Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ dependencies = [
"httpx>=0.28.0",
"httpcore>=1.0.9", # Required for Python 3.14 compatibility

# Token counting (OpenAI models)
"tiktoken>=0.12.0",


# Core functionality
"pydantic>=2.12.0",
Expand Down
44 changes: 5 additions & 39 deletions src/gac/ai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,11 @@
import os
import time
from collections.abc import Callable
from functools import lru_cache
from typing import Any, cast

import tiktoken
from rich.console import Console
from rich.status import Status

from gac.constants import EnvDefaults, Utility
from gac.errors import AIError
from gac.oauth import QwenOAuthProvider, refresh_token_if_expired
from gac.oauth.token_store import TokenStore
Expand All @@ -25,29 +22,16 @@
console = Console()


@lru_cache(maxsize=1)
def _should_skip_tiktoken_counting() -> bool:
"""Return True when token counting should avoid tiktoken calls entirely."""
value = os.getenv("GAC_NO_TIKTOKEN", str(EnvDefaults.NO_TIKTOKEN))
return value.lower() in ("true", "1", "yes", "on")


def count_tokens(content: str | list[dict[str, str]] | dict[str, Any], model: str) -> int:
"""Count tokens in content using the model's tokenizer."""
"""Count tokens in content using character-based estimation (1 token per 3.4 characters)."""
text = extract_text_content(content)
if not text:
return 0

if _should_skip_tiktoken_counting():
return len(text) // 4

try:
encoding = get_encoding(model)
return len(encoding.encode(text))
except (KeyError, UnicodeError, ValueError) as e:
logger.error(f"Error counting tokens: {e}")
# Fallback to rough estimation (4 chars per token on average)
return len(text) // 4
# Use simple character-based estimation: 1 token per 3.4 characters (rounded)
result = round(len(text) / 3.4)
# Ensure at least 1 token for non-empty text
return result if result > 0 else 1


def extract_text_content(content: str | list[dict[str, str]] | dict[str, Any]) -> str:
Expand All @@ -61,24 +45,6 @@ def extract_text_content(content: str | list[dict[str, str]] | dict[str, Any]) -
return ""


@lru_cache(maxsize=1)
def get_encoding(model: str) -> tiktoken.Encoding:
"""Get the appropriate encoding for a given model."""
provider, model_name = model.split(":", 1) if ":" in model else (None, model)

if provider != "openai":
return tiktoken.get_encoding(Utility.DEFAULT_ENCODING)

try:
return tiktoken.encoding_for_model(model_name)
except KeyError:
# Fall back to default encoding if model not found
return tiktoken.get_encoding(Utility.DEFAULT_ENCODING)
except (OSError, ConnectionError):
# If there are any network/SSL issues, fall back to default encoding
return tiktoken.get_encoding(Utility.DEFAULT_ENCODING)


def generate_with_retries(
provider_funcs: dict[str, Callable[..., str]],
model: str,
Expand Down
2 changes: 0 additions & 2 deletions src/gac/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ class GACConfig(TypedDict, total=False):
warning_limit_tokens: int
always_include_scope: bool
skip_secret_scan: bool
no_tiktoken: bool
no_verify_ssl: bool
verbose: bool
system_prompt_path: str | None
Expand Down Expand Up @@ -110,7 +109,6 @@ def load_config() -> GACConfig:
in ("true", "1", "yes", "on"),
"skip_secret_scan": os.getenv("GAC_SKIP_SECRET_SCAN", str(EnvDefaults.SKIP_SECRET_SCAN)).lower()
in ("true", "1", "yes", "on"),
"no_tiktoken": os.getenv("GAC_NO_TIKTOKEN", str(EnvDefaults.NO_TIKTOKEN)).lower() in ("true", "1", "yes", "on"),
"no_verify_ssl": os.getenv("GAC_NO_VERIFY_SSL", str(EnvDefaults.NO_VERIFY_SSL)).lower()
in ("true", "1", "yes", "on"),
"verbose": os.getenv("GAC_VERBOSE", str(EnvDefaults.VERBOSE)).lower() in ("true", "1", "yes", "on"),
Expand Down
2 changes: 0 additions & 2 deletions src/gac/constants/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ class EnvDefaults:
ALWAYS_INCLUDE_SCOPE: bool = False
SKIP_SECRET_SCAN: bool = False
VERBOSE: bool = False
NO_TIKTOKEN: bool = False
NO_VERIFY_SSL: bool = False # Skip SSL certificate verification (for corporate proxies)
HOOK_TIMEOUT: int = 120 # Timeout for pre-commit and lefthook hooks in seconds

Expand All @@ -34,7 +33,6 @@ class Logging:
class Utility:
"""General utility constants."""

DEFAULT_ENCODING: str = "cl100k_base" # llm encoding
DEFAULT_DIFF_TOKEN_LIMIT: int = 15000 # Maximum tokens for diff processing
MAX_WORKERS: int = os.cpu_count() or 4 # Maximum number of parallel workers
MAX_DISPLAYED_SECRET_LENGTH: int = 50 # Maximum length for displaying secrets
141 changes: 66 additions & 75 deletions tests/test_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
from unittest.mock import MagicMock, patch

import pytest
import tiktoken

from gac.ai import generate_commit_message, generate_grouped_commits
from gac.ai_utils import (
count_tokens,
extract_text_content,
get_encoding,
)
from gac.errors import AIError
from gac.providers import PROVIDER_REGISTRY, SUPPORTED_PROVIDERS
Expand All @@ -34,46 +32,40 @@ def test_extract_text_content(self):
# Test empty input
assert extract_text_content({}) == ""

def test_get_encoding_known_model(self):
"""Test getting encoding for known models with optimized mocking."""
# Create a mock encoding to avoid slow tiktoken loading
mock_encoding = MagicMock(spec=tiktoken.Encoding)
mock_encoding.name = "cl100k_base"
mock_encoding.encode.return_value = [9906, 1917] # Tokens for "Hello world"
mock_encoding.decode.return_value = "Hello world"

with patch("tiktoken.encoding_for_model", return_value=mock_encoding):
# Test with a well-known OpenAI model that should map to cl100k_base
encoding = get_encoding("openai:gpt-4")
assert isinstance(encoding, tiktoken.Encoding)
assert encoding.name == "cl100k_base"

# Verify encoding behavior
tokens = encoding.encode("Hello world")
assert len(tokens) > 0
assert isinstance(tokens[0], int)

# Decode should round-trip correctly
decoded = encoding.decode(tokens)
assert decoded == "Hello world"
def test_character_based_counting_simple(self):
"""Test simple character-based counting without external dependencies."""
# Test basic functionality
text = "Hello world"
result = count_tokens(text, "any:model")
expected = round(len(text) / 3.4)
assert result == expected

# Test with empty string
assert count_tokens("", "any:model") == 0

# Test with single character
assert count_tokens("a", "any:model") == 1

def test_count_tokens(self):
"""Test token counting functionality."""
# Test with string content
text = "Hello, world!"
token_count = count_tokens(text, "openai:gpt-4")
assert token_count > 0
expected = round(len(text) / 3.4)
assert token_count == expected
assert isinstance(token_count, int)

@patch("gac.ai_utils.count_tokens")
def test_count_tokens_anthropic_mock(self, mock_count_tokens):
"""Test that anthropic models are handled correctly."""
# This tests the code path, not the actual implementation
mock_count_tokens.return_value = 5
def test_count_tokens_all_models_same(self):
"""Test that all models work the same with character-based counting."""
text = "Test message"
expected = round(len(text) / 3.4)

# Test that anthropic model strings are recognized
model = "anthropic:claude-3-haiku"
assert model.startswith("anthropic")
# Test that all providers give same result
models = ["anthropic:claude-3-haiku", "openai:gpt-4", "groq:llama3", "gemini:gemini-pro"]

for model in models:
result = count_tokens(text, model)
assert result == expected, f"Model {model} should give {expected}, got {result}"

def test_count_tokens_empty_content(self):
"""Test token counting with empty content."""
Expand All @@ -84,55 +76,54 @@ def test_count_tokens_empty_content(self):
# Test with list of messages
messages = [{"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi there!"}]
token_count = count_tokens(messages, "openai:gpt-4")
assert token_count > 0
expected = round(len("Hello\nHi there!") / 3.4)
assert token_count == expected

# Test with dict content
message = {"role": "user", "content": "Test message"}
token_count = count_tokens(message, "openai:gpt-4")
assert token_count > 0

def test_get_encoding_unknown_model(self):
"""Test getting encoding for unknown models falls back to default."""
# Create a mock default encoding to avoid slow tiktoken loading
mock_encoding = MagicMock(spec=tiktoken.Encoding)
mock_encoding.name = "cl100k_base"

with patch("tiktoken.get_encoding", return_value=mock_encoding):
# Clear the cache first to ensure fresh test
get_encoding.cache_clear()

# Test with unknown model should fall back to default encoding
encoding = get_encoding("unknown:model-xyz")
assert isinstance(encoding, tiktoken.Encoding)
# Should use the default cl100k_base encoding
assert encoding.name == "cl100k_base"

def test_count_tokens_error_handling(self):
"""Test error handling in count_tokens function."""
# Test with a model that will cause encoding error
with patch("gac.ai_utils.get_encoding") as mock_encoding:
mock_encoding.side_effect = ValueError("Encoding error")

# Should fall back to character-based estimation (len/4)
token_count = count_tokens("Hello world", "test:model")
assert token_count == len("Hello world") // 4
expected = round(len("Test message") / 3.4)
assert token_count == expected

def test_character_based_all_providers_same(self):
"""Test that character-based counting works the same for all providers."""
text = "Sample test message"
expected = round(len(text) / 3.4)

providers = ["openai:gpt-4", "anthropic:claude-3", "groq:llama3-70b", "gemini:gemini-pro"]

for provider in providers:
result = count_tokens(text, provider)
assert result == expected, f"Provider {provider} should give {expected}, got {result}"

def test_character_based_no_errors(self):
"""Test that character-based counting never raises errors."""
# Various inputs that should always work
test_cases = [
"",
"Simple text",
"Unicode: café résumé",
"Emoji: 🎉🚀",
"New\nline\tand tabs",
]

for text in test_cases:
result = count_tokens(text, "any:model")
assert isinstance(result, int)
assert result >= 0

def test_count_tokens_with_various_content_types(self):
"""Test count_tokens with different content formats."""
# Mock encoding to avoid slow tiktoken loading
mock_encoding = MagicMock(spec=tiktoken.Encoding)
mock_encoding.encode.return_value = [1, 2, 3, 4, 5] # Mock tokens

with patch("gac.ai_utils.get_encoding", return_value=mock_encoding):
# Test with list containing invalid items
messages = [
{"role": "user", "content": "Valid message"},
{"role": "assistant"}, # Missing content
"invalid", # Not a dict
{"content": "No role"}, # Has content
]
token_count = count_tokens(messages, "openai:gpt-4")
assert token_count == 5 # Should return mock token count
# Test with list containing various items
messages = [
{"role": "user", "content": "Valid message"},
{"role": "assistant"}, # Missing content
"invalid", # Not a dict
{"content": "No role"}, # Has content
]
token_count = count_tokens(messages, "openai:gpt-4")
expected = round(len("Valid message\nNo role") / 3.4)
assert token_count == expected


class TestGenerateCommitMessage:
Expand Down
Loading
Loading