diff --git a/.gitignore b/.gitignore index d42c6b8..1a75e86 100644 --- a/.gitignore +++ b/.gitignore @@ -54,7 +54,6 @@ coverage.xml *.pot # Django stuff: -*.log local_settings.py db.sqlite3 db.sqlite3-journal @@ -124,4 +123,8 @@ test_proxy.py start_proxy.bat key_usage.json staged_changes.txt +launcher_config.json +cache/antigravity/thought_signatures.json logs/ +cache/ +*.env diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index bd4c6c1..b5a9493 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -57,6 +57,7 @@ client = RotatingClient( - `whitelist_models` (`Optional[Dict[str, List[str]]]`, default: `None`): Whitelist of models to always include, overriding `ignore_models`. - `enable_request_logging` (`bool`, default: `False`): If `True`, enables detailed per-request file logging. - `max_concurrent_requests_per_key` (`Optional[Dict[str, int]]`, default: `None`): Max concurrent requests allowed for a single API key per provider. +- `rotation_tolerance` (`float`, default: `3.0`): Controls the credential rotation strategy. See Section 2.2 for details. #### Core Responsibilities @@ -110,8 +111,16 @@ The `acquire_key` method uses a sophisticated strategy to balance load: 2. **Tiering**: Valid keys are split into two tiers: * **Tier 1 (Ideal)**: Keys that are completely idle (0 concurrent requests). * **Tier 2 (Acceptable)**: Keys that are busy but still under their configured `MAX_CONCURRENT_REQUESTS_PER_KEY_` limit for the requested model. This allows a single key to be used multiple times for the same model, maximizing throughput. -3. **Prioritization**: Within each tier, keys with the **lowest daily usage** are prioritized to spread costs evenly. +3. **Selection Strategy** (configurable via `rotation_tolerance`): + * **Deterministic (tolerance=0.0)**: Within each tier, keys are sorted by daily usage count and the least-used key is always selected. This provides perfect load balance but predictable patterns. + * **Weighted Random (tolerance>0, default)**: Keys are selected randomly with weights biased toward less-used ones: + - Formula: `weight = (max_usage - credential_usage) + tolerance + 1` + - `tolerance=2.0` (recommended): Balanced randomness - credentials within 2 uses of the maximum can still be selected with reasonable probability + - `tolerance=5.0+`: High randomness - even heavily-used credentials have significant probability + - **Security Benefit**: Unpredictable selection patterns make rate limit detection and fingerprinting harder + - **Load Balance**: Lower-usage credentials still preferred, maintaining reasonable distribution 4. **Concurrency Limits**: Checks against `max_concurrent` limits to prevent overloading a single key. +5. **Priority Groups**: When credential prioritization is enabled, higher-tier credentials (lower priority numbers) are tried first before moving to lower tiers. #### Failure Handling & Cooldowns @@ -313,6 +322,294 @@ The `CooldownManager` handles IP or account-level rate limiting that affects all - If so, `CooldownManager.start_cooldown()` is called for the entire provider - All subsequent `acquire_key()` calls for that provider will wait until the cooldown expires + +### 2.10. Credential Prioritization System (`client.py` & `usage_manager.py`) + +The library now includes an intelligent credential prioritization system that automatically detects credential tiers and ensures optimal credential selection for each request. + +**Key Concepts:** + +- **Provider-Level Priorities**: Providers can implement `get_credential_priority()` to return a priority level (1=highest, 10=lowest) for each credential +- **Model-Level Requirements**: Providers can implement `get_model_tier_requirement()` to specify minimum priority required for specific models +- **Automatic Filtering**: The client automatically filters out incompatible credentials before making requests +- **Priority-Aware Selection**: The `UsageManager` prioritizes higher-tier credentials (lower numbers) within the same priority group + +**Implementation Example (Gemini CLI):** + +```python +def get_credential_priority(self, credential: str) -> Optional[int]: + """Returns priority based on Gemini tier.""" + tier = self.project_tier_cache.get(credential) + if not tier: + return None # Not yet discovered + + # Paid tiers get highest priority + if tier not in ['free-tier', 'legacy-tier', 'unknown']: + return 1 + + # Free tier gets lower priority + if tier == 'free-tier': + return 2 + + return 10 + +def get_model_tier_requirement(self, model: str) -> Optional[int]: + """Returns minimum priority required for model.""" + if model.startswith("gemini-3-"): + return 1 # Only paid tier (priority 1) credentials + + return None # All other models have no restrictions +``` + +**Usage Manager Integration:** + +The `acquire_key()` method has been enhanced to: +1. Group credentials by priority level +2. Try highest priority group first (priority 1, then 2, etc.) +3. Within each group, use existing tier1/tier2 logic (idle keys first, then busy keys) +4. Load balance within priority groups by usage count +5. Only move to next priority if all higher-priority credentials are exhausted + +**Benefits:** + +- Ensures paid-tier credentials are always used for premium models +- Prevents failed requests due to tier restrictions +- Optimal cost distribution (free tier used when possible, paid when required) +- Graceful fallback if primary credentials are unavailable + +--- + +### 2.11. Provider Cache System (`providers/provider_cache.py`) + +A modular, shared caching system for providers to persist conversation state across requests. + +**Architecture:** + +- **Dual-TTL Design**: Short-lived memory cache (default: 1 hour) + longer-lived disk persistence (default: 24 hours) +- **Background Persistence**: Batched disk writes every 60 seconds (configurable) +- **Automatic Cleanup**: Background task removes expired entries from memory cache + +### 3.5. Antigravity (`antigravity_provider.py`) + +The most sophisticated provider implementation, supporting Google's internal Antigravity API for Gemini and Claude models. + +#### Architecture + +- **Unified Streaming/Non-Streaming**: Single code path handles both response types with optimal transformations +- **Thought Signature Caching**: Server-side caching of encrypted signatures for multi-turn Gemini 3 conversations +- **Model-Specific Logic**: Automatic configuration based on model type (Gemini 2.5, Gemini 3, Claude) + +#### Model Support + +**Gemini 2.5 (Pro/Flash):** +- Uses `thinkingBudget` parameter (integer tokens: -1 for auto, 0 to disable, or specific value) +- Standard safety settings and toolConfig +- Stream processing with thinking content separation + +**Gemini 3 (Pro/Image):** +- Uses `thinkingLevel` parameter (string: "low" or "high") +- **Tool Hallucination Prevention**: + - Automatic system instruction injection explaining custom tool schema rules + - Parameter signature injection into tool descriptions (e.g., "STRICT PARAMETERS: files (ARRAY_OF_OBJECTS[path: string REQUIRED, ...])") + - Namespace prefix for tool names (`gemini3_` prefix) to avoid training data conflicts + - Malformed JSON auto-correction (handles extra trailing braces) +- **ThoughtSignature Management**: + - Caching signatures from responses for reuse in follow-up messages + - Automatic injection into functionCalls for multi-turn conversations + - Fallback to bypass value if signature unavailable + +**Claude Sonnet 4.5:** +- Proxied through Antigravity API (uses internal model name `claude-sonnet-4-5-thinking`) +- Uses `thinkingBudget` parameter like Gemini 2.5 +- **Thinking Preservation**: Caches thinking content using composite keys (tool_call_id + text_hash) +- **Schema Cleaning**: Removes unsupported properties (`$schema`, `additionalProperties`, `const` → `enum`) + +#### Base URL Fallback + +Automatic fallback chain for resilience: +1. `daily-cloudcode-pa.sandbox.googleapis.com` (primary sandbox) +2. `autopush-cloudcode-pa.sandbox.googleapis.com` (fallback sandbox) +3. `cloudcode-pa.googleapis.com` (production fallback) + +#### Message Transformation + +**OpenAI → Gemini Format:** +- System messages → `systemInstruction` with parts array +- Multi-part content (text + images) → `inlineData` format +- Tool calls → `functionCall` with args and id +- Tool responses → `functionResponse` with name and response +- ThoughtSignatures preserved/injected as needed + +**Tool Response Grouping:** +- Converts linear format (call, response, call, response) to grouped format +- Groups all function calls in one `model` message +- Groups all responses in one `user` message +- Required for Antigravity API compatibility + +#### Configuration (Environment Variables) + +```env +# Cache control +ANTIGRAVITY_SIGNATURE_CACHE_TTL=3600 # Memory cache TTL +ANTIGRAVITY_SIGNATURE_DISK_TTL=86400 # Disk cache TTL +ANTIGRAVITY_ENABLE_SIGNATURE_CACHE=true + +# Feature flags +ANTIGRAVITY_PRESERVE_THOUGHT_SIGNATURES=true # Include signatures in client responses +ANTIGRAVITY_ENABLE_DYNAMIC_MODELS=false # Use API model discovery +ANTIGRAVITY_GEMINI3_TOOL_FIX=true # Enable Gemini 3 hallucination prevention +ANTIGRAVITY_CLAUDE_THINKING_SANITIZATION=true # Enable Claude thinking mode auto-correction + +# Gemini 3 tool fix customization +ANTIGRAVITY_GEMINI3_TOOL_PREFIX="gemini3_" # Namespace prefix +ANTIGRAVITY_GEMINI3_DESCRIPTION_PROMPT="\n\nSTRICT PARAMETERS: {params}." +ANTIGRAVITY_GEMINI3_SYSTEM_INSTRUCTION="..." # Full system prompt +``` + +#### Claude Extended Thinking Sanitization + +The provider includes automatic sanitization for Claude's extended thinking mode, handling common error scenarios: + +**Problem**: Claude's extended thinking API requires strict consistency in thinking blocks: +- If thinking is enabled, the final assistant turn must start with a thinking block +- If thinking is disabled, no thinking blocks can be present in the final turn +- Tool use loops are part of a single "assistant turn" +- You **cannot** toggle thinking mode mid-turn (this is invalid per Claude API) + +**Scenarios Handled**: + +| Scenario | Action | +|----------|--------| +| Tool loop WITH thinking + thinking enabled | Preserve thinking, continue normally | +| Tool loop WITHOUT thinking + thinking enabled | **Inject synthetic closure** to start fresh turn with thinking | +| Thinking disabled | Strip all thinking blocks | +| Normal conversation (no tool loop) | Strip old thinking, new response adds thinking naturally | + +**Solution**: The `_sanitize_thinking_for_claude()` method: +- Analyzes conversation state to detect incomplete tool use loops +- When enabling thinking in a tool loop that started without thinking: + - Injects a minimal synthetic assistant message: `"[Tool execution completed. Processing results.]"` + - This **closes** the previous turn, allowing Claude to start a **fresh turn with thinking** +- Strips thinking from old turns (Claude API ignores them anyway) +- Preserves thinking when the turn was started with thinking enabled + +**Key Insight**: Instead of force-disabling thinking, we close the tool loop with a synthetic message. This allows seamless model switching (e.g., Gemini → Claude with thinking) without losing the ability to think. + +**Example**: +``` +Before sanitization: + User: "What's the weather?" + Assistant: [tool_use: get_weather] ← Made by Gemini (no thinking) + User: [tool_result: "20C sunny"] + +After sanitization (thinking enabled): + User: "What's the weather?" + Assistant: [tool_use: get_weather] + User: [tool_result: "20C sunny"] + Assistant: "[Tool execution completed. Processing results.]" ← INJECTED + + → Claude now starts a NEW turn and CAN think! +``` + +**Configuration**: +```env +ANTIGRAVITY_CLAUDE_THINKING_SANITIZATION=true # Enable/disable auto-correction +``` + +#### File Logging + +Optional transaction logging for debugging: +- Enabled via `enable_request_logging` parameter +- Creates `logs/antigravity_logs/TIMESTAMP_MODEL_UUID/` directory per request +- Logs: `request_payload.json`, `response_stream.log`, `final_response.json`, `error.log` + +--- + + +- **Atomic Disk Writes**: Uses temp-file-and-move pattern to prevent corruption + +**Key Methods:** + +1. **`store(key, value)`**: Synchronously queues value for storage (schedules async write) +2. **`retrieve(key)`**: Synchronously retrieves from memory, optionally schedules disk fallback +3. **`store_async(key, value)`**: Awaitable storage for guaranteed persistence +4. **`retrieve_async(key)`**: Awaitable retrieval with disk fallback + +**Use Cases:** + +- **Gemini 3 ThoughtSignatures**: Caching tool call signatures for multi-turn conversations +- **Claude Thinking**: Preserving thinking content for consistency across conversation turns +- **Any Transient State**: Generic key-value storage for provider-specific needs + +**Configuration (Environment Variables):** + +```env +# Cache control (prefix can be customized per cache instance) +PROVIDER_CACHE_ENABLE=true +PROVIDER_CACHE_WRITE_INTERVAL=60 # seconds between disk writes +PROVIDER_CACHE_CLEANUP_INTERVAL=1800 # 30 min between cleanups + +# Gemini 3 specific +GEMINI_CLI_SIGNATURE_CACHE_ENABLE=true +GEMINI_CLI_SIGNATURE_CACHE_TTL=3600 # 1 hour memory TTL +GEMINI_CLI_SIGNATURE_DISK_TTL=86400 # 24 hours disk TTL +``` + +**File Structure:** + +``` +cache/ +├── gemini_cli/ +│ └── gemini3_signatures.json +└── antigravity/ + ├── gemini3_signatures.json + └── claude_thinking.json +``` + +--- + +### 2.12. Google OAuth Base (`providers/google_oauth_base.py`) + +A refactored, reusable OAuth2 base class that eliminates code duplication across Google-based providers. + +**Refactoring Benefits:** + +- **Single Source of Truth**: All OAuth logic centralized in one class +- **Easy Provider Addition**: New providers only need to override constants +- **Consistent Behavior**: Token refresh, expiry handling, and validation work identically across providers +- **Maintainability**: OAuth bugs fixed once apply to all inheriting providers + +**Provider Implementation:** + +```python +class AntigravityAuthBase(GoogleOAuthBase): + # Required overrides + CLIENT_ID = "antigravity-client-id" + CLIENT_SECRET = "antigravity-secret" + OAUTH_SCOPES = [ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/cclog", # Antigravity-specific + "https://www.googleapis.com/auth/experimentsandconfigs", + ] + ENV_PREFIX = "ANTIGRAVITY" # Used for env var loading + + # Optional overrides (defaults provided) + CALLBACK_PORT = 51121 + CALLBACK_PATH = "/oauthcallback" +``` + +**Inherited Features:** + +- Automatic token refresh with exponential backoff +- Invalid grant re-authentication flow +- Stateless deployment support (env var loading) +- Atomic credential file writes +- Headless environment detection +- Sequential refresh queue processing + +--- + + --- ## 3. Provider Specific Implementations diff --git a/Deployment guide.md b/Deployment guide.md index 1d31c14..57acd53 100644 --- a/Deployment guide.md +++ b/Deployment guide.md @@ -79,6 +79,37 @@ If you are using providers that require complex OAuth files (like **Gemini CLI** 4. Copy the contents of this file and paste them directly into your `.env` file or Render's "Environment Variables" section. 5. The proxy will automatically detect and use these variables—no file upload required! + +### Advanced: Antigravity OAuth Provider + +The Antigravity provider requires OAuth2 authentication similar to Gemini CLI. It provides access to: +- Gemini 2.5 models (Pro/Flash) +- Gemini 3 models (Pro/Image-preview) - **requires paid-tier Google Cloud project** +- Claude Sonnet 4.5 via Google's Antigravity proxy + +**Setting up Antigravity locally:** +1. Run the credential tool: `python -m rotator_library.credential_tool` +2. Select "Add OAuth Credential" and choose "Antigravity" +3. Complete the OAuth flow in your browser +4. The credential is saved to `oauth_creds/antigravity_oauth_1.json` + +**Exporting for stateless deployment:** +1. Run: `python -m rotator_library.credential_tool` +2. Select "Export Antigravity to .env" +3. Copy the generated environment variables to your deployment platform: + ```env + ANTIGRAVITY_ACCESS_TOKEN="..." + ANTIGRAVITY_REFRESH_TOKEN="..." + ANTIGRAVITY_EXPIRY_DATE="..." + ANTIGRAVITY_EMAIL="your-email@gmail.com" + ``` + +**Important Notes:** +- Antigravity uses Google OAuth with additional scopes for cloud platform access +- Gemini 3 models require a paid-tier Google Cloud project (free tier will fail) +- The provider automatically handles thought signature caching for multi-turn conversations +- Tool hallucination prevention is enabled by default for Gemini 3 models + 4. Save the file. (We'll upload it to Render in Step 5.) diff --git a/README.md b/README.md index 6129d11..51399bd 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,19 @@ This project provides a powerful solution for developers building complex applic - **Provider Agnostic**: Compatible with any provider supported by `litellm`. - **OpenAI-Compatible Proxy**: Offers a familiar API interface with additional endpoints for model and provider discovery. - **Advanced Model Filtering**: Supports both blacklists and whitelists to give you fine-grained control over which models are available through the proxy. + +- **🆕 Antigravity Provider**: Full support for Google's internal Antigravity API, providing access to Gemini 2.5, Gemini 3, and Claude Sonnet 4.5 models with advanced features: + - Thought signature caching for multi-turn conversations + - Tool hallucination prevention via parameter signature injection + - Automatic thinking block sanitization for Claude models + - Note: Claude Sonnet 4.5 thinking mode requires careful conversation state management (see [Antigravity documentation](DOCUMENTATION.md#antigravity-claude-extended-thinking-sanitization) for details) +- **🆕 Credential Prioritization**: Automatic tier detection and priority-based credential selection ensures paid-tier credentials are used for premium models that require them. +- **🆕 Weighted Random Rotation**: Configurable credential rotation strategy - choose between deterministic (perfect balance) or weighted random (unpredictable, harder to fingerprint) selection. +- **🆕 Enhanced Gemini CLI**: Improved project discovery, paid vs free tier detection, and Gemini 3 support with thoughtSignature caching. +- **🆕 Temperature Override**: Global temperature=0 override option to prevent tool hallucination issues with low-temperature settings. +- **🆕 Provider Cache System**: Modular caching system for preserving conversation state (thought signatures, thinking content) across requests. +- **🆕 Refactored OAuth Base**: Shared [`GoogleOAuthBase`](src/rotator_library/providers/google_oauth_base.py) class eliminates code duplication across OAuth providers. + - **🆕 Interactive Launcher TUI**: Beautiful, cross-platform TUI for configuration and management with an integrated settings tool for advanced configuration. @@ -234,11 +247,12 @@ python src/proxy_app/main.py **Main Menu Features:** -1. **Add OAuth Credential** - Interactive OAuth flow for Gemini CLI, Qwen Code, and iFlow +1. **Add OAuth Credential** - Interactive OAuth flow for Gemini CLI, Antigravity, Qwen Code, and iFlow - Automatically opens your browser for authentication - Handles the entire OAuth flow including callbacks - Saves credentials to the local `oauth_creds/` directory - For Gemini CLI: Automatically discovers or creates a Google Cloud project + - For Antigravity: Similar to Gemini CLI with Antigravity-specific scopes - For Qwen Code: Uses Device Code flow (you'll enter a code in your browser) - For iFlow: Starts a local callback server on port 11451 @@ -488,6 +502,42 @@ The following advanced settings can be added to your `.env` file (or configured - **`SKIP_OAUTH_INIT_CHECK`**: Set to `true` to skip the interactive OAuth setup/validation check on startup. Essential for non-interactive environments like Docker containers or CI/CD pipelines. ```env SKIP_OAUTH_INIT_CHECK=true + + +#### **Antigravity (Advanced - Gemini 3 \Claude 4.5 Access)** +The newest and most sophisticated provider, offering access to cutting-edge models via Google's internal Antigravity API. + +**Supported Models:** +- Gemini 2.5 (Pro/Flash) with `thinkingBudget` parameter +- **Gemini 3 Pro (High/Low)** - Latest preview models +- **Claude Sonnet 4.5 + Thinking** via Antigravity proxy + +**Advanced Features:** +- **Thought Signature Caching**: Preserves encrypted signatures for multi-turn Gemini 3 conversations +- **Tool Hallucination Prevention**: Automatic system instruction and parameter signature injection for Gemini 3 to prevent tools from being called with incorrect parameters +- **Thinking Preservation**: Caches Claude thinking content for consistency across conversation turns +- **Automatic Fallback**: Tries sandbox endpoints before falling back to production +- **Schema Cleaning**: Handles Claude-specific tool schema requirements + +**Configuration:** +- **OAuth Setup**: Uses Google OAuth similar to Gemini CLI (separate scopes) +- **Stateless Deployment**: Full environment variable support +- **Paid Tier Recommended**: Gemini 3 models require a paid Google Cloud project + +**Environment Variables:** +```env +# Stateless deployment +ANTIGRAVITY_ACCESS_TOKEN="..." +ANTIGRAVITY_REFRESH_TOKEN="..." +ANTIGRAVITY_EXPIRY_DATE="..." +ANTIGRAVITY_EMAIL="user@gmail.com" + +# Feature toggles +ANTIGRAVITY_ENABLE_SIGNATURE_CACHE=true # Multi-turn conversation support +ANTIGRAVITY_GEMINI3_TOOL_FIX=true # Prevent tool hallucination +``` + + ``` #### Concurrency Control @@ -516,6 +566,71 @@ For providers that support custom model definitions (Qwen Code, iFlow), you can #### Provider-Specific Settings - **`GEMINI_CLI_PROJECT_ID`**: Manually specify a Google Cloud Project ID for Gemini CLI OAuth. Only needed if automatic discovery fails. + + +#### Antigravity Provider + +- **`ANTIGRAVITY_OAUTH_1`**: Path to Antigravity OAuth credential file (auto-discovered from `~/.antigravity/` or use the credential tool). + ```env + ANTIGRAVITY_OAUTH_1="/path/to/your/antigravity_creds.json" + ``` + +- **Stateless Deployment** (Environment Variables): + ```env + ANTIGRAVITY_ACCESS_TOKEN="ya29.your-access-token" + + +#### Credential Rotation Strategy + +- **`ROTATION_TOLERANCE`**: Controls how credentials are selected for requests. Set via environment variable or programmatically. + - `0.0`: **Deterministic** - Always selects the least-used credential for perfect load balance + - `3.0` (default, recommended): **Weighted Random** - Randomly selects with bias toward less-used credentials. Provides unpredictability (harder to fingerprint/detect) while maintaining good balance + - `5.0+`: **High Randomness** - Maximum unpredictability, even heavily-used credentials can be selected + + ```env + # For maximum security/unpredictability (recommended for production) + ROTATION_TOLERANCE=3.0 + + # For perfect load balancing (default) + ROTATION_TOLERANCE=0.0 + ``` + + **Why use weighted random?** + - Makes traffic patterns less predictable + - Still maintains good load distribution across keys + - Recommended for production environments with multiple credentials + + + ANTIGRAVITY_REFRESH_TOKEN="1//your-refresh-token" + ANTIGRAVITY_EXPIRY_DATE="1234567890000" + ANTIGRAVITY_EMAIL="your-email@gmail.com" + ``` + +- **`ANTIGRAVITY_ENABLE_SIGNATURE_CACHE`**: Enable/disable thought signature caching for Gemini 3 multi-turn conversations. Default: `true`. + ```env + ANTIGRAVITY_ENABLE_SIGNATURE_CACHE=true + ``` + +- **`ANTIGRAVITY_GEMINI3_TOOL_FIX`**: Enable/disable tool hallucination prevention for Gemini 3 models. Default: `true`. + ```env + ANTIGRAVITY_GEMINI3_TOOL_FIX=true + ``` + +#### Temperature Override (Global) + +- **`OVERRIDE_TEMPERATURE_ZERO`**: Prevents tool hallucination caused by temperature=0 settings. Modes: + - `"remove"`: Deletes temperature=0 from requests (lets provider use default) + - `"set"`: Changes temperature=0 to temperature=1.0 + - `"false"` or unset: Disabled (default) + +#### Credential Prioritization + +- **`GEMINI_CLI_PROJECT_ID`**: Manually specify a Google Cloud Project ID for Gemini CLI OAuth. Auto-discovered unless unexpected failure occurs. + ```env + GEMINI_CLI_PROJECT_ID="your-gcp-project-id" + ``` + + ```env GEMINI_CLI_PROJECT_ID="your-gcp-project-id" ``` diff --git a/src/proxy_app/launcher_tui.py b/src/proxy_app/launcher_tui.py index a14c0ae..26a36bf 100644 --- a/src/proxy_app/launcher_tui.py +++ b/src/proxy_app/launcher_tui.py @@ -100,7 +100,8 @@ def get_all_settings() -> dict: "custom_bases": SettingsDetector.detect_custom_api_bases(), "model_definitions": SettingsDetector.detect_model_definitions(), "concurrency_limits": SettingsDetector.detect_concurrency_limits(), - "model_filters": SettingsDetector.detect_model_filters() + "model_filters": SettingsDetector.detect_model_filters(), + "provider_settings": SettingsDetector.detect_provider_settings() } @staticmethod @@ -198,6 +199,45 @@ def detect_model_filters() -> dict: else: filters[provider]["has_whitelist"] = True return filters + + @staticmethod + def detect_provider_settings() -> dict: + """Detect provider-specific settings (Antigravity, Gemini CLI)""" + try: + from proxy_app.settings_tool import PROVIDER_SETTINGS_MAP + except ImportError: + # Fallback for direct execution or testing + from .settings_tool import PROVIDER_SETTINGS_MAP + + provider_settings = {} + env_vars = SettingsDetector._load_local_env() + + for provider, definitions in PROVIDER_SETTINGS_MAP.items(): + modified_count = 0 + for key, definition in definitions.items(): + env_value = env_vars.get(key) + if env_value is not None: + # Check if value differs from default + default = definition.get("default") + setting_type = definition.get("type", "str") + + try: + if setting_type == "bool": + current = env_value.lower() in ("true", "1", "yes") + elif setting_type == "int": + current = int(env_value) + else: + current = env_value + + if current != default: + modified_count += 1 + except (ValueError, AttributeError): + pass + + if modified_count > 0: + provider_settings[provider] = modified_count + + return provider_settings class LauncherTUI: @@ -300,7 +340,8 @@ def show_main_menu(self): self.console.print("━" * 70) provider_count = len(credentials) custom_count = len(custom_bases) - has_advanced = bool(settings["model_definitions"] or settings["concurrency_limits"] or settings["model_filters"]) + provider_settings = settings.get("provider_settings", {}) + has_advanced = bool(settings["model_definitions"] or settings["concurrency_limits"] or settings["model_filters"] or provider_settings) self.console.print(f" Providers: {provider_count} configured") self.console.print(f" Custom Providers: {custom_count} configured") @@ -422,6 +463,7 @@ def show_provider_settings_menu(self): model_defs = settings["model_definitions"] concurrency = settings["concurrency_limits"] filters = settings["model_filters"] + provider_settings = settings.get("provider_settings", {}) self.console.print(Panel.fit( "[bold cyan]📊 Provider & Advanced Settings[/bold cyan]", @@ -472,7 +514,7 @@ def show_provider_settings_menu(self): self.console.print("━" * 70) for provider, limit in concurrency.items(): self.console.print(f" • {provider:15} {limit} requests/key") - self.console.print(f" • Default: 1 request/key (all others)") + self.console.print(" • Default: 1 request/key (all others)") # Model Filters (basic info only) if filters: @@ -488,6 +530,22 @@ def show_provider_settings_menu(self): status = " + ".join(status_parts) if status_parts else "None" self.console.print(f" • {provider:15} ✅ {status}") + # Provider-Specific Settings + self.console.print() + self.console.print("[bold]🔬 Provider-Specific Settings[/bold]") + self.console.print("━" * 70) + try: + from proxy_app.settings_tool import PROVIDER_SETTINGS_MAP + except ImportError: + from .settings_tool import PROVIDER_SETTINGS_MAP + for provider in PROVIDER_SETTINGS_MAP.keys(): + display_name = provider.replace("_", " ").title() + modified = provider_settings.get(provider, 0) + if modified > 0: + self.console.print(f" • {display_name:20} [yellow]{modified} setting{'s' if modified > 1 else ''} modified[/yellow]") + else: + self.console.print(f" • {display_name:20} [dim]using defaults[/dim]") + # Actions self.console.print() self.console.print("━" * 70) diff --git a/src/proxy_app/main.py b/src/proxy_app/main.py index 94f2c38..aa1278d 100644 --- a/src/proxy_app/main.py +++ b/src/proxy_app/main.py @@ -38,6 +38,25 @@ # If we get here, we're ACTUALLY running the proxy - NOW show startup messages and start timer _start_time = time.time() +# Load all .env files from root folder (main .env first, then any additional *.env files) +from dotenv import load_dotenv +from glob import glob + +# Load main .env first +load_dotenv() + +# Load any additional .env files (e.g., antigravity_all_combined.env, gemini_cli_all_combined.env) +_root_dir = Path.cwd() +_env_files_found = list(_root_dir.glob("*.env")) +for _env_file in sorted(_root_dir.glob("*.env")): + if _env_file.name != ".env": # Skip main .env (already loaded) + load_dotenv(_env_file, override=False) # Don't override existing values + +# Log discovered .env files for deployment verification +if _env_files_found: + _env_names = [_ef.name for _ef in _env_files_found] + print(f"📁 Loaded {len(_env_files_found)} .env file(s): {', '.join(_env_names)}") + # Get proxy API key for display proxy_api_key = os.getenv("PROXY_API_KEY") if proxy_api_key: @@ -87,6 +106,7 @@ from rotator_library import RotatingClient from rotator_library.credential_manager import CredentialManager from rotator_library.background_refresher import BackgroundRefresher + from rotator_library.model_info_service import init_model_info_service from proxy_app.request_logger import log_request_to_console from proxy_app.batch_manager import EmbeddingBatcher from proxy_app.detailed_logger import DetailedLogger @@ -110,15 +130,59 @@ class EmbeddingRequest(BaseModel): user: Optional[str] = None class ModelCard(BaseModel): + """Basic model card for minimal response.""" id: str object: str = "model" created: int = Field(default_factory=lambda: int(time.time())) owned_by: str = "Mirro-Proxy" +class ModelCapabilities(BaseModel): + """Model capability flags.""" + tool_choice: bool = False + function_calling: bool = False + reasoning: bool = False + vision: bool = False + system_messages: bool = True + prompt_caching: bool = False + assistant_prefill: bool = False + +class EnrichedModelCard(BaseModel): + """Extended model card with pricing and capabilities.""" + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "unknown" + # Pricing (optional - may not be available for all models) + input_cost_per_token: Optional[float] = None + output_cost_per_token: Optional[float] = None + cache_read_input_token_cost: Optional[float] = None + cache_creation_input_token_cost: Optional[float] = None + # Limits (optional) + max_input_tokens: Optional[int] = None + max_output_tokens: Optional[int] = None + context_window: Optional[int] = None + # Capabilities + mode: str = "chat" + supported_modalities: List[str] = Field(default_factory=lambda: ["text"]) + supported_output_modalities: List[str] = Field(default_factory=lambda: ["text"]) + capabilities: Optional[ModelCapabilities] = None + # Debug info (optional) + _sources: Optional[List[str]] = None + _match_type: Optional[str] = None + + class Config: + extra = "allow" # Allow extra fields from the service + class ModelList(BaseModel): + """List of models response.""" object: str = "list" data: List[ModelCard] +class EnrichedModelList(BaseModel): + """List of enriched models with pricing and capabilities.""" + object: str = "list" + data: List[EnrichedModelCard] + # Calculate total loading time _elapsed = time.time() - _start_time print(f"✓ Server ready in {_elapsed:.2f}s ({_plugin_count} providers discovered in {_provider_time:.2f}s)") @@ -294,6 +358,11 @@ async def lifespan(app: FastAPI): if provider not in credentials_to_initialize: credentials_to_initialize[provider] = [] for path in paths: + # Skip env-based credentials (virtual paths) - they don't have metadata files + if path.startswith("env://"): + credentials_to_initialize[provider].append(path) + continue + try: with open(path, 'r') as f: data = json.load(f) @@ -395,19 +464,20 @@ async def process_credential(provider: str, path: str, provider_instance): final_oauth_credentials[provider] = [] final_oauth_credentials[provider].append(path) - # Update metadata - try: - with open(path, 'r+') as f: - data = json.load(f) - metadata = data.get("_proxy_metadata", {}) - metadata["email"] = email - metadata["last_check_timestamp"] = time.time() - data["_proxy_metadata"] = metadata - f.seek(0) - json.dump(data, f, indent=2) - f.truncate() - except Exception as e: - logging.error(f"Failed to update metadata for '{path}': {e}") + # Update metadata (skip for env-based credentials - they don't have files) + if not path.startswith("env://"): + try: + with open(path, 'r+') as f: + data = json.load(f) + metadata = data.get("_proxy_metadata", {}) + metadata["email"] = email + metadata["last_check_timestamp"] = time.time() + data["_proxy_metadata"] = metadata + f.seek(0) + json.dump(data, f, indent=2) + f.truncate() + except Exception as e: + logging.error(f"Failed to update metadata for '{path}': {e}") logging.info("OAuth credential processing complete.") oauth_credentials = final_oauth_credentials @@ -428,6 +498,12 @@ async def process_credential(provider: str, path: str, provider_instance): enable_request_logging=ENABLE_REQUEST_LOGGING, max_concurrent_requests_per_key=max_concurrent_requests_per_key ) + + # Log loaded credentials summary (compact, always visible for deployment verification) + _api_summary = ', '.join([f"{p}:{len(c)}" for p, c in api_keys.items()]) if api_keys else "none" + _oauth_summary = ', '.join([f"{p}:{len(c)}" for p, c in oauth_credentials.items()]) if oauth_credentials else "none" + _total_summary = ', '.join([f"{p}:{len(c)}" for p, c in client.all_credentials.items()]) + print(f"🔑 Credentials loaded: {_total_summary} (API: {_api_summary} | OAuth: {_oauth_summary})") client.background_refresher.start() # Start the background task app.state.rotating_client = client @@ -451,6 +527,12 @@ async def process_credential(provider: str, path: str, provider_instance): else: app.state.embedding_batcher = None logging.info("RotatingClient initialized (EmbeddingBatcher disabled).") + + # Start model info service in background (fetches pricing/capabilities data) + # This runs asynchronously and doesn't block proxy startup + model_info_service = await init_model_info_service() + app.state.model_info_service = model_info_service + logging.info("Model info service started (fetching pricing data in background).") yield @@ -459,6 +541,10 @@ async def process_credential(provider: str, path: str, provider_instance): await app.state.embedding_batcher.stop() await client.close() + # Stop model info service + if hasattr(app.state, 'model_info_service') and app.state.model_info_service: + await app.state.model_info_service.stop() + if app.state.embedding_batcher: logging.info("RotatingClient and EmbeddingBatcher closed.") else: @@ -589,7 +675,10 @@ async def streaming_response_wrapper( final_message["function_call"]["arguments"] += value["arguments"] else: # Generic key handling for other data like 'reasoning' - if key not in final_message: + # FIX: Role should always replace, never concatenate + if key == "role": + final_message[key] = value + elif key not in final_message: final_message[key] = value elif isinstance(final_message.get(key), str): final_message[key] += value @@ -605,6 +694,9 @@ async def streaming_response_wrapper( # --- Final Response Construction --- if aggregated_tool_calls: final_message["tool_calls"] = list(aggregated_tool_calls.values()) + # CRITICAL FIX: Override finish_reason when tool_calls exist + # This ensures OpenCode and other agentic systems continue the conversation loop + finish_reason = "tool_calls" # Ensure standard fields are present for consistent logging for field in ["content", "tool_calls", "function_call"]: @@ -652,19 +744,35 @@ async def chat_completions( except json.JSONDecodeError: raise HTTPException(status_code=400, detail="Invalid JSON in request body.") + # Global temperature=0 override (controlled by .env variable, default: OFF) + # Low temperature makes models deterministic and prone to following training data + # instead of actual schemas, which can cause tool hallucination + # Modes: "remove" = delete temperature key, "set" = change to 1.0, "false" = disabled + override_temp_zero = os.getenv("OVERRIDE_TEMPERATURE_ZERO", "false").lower() + + if override_temp_zero in ("remove", "set", "true", "1", "yes") and "temperature" in request_data and request_data["temperature"] == 0: + if override_temp_zero == "remove": + # Remove temperature key entirely + del request_data["temperature"] + logging.debug("OVERRIDE_TEMPERATURE_ZERO=remove: Removed temperature=0 from request") + else: + # Set to 1.0 (for "set", "true", "1", "yes") + request_data["temperature"] = 1.0 + logging.debug("OVERRIDE_TEMPERATURE_ZERO=set: Converting temperature=0 to temperature=1.0") + # If logging is enabled, perform all logging operations using the parsed data. if logger: logger.log_request(headers=request.headers, body=request_data) - # Extract and log specific reasoning parameters for monitoring. - model = request_data.get("model") - generation_cfg = request_data.get("generationConfig", {}) or request_data.get("generation_config", {}) or {} - reasoning_effort = request_data.get("reasoning_effort") or generation_cfg.get("reasoning_effort") - custom_reasoning_budget = request_data.get("custom_reasoning_budget") or generation_cfg.get("custom_reasoning_budget", False) + # Extract and log specific reasoning parameters for monitoring. + model = request_data.get("model") + generation_cfg = request_data.get("generationConfig", {}) or request_data.get("generation_config", {}) or {} + reasoning_effort = request_data.get("reasoning_effort") or generation_cfg.get("reasoning_effort") + custom_reasoning_budget = request_data.get("custom_reasoning_budget") or generation_cfg.get("custom_reasoning_budget", False) - logging.getLogger("rotator_library").info( - f"Handling reasoning parameters: model={model}, reasoning_effort={reasoning_effort}, custom_reasoning_budget={custom_reasoning_budget}" - ) + logging.getLogger("rotator_library").debug( + f"Handling reasoning parameters: model={model}, reasoning_effort={reasoning_effort}, custom_reasoning_budget={custom_reasoning_budget}" + ) # Log basic request info to console (this is a separate, simpler logger). log_request_to_console( @@ -806,17 +914,73 @@ async def embeddings( def read_root(): return {"Status": "API Key Proxy is running"} -@app.get("/v1/models", response_model=ModelList) +@app.get("/v1/models") async def list_models( + request: Request, client: RotatingClient = Depends(get_rotating_client), - _=Depends(verify_api_key) + _=Depends(verify_api_key), + enriched: bool = True, ): """ Returns a list of available models in the OpenAI-compatible format. + + Query Parameters: + enriched: If True (default), returns detailed model info with pricing and capabilities. + If False, returns minimal OpenAI-compatible response. """ model_ids = await client.get_all_available_models(grouped=False) - model_cards = [ModelCard(id=model_id) for model_id in model_ids] - return ModelList(data=model_cards) + + if enriched and hasattr(request.app.state, 'model_info_service'): + model_info_service = request.app.state.model_info_service + if model_info_service.is_ready: + # Return enriched model data + enriched_data = model_info_service.enrich_model_list(model_ids) + return {"object": "list", "data": enriched_data} + + # Fallback to basic model cards + model_cards = [{"id": model_id, "object": "model", "created": int(time.time()), "owned_by": "Mirro-Proxy"} for model_id in model_ids] + return {"object": "list", "data": model_cards} + + +@app.get("/v1/models/{model_id:path}") +async def get_model( + model_id: str, + request: Request, + _=Depends(verify_api_key), +): + """ + Returns detailed information about a specific model. + + Path Parameters: + model_id: The model ID (e.g., "anthropic/claude-3-opus", "openrouter/openai/gpt-4") + """ + if hasattr(request.app.state, 'model_info_service'): + model_info_service = request.app.state.model_info_service + if model_info_service.is_ready: + info = model_info_service.get_model_info(model_id) + if info: + return info.to_dict() + + # Return basic info if service not ready or model not found + return { + "id": model_id, + "object": "model", + "created": int(time.time()), + "owned_by": model_id.split("/")[0] if "/" in model_id else "unknown", + } + + +@app.get("/v1/model-info/stats") +async def model_info_stats( + request: Request, + _=Depends(verify_api_key), +): + """ + Returns statistics about the model info service (for monitoring/debugging). + """ + if hasattr(request.app.state, 'model_info_service'): + return request.app.state.model_info_service.get_stats() + return {"error": "Model info service not initialized"} @app.get("/v1/providers") @@ -850,6 +1014,101 @@ async def token_count( logging.error(f"Token count failed: {e}") raise HTTPException(status_code=500, detail=str(e)) + +@app.post("/v1/cost-estimate") +async def cost_estimate( + request: Request, + _=Depends(verify_api_key) +): + """ + Estimates the cost for a request based on token counts and model pricing. + + Request body: + { + "model": "anthropic/claude-3-opus", + "prompt_tokens": 1000, + "completion_tokens": 500, + "cache_read_tokens": 0, # optional + "cache_creation_tokens": 0 # optional + } + + Returns: + { + "model": "anthropic/claude-3-opus", + "cost": 0.0375, + "currency": "USD", + "pricing": { + "input_cost_per_token": 0.000015, + "output_cost_per_token": 0.000075 + }, + "source": "model_info_service" # or "litellm_fallback" + } + """ + try: + data = await request.json() + model = data.get("model") + prompt_tokens = data.get("prompt_tokens", 0) + completion_tokens = data.get("completion_tokens", 0) + cache_read_tokens = data.get("cache_read_tokens", 0) + cache_creation_tokens = data.get("cache_creation_tokens", 0) + + if not model: + raise HTTPException(status_code=400, detail="'model' is required.") + + result = { + "model": model, + "cost": None, + "currency": "USD", + "pricing": {}, + "source": None + } + + # Try model info service first + if hasattr(request.app.state, 'model_info_service'): + model_info_service = request.app.state.model_info_service + if model_info_service.is_ready: + cost = model_info_service.calculate_cost( + model, prompt_tokens, completion_tokens, + cache_read_tokens, cache_creation_tokens + ) + if cost is not None: + cost_info = model_info_service.get_cost_info(model) + result["cost"] = cost + result["pricing"] = cost_info or {} + result["source"] = "model_info_service" + return result + + # Fallback to litellm + try: + import litellm + # Create a mock response for cost calculation + model_info = litellm.get_model_info(model) + input_cost = model_info.get("input_cost_per_token", 0) + output_cost = model_info.get("output_cost_per_token", 0) + + if input_cost or output_cost: + cost = (prompt_tokens * input_cost) + (completion_tokens * output_cost) + result["cost"] = cost + result["pricing"] = { + "input_cost_per_token": input_cost, + "output_cost_per_token": output_cost + } + result["source"] = "litellm_fallback" + return result + except Exception: + pass + + result["source"] = "unknown" + result["error"] = "Pricing data not available for this model" + return result + + except HTTPException: + raise + except Exception as e: + logging.error(f"Cost estimate failed: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + if __name__ == "__main__": # Define ENV_FILE for onboarding checks ENV_FILE = Path.cwd() / ".env" diff --git a/src/proxy_app/settings_tool.py b/src/proxy_app/settings_tool.py index 67ee0cb..71641f3 100644 --- a/src/proxy_app/settings_tool.py +++ b/src/proxy_app/settings_tool.py @@ -166,6 +166,184 @@ def remove_limit(self, provider: str): self.settings.remove(key) +# ============================================================================= +# PROVIDER-SPECIFIC SETTINGS DEFINITIONS +# ============================================================================= + +# Antigravity provider environment variables +ANTIGRAVITY_SETTINGS = { + "ANTIGRAVITY_SIGNATURE_CACHE_TTL": { + "type": "int", + "default": 3600, + "description": "Memory cache TTL for Gemini 3 thought signatures (seconds)", + }, + "ANTIGRAVITY_SIGNATURE_DISK_TTL": { + "type": "int", + "default": 86400, + "description": "Disk cache TTL for Gemini 3 thought signatures (seconds)", + }, + "ANTIGRAVITY_PRESERVE_THOUGHT_SIGNATURES": { + "type": "bool", + "default": True, + "description": "Preserve thought signatures in client responses", + }, + "ANTIGRAVITY_ENABLE_SIGNATURE_CACHE": { + "type": "bool", + "default": True, + "description": "Enable signature caching for multi-turn conversations", + }, + "ANTIGRAVITY_ENABLE_DYNAMIC_MODELS": { + "type": "bool", + "default": False, + "description": "Enable dynamic model discovery from API", + }, + "ANTIGRAVITY_GEMINI3_TOOL_FIX": { + "type": "bool", + "default": True, + "description": "Enable Gemini 3 tool hallucination prevention", + }, + "ANTIGRAVITY_CLAUDE_TOOL_FIX": { + "type": "bool", + "default": True, + "description": "Enable Claude tool hallucination prevention", + }, + "ANTIGRAVITY_CLAUDE_THINKING_SANITIZATION": { + "type": "bool", + "default": True, + "description": "Sanitize thinking blocks for Claude multi-turn conversations", + }, + "ANTIGRAVITY_GEMINI3_TOOL_PREFIX": { + "type": "str", + "default": "gemini3_", + "description": "Prefix added to tool names for Gemini 3 disambiguation", + }, + "ANTIGRAVITY_GEMINI3_DESCRIPTION_PROMPT": { + "type": "str", + "default": "\n\nSTRICT PARAMETERS: {params}.", + "description": "Template for strict parameter hints in tool descriptions", + }, + "ANTIGRAVITY_CLAUDE_DESCRIPTION_PROMPT": { + "type": "str", + "default": "\n\nSTRICT PARAMETERS: {params}.", + "description": "Template for Claude strict parameter hints in tool descriptions", + }, +} + +# Gemini CLI provider environment variables +GEMINI_CLI_SETTINGS = { + "GEMINI_CLI_SIGNATURE_CACHE_TTL": { + "type": "int", + "default": 3600, + "description": "Memory cache TTL for thought signatures (seconds)", + }, + "GEMINI_CLI_SIGNATURE_DISK_TTL": { + "type": "int", + "default": 86400, + "description": "Disk cache TTL for thought signatures (seconds)", + }, + "GEMINI_CLI_PRESERVE_THOUGHT_SIGNATURES": { + "type": "bool", + "default": True, + "description": "Preserve thought signatures in client responses", + }, + "GEMINI_CLI_ENABLE_SIGNATURE_CACHE": { + "type": "bool", + "default": True, + "description": "Enable signature caching for multi-turn conversations", + }, + "GEMINI_CLI_GEMINI3_TOOL_FIX": { + "type": "bool", + "default": True, + "description": "Enable Gemini 3 tool hallucination prevention", + }, + "GEMINI_CLI_GEMINI3_TOOL_PREFIX": { + "type": "str", + "default": "gemini3_", + "description": "Prefix added to tool names for Gemini 3 disambiguation", + }, + "GEMINI_CLI_GEMINI3_DESCRIPTION_PROMPT": { + "type": "str", + "default": "\n\nSTRICT PARAMETERS: {params}.", + "description": "Template for strict parameter hints in tool descriptions", + }, + "GEMINI_CLI_PROJECT_ID": { + "type": "str", + "default": "", + "description": "GCP Project ID for paid tier users (required for paid tiers)", + }, +} + +# Map provider names to their settings definitions +PROVIDER_SETTINGS_MAP = { + "antigravity": ANTIGRAVITY_SETTINGS, + "gemini_cli": GEMINI_CLI_SETTINGS, +} + + +class ProviderSettingsManager: + """Manages provider-specific configuration settings""" + + def __init__(self, settings: AdvancedSettings): + self.settings = settings + + def get_available_providers(self) -> List[str]: + """Get list of providers with specific settings available""" + return list(PROVIDER_SETTINGS_MAP.keys()) + + def get_provider_settings_definitions(self, provider: str) -> Dict[str, Dict[str, Any]]: + """Get settings definitions for a provider""" + return PROVIDER_SETTINGS_MAP.get(provider, {}) + + def get_current_value(self, key: str, definition: Dict[str, Any]) -> Any: + """Get current value of a setting from environment""" + env_value = os.getenv(key) + if env_value is None: + return definition.get("default") + + setting_type = definition.get("type", "str") + try: + if setting_type == "bool": + return env_value.lower() in ("true", "1", "yes") + elif setting_type == "int": + return int(env_value) + else: + return env_value + except (ValueError, AttributeError): + return definition.get("default") + + def get_all_current_values(self, provider: str) -> Dict[str, Any]: + """Get all current values for a provider""" + definitions = self.get_provider_settings_definitions(provider) + values = {} + for key, definition in definitions.items(): + values[key] = self.get_current_value(key, definition) + return values + + def set_value(self, key: str, value: Any, definition: Dict[str, Any]): + """Set a setting value, converting to string for .env storage""" + setting_type = definition.get("type", "str") + if setting_type == "bool": + str_value = "true" if value else "false" + else: + str_value = str(value) + self.settings.set(key, str_value) + + def reset_to_default(self, key: str): + """Remove a setting to reset it to default""" + self.settings.remove(key) + + def get_modified_settings(self, provider: str) -> Dict[str, Any]: + """Get settings that differ from defaults""" + definitions = self.get_provider_settings_definitions(provider) + modified = {} + for key, definition in definitions.items(): + current = self.get_current_value(key, definition) + default = definition.get("default") + if current != default: + modified[key] = current + return modified + + class SettingsTool: """Main settings tool TUI""" @@ -175,6 +353,7 @@ def __init__(self): self.provider_mgr = CustomProviderManager(self.settings) self.model_mgr = ModelDefinitionManager(self.settings) self.concurrency_mgr = ConcurrencyManager(self.settings) + self.provider_settings_mgr = ProviderSettingsManager(self.settings) self.running = True def get_available_providers(self) -> List[str]: @@ -223,8 +402,9 @@ def show_main_menu(self): self.console.print(" 1. 🌐 Custom Provider API Bases") self.console.print(" 2. 📦 Provider Model Definitions") self.console.print(" 3. ⚡ Concurrency Limits") - self.console.print(" 4. 💾 Save & Exit") - self.console.print(" 5. 🚫 Exit Without Saving") + self.console.print(" 4. 🔬 Provider-Specific Settings") + self.console.print(" 5. 💾 Save & Exit") + self.console.print(" 6. 🚫 Exit Without Saving") self.console.print() self.console.print("━" * 70) @@ -238,7 +418,7 @@ def show_main_menu(self): self.console.print("[dim]⚠️ Model filters not supported - edit .env for IGNORE_MODELS_* / WHITELIST_MODELS_*[/dim]") self.console.print() - choice = Prompt.ask("Select option", choices=["1", "2", "3", "4", "5"], show_choices=False) + choice = Prompt.ask("Select option", choices=["1", "2", "3", "4", "5", "6"], show_choices=False) if choice == "1": self.manage_custom_providers() @@ -247,8 +427,10 @@ def show_main_menu(self): elif choice == "3": self.manage_concurrency_limits() elif choice == "4": - self.save_and_exit() + self.manage_provider_settings() elif choice == "5": + self.save_and_exit() + elif choice == "6": self.exit_without_saving() def manage_custom_providers(self): @@ -631,6 +813,195 @@ def view_model_definitions(self, providers: List[str]): input("Press Enter to return...") + def manage_provider_settings(self): + """Manage provider-specific settings (Antigravity, Gemini CLI)""" + while True: + self.console.clear() + + available_providers = self.provider_settings_mgr.get_available_providers() + + self.console.print(Panel.fit( + "[bold cyan]🔬 Provider-Specific Settings[/bold cyan]", + border_style="cyan" + )) + + self.console.print() + self.console.print("[bold]📋 Available Providers with Custom Settings[/bold]") + self.console.print("━" * 70) + + for provider in available_providers: + modified = self.provider_settings_mgr.get_modified_settings(provider) + status = f"[yellow]{len(modified)} modified[/yellow]" if modified else "[dim]defaults[/dim]" + display_name = provider.replace("_", " ").title() + self.console.print(f" • {display_name:20} {status}") + + self.console.print() + self.console.print("━" * 70) + self.console.print() + self.console.print("[bold]⚙️ Select Provider to Configure[/bold]") + self.console.print() + + for idx, provider in enumerate(available_providers, 1): + display_name = provider.replace("_", " ").title() + self.console.print(f" {idx}. {display_name}") + self.console.print(f" {len(available_providers) + 1}. ↩️ Back to Settings Menu") + + self.console.print() + self.console.print("━" * 70) + self.console.print() + + choices = [str(i) for i in range(1, len(available_providers) + 2)] + choice = Prompt.ask("Select option", choices=choices, show_choices=False) + choice_idx = int(choice) + + if choice_idx == len(available_providers) + 1: + break + + provider = available_providers[choice_idx - 1] + self._manage_single_provider_settings(provider) + + def _manage_single_provider_settings(self, provider: str): + """Manage settings for a single provider""" + while True: + self.console.clear() + + display_name = provider.replace("_", " ").title() + definitions = self.provider_settings_mgr.get_provider_settings_definitions(provider) + current_values = self.provider_settings_mgr.get_all_current_values(provider) + + self.console.print(Panel.fit( + f"[bold cyan]🔬 {display_name} Settings[/bold cyan]", + border_style="cyan" + )) + + self.console.print() + self.console.print("[bold]📋 Current Settings[/bold]") + self.console.print("━" * 70) + + # Display all settings with current values + settings_list = list(definitions.keys()) + for idx, key in enumerate(settings_list, 1): + definition = definitions[key] + current = current_values.get(key) + default = definition.get("default") + setting_type = definition.get("type", "str") + description = definition.get("description", "") + + # Format value display + if setting_type == "bool": + value_display = "[green]✓ Enabled[/green]" if current else "[red]✗ Disabled[/red]" + elif setting_type == "int": + value_display = f"[cyan]{current}[/cyan]" + else: + value_display = f"[cyan]{current or '(not set)'}[/cyan]" if current else "[dim](not set)[/dim]" + + # Check if modified from default + modified = current != default + mod_marker = "[yellow]*[/yellow]" if modified else " " + + # Short key name for display (strip provider prefix) + short_key = key.replace(f"{provider.upper()}_", "") + + self.console.print(f" {mod_marker}{idx:2}. {short_key:35} {value_display}") + self.console.print(f" [dim]{description}[/dim]") + + self.console.print() + self.console.print("━" * 70) + self.console.print("[dim]* = modified from default[/dim]") + self.console.print() + self.console.print("[bold]⚙️ Actions[/bold]") + self.console.print() + self.console.print(" E. ✏️ Edit a Setting") + self.console.print(" R. 🔄 Reset Setting to Default") + self.console.print(" A. 🔄 Reset All to Defaults") + self.console.print(" B. ↩️ Back to Provider Selection") + + self.console.print() + self.console.print("━" * 70) + self.console.print() + + choice = Prompt.ask("Select action", choices=["e", "r", "a", "b", "E", "R", "A", "B"], show_choices=False).lower() + + if choice == "b": + break + elif choice == "e": + self._edit_provider_setting(provider, settings_list, definitions) + elif choice == "r": + self._reset_provider_setting(provider, settings_list, definitions) + elif choice == "a": + self._reset_all_provider_settings(provider, settings_list) + + def _edit_provider_setting(self, provider: str, settings_list: List[str], definitions: Dict[str, Dict[str, Any]]): + """Edit a single provider setting""" + self.console.print("\n[bold]Select setting number to edit:[/bold]") + + choices = [str(i) for i in range(1, len(settings_list) + 1)] + choice = IntPrompt.ask("Setting number", choices=choices) + key = settings_list[choice - 1] + definition = definitions[key] + + current = self.provider_settings_mgr.get_current_value(key, definition) + default = definition.get("default") + setting_type = definition.get("type", "str") + short_key = key.replace(f"{provider.upper()}_", "") + + self.console.print(f"\n[bold]Editing: {short_key}[/bold]") + self.console.print(f"Current value: [cyan]{current}[/cyan]") + self.console.print(f"Default value: [dim]{default}[/dim]") + self.console.print(f"Type: {setting_type}") + + if setting_type == "bool": + new_value = Confirm.ask("\nEnable this setting?", default=current) + self.provider_settings_mgr.set_value(key, new_value, definition) + status = "enabled" if new_value else "disabled" + self.console.print(f"\n[green]✅ {short_key} {status}![/green]") + elif setting_type == "int": + new_value = IntPrompt.ask("\nNew value", default=current) + self.provider_settings_mgr.set_value(key, new_value, definition) + self.console.print(f"\n[green]✅ {short_key} set to {new_value}![/green]") + else: + new_value = Prompt.ask("\nNew value", default=str(current) if current else "").strip() + if new_value: + self.provider_settings_mgr.set_value(key, new_value, definition) + self.console.print(f"\n[green]✅ {short_key} updated![/green]") + else: + self.console.print("\n[yellow]No changes made[/yellow]") + + input("\nPress Enter to continue...") + + def _reset_provider_setting(self, provider: str, settings_list: List[str], definitions: Dict[str, Dict[str, Any]]): + """Reset a single provider setting to default""" + self.console.print("\n[bold]Select setting number to reset:[/bold]") + + choices = [str(i) for i in range(1, len(settings_list) + 1)] + choice = IntPrompt.ask("Setting number", choices=choices) + key = settings_list[choice - 1] + definition = definitions[key] + + default = definition.get("default") + short_key = key.replace(f"{provider.upper()}_", "") + + if Confirm.ask(f"\nReset {short_key} to default ({default})?"): + self.provider_settings_mgr.reset_to_default(key) + self.console.print(f"\n[green]✅ {short_key} reset to default![/green]") + else: + self.console.print("\n[yellow]No changes made[/yellow]") + + input("\nPress Enter to continue...") + + def _reset_all_provider_settings(self, provider: str, settings_list: List[str]): + """Reset all provider settings to defaults""" + display_name = provider.replace("_", " ").title() + + if Confirm.ask(f"\n[bold red]Reset ALL {display_name} settings to defaults?[/bold red]"): + for key in settings_list: + self.provider_settings_mgr.reset_to_default(key) + self.console.print(f"\n[green]✅ All {display_name} settings reset to defaults![/green]") + else: + self.console.print("\n[yellow]No changes made[/yellow]") + + input("\nPress Enter to continue...") + def manage_concurrency_limits(self): """Manage concurrency limits""" while True: diff --git a/src/rotator_library/README.md b/src/rotator_library/README.md index c020799..2050f1b 100644 --- a/src/rotator_library/README.md +++ b/src/rotator_library/README.md @@ -7,9 +7,11 @@ A robust, asynchronous, and thread-safe Python library for managing a pool of AP - **Asynchronous by Design**: Built with `asyncio` and `httpx` for high-performance, non-blocking I/O. - **Advanced Concurrency Control**: A single API key can be used for multiple concurrent requests. By default, it supports concurrent requests to *different* models. With configuration (`MAX_CONCURRENT_REQUESTS_PER_KEY_`), it can also support multiple concurrent requests to the *same* model using the same key. - **Smart Key Management**: Selects the optimal key for each request using a tiered, model-aware locking strategy to distribute load evenly and maximize availability. +- **Configurable Rotation Strategy**: Choose between deterministic least-used selection (perfect balance) or default weighted random selection (unpredictable, harder to fingerprint). - **Deadline-Driven Requests**: A global timeout ensures that no request, including all retries and key selections, exceeds a specified time limit. - **OAuth & API Key Support**: Built-in support for standard API keys and complex OAuth flows. - - **Gemini CLI**: Full OAuth 2.0 web flow with automatic project discovery and free-tier onboarding. + - **Gemini CLI**: Full OAuth 2.0 web flow with automatic project discovery, free-tier onboarding, and credential prioritization (paid vs free tier). + - **Antigravity**: Full OAuth 2.0 support for Gemini 3, Gemini 2.5, and Claude Sonnet 4.5 models with thought signature caching(Full support for Gemini 3 and Claude models). **First on the scene to provide full support for Gemini 3** via Antigravity with advanced features like thought signature caching and tool hallucination prevention. - **Qwen Code**: Device Code flow support. - **iFlow**: Authorization Code flow with local callback handling. - **Stateless Deployment Ready**: Can load complex OAuth credentials from environment variables, eliminating the need for physical credential files in containerized environments. @@ -17,11 +19,15 @@ A robust, asynchronous, and thread-safe Python library for managing a pool of AP - **Escalating Per-Model Cooldowns**: Failed keys are placed on a temporary, escalating cooldown for specific models. - **Key-Level Lockouts**: Keys failing across multiple models are temporarily removed from rotation. - **Stream Recovery**: The client detects mid-stream errors (like quota limits) and gracefully handles them. +- **Credential Prioritization**: Automatic tier detection and priority-based credential selection (e.g., paid tier credentials used first for models that require them). +- **Advanced Model Requirements**: Support for model-tier restrictions (e.g., Gemini 3 requires paid-tier credentials). - **Robust Streaming Support**: Includes a wrapper for streaming responses that reassembles fragmented JSON chunks. - **Detailed Usage Tracking**: Tracks daily and global usage for each key, persisted to a JSON file. - **Automatic Daily Resets**: Automatically resets cooldowns and archives stats daily. - **Provider Agnostic**: Works with any provider supported by `litellm`. - **Extensible**: Easily add support for new providers through a simple plugin-based architecture. +- **Temperature Override**: Global temperature=0 override to prevent tool hallucination with low-temperature settings. +- **Shared OAuth Base**: Refactored OAuth implementation with reusable [`GoogleOAuthBase`](providers/google_oauth_base.py) for multiple providers. ## Installation @@ -71,7 +77,8 @@ client = RotatingClient( ignore_models={}, whitelist_models={}, enable_request_logging=False, - max_concurrent_requests_per_key={} + max_concurrent_requests_per_key={}, + rotation_tolerance=2.0 # 0.0=deterministic, 2.0=recommended random ) ``` @@ -89,6 +96,17 @@ client = RotatingClient( - `whitelist_models` (`Optional[Dict[str, List[str]]]`, default: `None`): A dictionary where keys are provider names and values are lists of model names/patterns to always include, overriding `ignore_models`. - `enable_request_logging` (`bool`, default: `False`): If `True`, enables detailed per-request file logging (useful for debugging complex interactions). - `max_concurrent_requests_per_key` (`Optional[Dict[str, int]]`, default: `None`): A dictionary defining the maximum number of concurrent requests allowed for a single API key for a specific provider. Defaults to 1 if not specified. +- `rotation_tolerance` (`float`, default: `0.0`): Controls credential rotation strategy: + - `0.0`: **Deterministic** - Always selects the least-used credential for perfect load balance. + - `2.0` (default, recommended): **Weighted Random** - Randomly selects credentials with bias toward less-used ones. Provides unpredictability (harder to fingerprint) while maintaining good balance. + - `5.0+`: **High Randomness** - Even heavily-used credentials have significant selection probability. Maximum unpredictability. + + The weight formula is: `weight = (max_usage - credential_usage) + tolerance + 1` + + **Use Cases:** + - `0.0`: When perfect load balance is critical + - `2.0`: When avoiding fingerprinting/rate limit detection is important + - `5.0+`: For stress testing or maximum unpredictability ### Concurrency and Resource Management @@ -185,9 +203,27 @@ Use this tool to: ### Google Gemini (CLI) - **Auth**: Simulates the Google Cloud CLI authentication flow. -- **Project Discovery**: Automatically discovers the default Google Cloud Project ID. +- **Project Discovery**: Automatically discovers the default Google Cloud Project ID with enhanced onboarding flow. +- **Credential Prioritization**: Automatic detection and prioritization of paid vs free tier credentials. +- **Model Tier Requirements**: Gemini 3 models automatically filtered to paid-tier credentials only. +- **Gemini 3 Support**: Full support for Gemini 3 models with: + - `thinkingLevel` configuration (low/high) + - Tool hallucination prevention via system instruction injection + - ThoughtSignature caching for multi-turn conversations + - Parameter signature injection into tool descriptions - **Rate Limits**: Implements smart fallback strategies (e.g., switching from `gemini-1.5-pro` to `gemini-1.5-pro-002`) when rate limits are hit. +### Antigravity +- **Auth**: Uses OAuth 2.0 flow similar to Gemini CLI, with Antigravity-specific credentials and scopes. +- **Models**: Supports Gemini 2.5 (Pro/Flash), Gemini 3 (Pro/Image), and Claude Sonnet 4.5 via Google's internal Antigravity API. +- **Thought Signature Caching**: Server-side caching of `thoughtSignature` data for multi-turn conversations with Gemini 3 models. +- **Tool Hallucination Prevention**: Automatic injection of system instructions and parameter signatures for Gemini 3 to prevent tool parameter hallucination. +- **Thinking Support**: + - Gemini 2.5: Uses `thinkingBudget` (integer tokens) + - Gemini 3: Uses `thinkingLevel` (string: "low"/"high") + - Claude: Uses `thinkingBudget` via Antigravity proxy +- **Base URL Fallback**: Automatic fallback between sandbox and production endpoints. + ## Error Handling and Cooldowns The client uses a sophisticated error handling mechanism: diff --git a/src/rotator_library/__init__.py b/src/rotator_library/__init__.py index 9a67812..f3ff0ec 100644 --- a/src/rotator_library/__init__.py +++ b/src/rotator_library/__init__.py @@ -7,12 +7,19 @@ if TYPE_CHECKING: from .providers import PROVIDER_PLUGINS from .providers.provider_interface import ProviderInterface + from .model_info_service import ModelInfoService, ModelInfo -__all__ = ["RotatingClient", "PROVIDER_PLUGINS"] +__all__ = ["RotatingClient", "PROVIDER_PLUGINS", "ModelInfoService", "ModelInfo"] def __getattr__(name): - """Lazy-load PROVIDER_PLUGINS to speed up module import.""" + """Lazy-load PROVIDER_PLUGINS and ModelInfoService to speed up module import.""" if name == "PROVIDER_PLUGINS": from .providers import PROVIDER_PLUGINS return PROVIDER_PLUGINS + if name == "ModelInfoService": + from .model_info_service import ModelInfoService + return ModelInfoService + if name == "ModelInfo": + from .model_info_service import ModelInfo + return ModelInfo raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/rotator_library/client.py b/src/rotator_library/client.py index 83a285f..e536aeb 100644 --- a/src/rotator_library/client.py +++ b/src/rotator_library/client.py @@ -63,7 +63,29 @@ def __init__( whitelist_models: Optional[Dict[str, List[str]]] = None, enable_request_logging: bool = False, max_concurrent_requests_per_key: Optional[Dict[str, int]] = None, + rotation_tolerance: float = 3.0, ): + """ + Initialize the RotatingClient with intelligent credential rotation. + + Args: + api_keys: Dictionary mapping provider names to lists of API keys + oauth_credentials: Dictionary mapping provider names to OAuth credential paths + max_retries: Maximum number of retry attempts per credential + usage_file_path: Path to store usage statistics + configure_logging: Whether to configure library logging + global_timeout: Global timeout for requests in seconds + abort_on_callback_error: Whether to abort on pre-request callback errors + litellm_provider_params: Provider-specific parameters for LiteLLM + ignore_models: Models to ignore/blacklist per provider + whitelist_models: Models to explicitly whitelist per provider + enable_request_logging: Whether to enable detailed request logging + max_concurrent_requests_per_key: Max concurrent requests per key by provider + rotation_tolerance: Tolerance for weighted random credential rotation. + - 0.0: Deterministic, least-used credential always selected + - 2.0 - 4.0 (default, recommended): Balanced randomness, can pick credentials within 2 uses of max + - 5.0+: High randomness, more unpredictable selection patterns + """ os.environ["LITELLM_LOG"] = "ERROR" litellm.set_verbose = False litellm.drop_params = True @@ -93,8 +115,13 @@ def __init__( ) self.api_keys = api_keys - self.credential_manager = CredentialManager(oauth_credentials) - self.oauth_credentials = self.credential_manager.discover_and_prepare() + # Use provided oauth_credentials directly if available (already discovered by main.py) + # Only call discover_and_prepare() if no credentials were passed + if oauth_credentials: + self.oauth_credentials = oauth_credentials + else: + self.credential_manager = CredentialManager(os.environ) + self.oauth_credentials = self.credential_manager.discover_and_prepare() self.background_refresher = BackgroundRefresher(self) self.oauth_providers = set(self.oauth_credentials.keys()) @@ -108,7 +135,10 @@ def __init__( self.max_retries = max_retries self.global_timeout = global_timeout self.abort_on_callback_error = abort_on_callback_error - self.usage_manager = UsageManager(file_path=usage_file_path) + self.usage_manager = UsageManager( + file_path=usage_file_path, + rotation_tolerance=rotation_tolerance + ) self._model_list_cache = {} self._provider_plugins = PROVIDER_PLUGINS self._provider_instances = {} @@ -393,7 +423,23 @@ def _is_custom_openai_compatible_provider(self, provider_name: str) -> bool: return os.getenv(api_base_env) is not None def _get_provider_instance(self, provider_name: str): - """Lazily initializes and returns a provider instance.""" + """ + Lazily initializes and returns a provider instance. + Only initializes providers that have configured credentials. + + Args: + provider_name: The name of the provider to get an instance for. + + Returns: + Provider instance if credentials exist, None otherwise. + """ + # Only initialize providers for which we have credentials + if provider_name not in self.all_credentials: + lib_logger.debug( + f"Skipping provider '{provider_name}' initialization: no credentials configured" + ) + return None + if provider_name not in self._provider_instances: if provider_name in self._provider_plugins: self._provider_instances[provider_name] = self._provider_plugins[ @@ -454,11 +500,19 @@ async def _safe_streaming_wrapper( """ A hybrid wrapper for streaming that buffers fragmented JSON, handles client disconnections gracefully, and distinguishes between content and streamed errors. + + FINISH_REASON HANDLING: + Providers just translate chunks - this wrapper handles ALL finish_reason logic: + 1. Strip finish_reason from intermediate chunks (litellm defaults to "stop") + 2. Track accumulated_finish_reason with priority: tool_calls > length/content_filter > stop + 3. Only emit finish_reason on final chunk (detected by usage.completion_tokens > 0) """ last_usage = None stream_completed = False stream_iterator = stream.__aiter__() json_buffer = "" + accumulated_finish_reason = None # Track strongest finish_reason across chunks + has_tool_calls = False # Track if ANY tool calls were seen in stream try: while True: @@ -466,26 +520,64 @@ async def _safe_streaming_wrapper( lib_logger.info( f"Client disconnected. Aborting stream for credential ...{key[-6:]}." ) - # Do not yield [DONE] because the client is gone. - # The 'finally' block will handle key release. break try: chunk = await stream_iterator.__anext__() if json_buffer: - # If we are about to discard a buffer, it means data was likely lost. - # Log this as a warning to make it visible. lib_logger.warning( f"Discarding incomplete JSON buffer from previous chunk: {json_buffer}" ) json_buffer = "" - yield f"data: {json.dumps(chunk.dict())}\n\n" + # Convert chunk to dict, handling both litellm.ModelResponse and raw dicts + if hasattr(chunk, "dict"): + chunk_dict = chunk.dict() + elif hasattr(chunk, "model_dump"): + chunk_dict = chunk.model_dump() + else: + chunk_dict = chunk + + # === FINISH_REASON LOGIC === + # Providers send raw chunks without finish_reason logic. + # This wrapper determines finish_reason based on accumulated state. + if "choices" in chunk_dict and chunk_dict["choices"]: + choice = chunk_dict["choices"][0] + delta = choice.get("delta", {}) + usage = chunk_dict.get("usage", {}) + + # Track tool_calls across ALL chunks - if we ever see one, finish_reason must be tool_calls + if delta.get("tool_calls"): + has_tool_calls = True + accumulated_finish_reason = "tool_calls" + + # Detect final chunk: has usage with completion_tokens > 0 + has_completion_tokens = ( + usage and + isinstance(usage, dict) and + usage.get("completion_tokens", 0) > 0 + ) + + if has_completion_tokens: + # FINAL CHUNK: Determine correct finish_reason + if has_tool_calls: + # Tool calls always win + choice["finish_reason"] = "tool_calls" + elif accumulated_finish_reason: + # Use accumulated reason (length, content_filter, etc.) + choice["finish_reason"] = accumulated_finish_reason + else: + # Default to stop + choice["finish_reason"] = "stop" + else: + # INTERMEDIATE CHUNK: Never emit finish_reason + # (litellm.ModelResponse defaults to "stop" which is wrong) + choice["finish_reason"] = None + + yield f"data: {json.dumps(chunk_dict)}\n\n" if hasattr(chunk, "usage") and chunk.usage: - last_usage = ( - chunk.usage - ) # Overwrite with the latest (cumulative) + last_usage = chunk.usage except StopAsyncIteration: stream_completed = True @@ -656,6 +748,73 @@ async def _execute_with_retry( lib_logger.info(f"Resolved model '{model}' to '{resolved_model}'") model = resolved_model kwargs["model"] = model # Ensure kwargs has the resolved model for litellm + + # [NEW] Filter by model tier requirement and build priority map + credential_priorities = None + if provider_plugin and hasattr(provider_plugin, 'get_model_tier_requirement'): + required_tier = provider_plugin.get_model_tier_requirement(model) + if required_tier is not None: + # Filter OUT only credentials we KNOW are too low priority + # Keep credentials with unknown priority (None) - they might be high priority + incompatible_creds = [] + compatible_creds = [] + unknown_creds = [] + + for cred in credentials_for_provider: + if hasattr(provider_plugin, 'get_credential_priority'): + priority = provider_plugin.get_credential_priority(cred) + if priority is None: + # Unknown priority - keep it, will be discovered on first use + unknown_creds.append(cred) + elif priority <= required_tier: + # Known compatible priority + compatible_creds.append(cred) + else: + # Known incompatible priority (too low) + incompatible_creds.append(cred) + else: + # Provider doesn't support priorities - keep all + unknown_creds.append(cred) + + # If we have any known-compatible or unknown credentials, use them + tier_compatible_creds = compatible_creds + unknown_creds + if tier_compatible_creds: + credentials_for_provider = tier_compatible_creds + if compatible_creds and unknown_creds: + lib_logger.info( + f"Model {model} requires priority <= {required_tier}. " + f"Using {len(compatible_creds)} known-compatible + {len(unknown_creds)} unknown-tier credentials." + ) + elif compatible_creds: + lib_logger.info( + f"Model {model} requires priority <= {required_tier}. " + f"Using {len(compatible_creds)} known-compatible credentials." + ) + else: + lib_logger.info( + f"Model {model} requires priority <= {required_tier}. " + f"Using {len(unknown_creds)} unknown-tier credentials (will discover on use)." + ) + elif incompatible_creds: + # Only known-incompatible credentials remain + lib_logger.warning( + f"Model {model} requires priority <= {required_tier} credentials, " + f"but all {len(incompatible_creds)} known credentials have priority > {required_tier}. " + f"Request will likely fail." + ) + + # Build priority map for usage_manager + if provider_plugin and hasattr(provider_plugin, 'get_credential_priority'): + credential_priorities = {} + for cred in credentials_for_provider: + priority = provider_plugin.get_credential_priority(cred) + if priority is not None: + credential_priorities[cred] = priority + + if credential_priorities: + lib_logger.debug( + f"Credential priorities for {provider}: {', '.join(f'P{p}={len([c for c in credentials_for_provider if credential_priorities.get(c)==p])}' for p in sorted(set(credential_priorities.values())))}" + ) while ( len(tried_creds) < len(credentials_for_provider) and time.time() < deadline @@ -694,7 +853,8 @@ async def _execute_with_retry( max_concurrent = self.max_concurrent_requests_per_key.get(provider, 1) current_cred = await self.usage_manager.acquire_key( available_keys=creds_to_try, model=model, deadline=deadline, - max_concurrent=max_concurrent + max_concurrent=max_concurrent, + credential_priorities=credential_priorities ) key_acquired = True tried_creds.add(current_cred) @@ -1031,6 +1191,73 @@ async def _streaming_acompletion_with_retry( lib_logger.info(f"Resolved model '{model}' to '{resolved_model}'") model = resolved_model kwargs["model"] = model # Ensure kwargs has the resolved model for litellm + + # [NEW] Filter by model tier requirement and build priority map + credential_priorities = None + if provider_plugin and hasattr(provider_plugin, 'get_model_tier_requirement'): + required_tier = provider_plugin.get_model_tier_requirement(model) + if required_tier is not None: + # Filter OUT only credentials we KNOW are too low priority + # Keep credentials with unknown priority (None) - they might be high priority + incompatible_creds = [] + compatible_creds = [] + unknown_creds = [] + + for cred in credentials_for_provider: + if hasattr(provider_plugin, 'get_credential_priority'): + priority = provider_plugin.get_credential_priority(cred) + if priority is None: + # Unknown priority - keep it, will be discovered on first use + unknown_creds.append(cred) + elif priority <= required_tier: + # Known compatible priority + compatible_creds.append(cred) + else: + # Known incompatible priority (too low) + incompatible_creds.append(cred) + else: + # Provider doesn't support priorities - keep all + unknown_creds.append(cred) + + # If we have any known-compatible or unknown credentials, use them + tier_compatible_creds = compatible_creds + unknown_creds + if tier_compatible_creds: + credentials_for_provider = tier_compatible_creds + if compatible_creds and unknown_creds: + lib_logger.info( + f"Model {model} requires priority <= {required_tier}. " + f"Using {len(compatible_creds)} known-compatible + {len(unknown_creds)} unknown-tier credentials." + ) + elif compatible_creds: + lib_logger.info( + f"Model {model} requires priority <= {required_tier}. " + f"Using {len(compatible_creds)} known-compatible credentials." + ) + else: + lib_logger.info( + f"Model {model} requires priority <= {required_tier}. " + f"Using {len(unknown_creds)} unknown-tier credentials (will discover on use)." + ) + elif incompatible_creds: + # Only known-incompatible credentials remain + lib_logger.warning( + f"Model {model} requires priority <= {required_tier} credentials, " + f"but all {len(incompatible_creds)} known credentials have priority > {required_tier}. " + f"Request will likely fail." + ) + + # Build priority map for usage_manager + if provider_plugin and hasattr(provider_plugin, 'get_credential_priority'): + credential_priorities = {} + for cred in credentials_for_provider: + priority = provider_plugin.get_credential_priority(cred) + if priority is not None: + credential_priorities[cred] = priority + + if credential_priorities: + lib_logger.debug( + f"Credential priorities for {provider}: {', '.join(f'P{p}={len([c for c in credentials_for_provider if credential_priorities.get(c)==p])}' for p in sorted(set(credential_priorities.values())))}" + ) try: while ( @@ -1070,7 +1297,8 @@ async def _streaming_acompletion_with_retry( max_concurrent = self.max_concurrent_requests_per_key.get(provider, 1) current_cred = await self.usage_manager.acquire_key( available_keys=creds_to_try, model=model, deadline=deadline, - max_concurrent=max_concurrent + max_concurrent=max_concurrent, + credential_priorities=credential_priorities ) key_acquired = True tried_creds.add(current_cred) diff --git a/src/rotator_library/credential_manager.py b/src/rotator_library/credential_manager.py index c5426d7..16be41c 100644 --- a/src/rotator_library/credential_manager.py +++ b/src/rotator_library/credential_manager.py @@ -1,8 +1,9 @@ import os +import re import shutil import logging from pathlib import Path -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Set lib_logger = logging.getLogger('rotator_library') @@ -14,22 +15,100 @@ "gemini_cli": Path.home() / ".gemini", "qwen_code": Path.home() / ".qwen", "iflow": Path.home() / ".iflow", + "antigravity": Path.home() / ".antigravity", # Add other providers like 'claude' here if they have a standard CLI path } +# OAuth providers that support environment variable-based credentials +# Maps provider name to the ENV_PREFIX used by the provider +ENV_OAUTH_PROVIDERS = { + "gemini_cli": "GEMINI_CLI", + "antigravity": "ANTIGRAVITY", + "qwen_code": "QWEN_CODE", + "iflow": "IFLOW", +} + + class CredentialManager: """ Discovers OAuth credential files from standard locations, copies them locally, and updates the configuration to use the local paths. + + Also discovers environment variable-based OAuth credentials for stateless deployments. + Supports two env var formats: + + 1. Single credential (legacy): PROVIDER_ACCESS_TOKEN, PROVIDER_REFRESH_TOKEN + 2. Multiple credentials (numbered): PROVIDER_1_ACCESS_TOKEN, PROVIDER_2_ACCESS_TOKEN, etc. + + When env-based credentials are detected, virtual paths like "env://provider/1" are created. """ def __init__(self, env_vars: Dict[str, str]): self.env_vars = env_vars + def _discover_env_oauth_credentials(self) -> Dict[str, List[str]]: + """ + Discover OAuth credentials defined via environment variables. + + Supports two formats: + 1. Single credential: ANTIGRAVITY_ACCESS_TOKEN + ANTIGRAVITY_REFRESH_TOKEN + 2. Multiple credentials: ANTIGRAVITY_1_ACCESS_TOKEN + ANTIGRAVITY_1_REFRESH_TOKEN, etc. + + Returns: + Dict mapping provider name to list of virtual paths (e.g., "env://antigravity/1") + """ + env_credentials: Dict[str, Set[str]] = {} + + for provider, env_prefix in ENV_OAUTH_PROVIDERS.items(): + found_indices: Set[str] = set() + + # Check for numbered credentials (PROVIDER_N_ACCESS_TOKEN pattern) + # Pattern: ANTIGRAVITY_1_ACCESS_TOKEN, ANTIGRAVITY_2_ACCESS_TOKEN, etc. + numbered_pattern = re.compile(rf"^{env_prefix}_(\d+)_ACCESS_TOKEN$") + + for key in self.env_vars.keys(): + match = numbered_pattern.match(key) + if match: + index = match.group(1) + # Verify refresh token also exists + refresh_key = f"{env_prefix}_{index}_REFRESH_TOKEN" + if refresh_key in self.env_vars and self.env_vars[refresh_key]: + found_indices.add(index) + + # Check for legacy single credential (PROVIDER_ACCESS_TOKEN pattern) + # Only use this if no numbered credentials exist + if not found_indices: + access_key = f"{env_prefix}_ACCESS_TOKEN" + refresh_key = f"{env_prefix}_REFRESH_TOKEN" + if (access_key in self.env_vars and self.env_vars[access_key] and + refresh_key in self.env_vars and self.env_vars[refresh_key]): + # Use "0" as the index for legacy single credential + found_indices.add("0") + + if found_indices: + env_credentials[provider] = found_indices + lib_logger.info(f"Found {len(found_indices)} env-based credential(s) for {provider}") + + # Convert to virtual paths + result: Dict[str, List[str]] = {} + for provider, indices in env_credentials.items(): + # Sort indices numerically for consistent ordering + sorted_indices = sorted(indices, key=lambda x: int(x)) + result[provider] = [f"env://{provider}/{idx}" for idx in sorted_indices] + + return result + def discover_and_prepare(self) -> Dict[str, List[str]]: lib_logger.info("Starting automated OAuth credential discovery...") final_config = {} - # Extract OAuth paths from environment variables first + # PHASE 1: Discover environment variable-based OAuth credentials + # These take priority for stateless deployments + env_oauth_creds = self._discover_env_oauth_credentials() + for provider, virtual_paths in env_oauth_creds.items(): + lib_logger.info(f"Using {len(virtual_paths)} env-based credential(s) for {provider}") + final_config[provider] = virtual_paths + + # Extract OAuth file paths from environment variables env_oauth_paths = {} for key, value in self.env_vars.items(): if "_OAUTH_" in key: @@ -39,7 +118,13 @@ def discover_and_prepare(self) -> Dict[str, List[str]]: if value: # Only consider non-empty values env_oauth_paths[provider].append(value) + # PHASE 2: Discover file-based OAuth credentials for provider, default_dir in DEFAULT_OAUTH_DIRS.items(): + # Skip if already discovered from environment variables + if provider in final_config: + lib_logger.debug(f"Skipping file discovery for {provider} - using env-based credentials") + continue + # Check for existing local credentials first. If found, use them and skip discovery. local_provider_creds = sorted(list(OAUTH_BASE_DIR.glob(f"{provider}_oauth_*.json"))) if local_provider_creds: diff --git a/src/rotator_library/credential_tool.py b/src/rotator_library/credential_tool.py index 82c8b05..1949f13 100644 --- a/src/rotator_library/credential_tool.py +++ b/src/rotator_library/credential_tool.py @@ -36,6 +36,77 @@ def _ensure_providers_loaded(): _provider_plugins = pp return _provider_factory, _provider_plugins + +def _get_credential_number_from_filename(filename: str) -> int: + """ + Extract credential number from filename like 'provider_oauth_1.json' -> 1 + """ + match = re.search(r'_oauth_(\d+)\.json$', filename) + if match: + return int(match.group(1)) + return 1 + + +def _build_env_export_content( + provider_prefix: str, + cred_number: int, + creds: dict, + email: str, + extra_fields: dict = None, + include_client_creds: bool = True +) -> tuple[list[str], str]: + """ + Build .env content for OAuth credential export with numbered format. + Exports all fields from the JSON file as a 1-to-1 mirror. + + Args: + provider_prefix: Environment variable prefix (e.g., "ANTIGRAVITY", "GEMINI_CLI") + cred_number: Credential number for this export (1, 2, 3, etc.) + creds: The credential dictionary loaded from JSON + email: User email for comments + extra_fields: Optional dict of additional fields to include + include_client_creds: Whether to include client_id/secret (Google OAuth providers) + + Returns: + Tuple of (env_lines list, numbered_prefix string for display) + """ + # Use numbered format: PROVIDER_N_ACCESS_TOKEN + numbered_prefix = f"{provider_prefix}_{cred_number}" + + env_lines = [ + f"# {provider_prefix} Credential #{cred_number} for: {email}", + f"# Exported from: {provider_prefix.lower()}_oauth_{cred_number}.json", + f"# Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}", + f"# ", + f"# To combine multiple credentials into one .env file, copy these lines", + f"# and ensure each credential has a unique number (1, 2, 3, etc.)", + "", + f"{numbered_prefix}_ACCESS_TOKEN={creds.get('access_token', '')}", + f"{numbered_prefix}_REFRESH_TOKEN={creds.get('refresh_token', '')}", + f"{numbered_prefix}_SCOPE={creds.get('scope', '')}", + f"{numbered_prefix}_TOKEN_TYPE={creds.get('token_type', 'Bearer')}", + f"{numbered_prefix}_ID_TOKEN={creds.get('id_token', '')}", + f"{numbered_prefix}_EXPIRY_DATE={creds.get('expiry_date', 0)}", + ] + + if include_client_creds: + env_lines.extend([ + f"{numbered_prefix}_CLIENT_ID={creds.get('client_id', '')}", + f"{numbered_prefix}_CLIENT_SECRET={creds.get('client_secret', '')}", + f"{numbered_prefix}_TOKEN_URI={creds.get('token_uri', 'https://oauth2.googleapis.com/token')}", + f"{numbered_prefix}_UNIVERSE_DOMAIN={creds.get('universe_domain', 'googleapis.com')}", + ]) + + env_lines.append(f"{numbered_prefix}_EMAIL={email}") + + # Add extra provider-specific fields + if extra_fields: + for key, value in extra_fields.items(): + if value: # Only add non-empty values + env_lines.append(f"{numbered_prefix}_{key}={value}") + + return env_lines, numbered_prefix + def ensure_env_defaults(): """ Ensures the .env file exists and contains essential default values like PROXY_API_KEY. @@ -98,7 +169,7 @@ async def setup_api_key(): # Discover custom providers and add them to the list # Note: gemini_cli is OAuth-only, but qwen_code and iflow support both OAuth and API keys _, PROVIDER_PLUGINS = _ensure_providers_loaded() - oauth_only_providers = {'gemini_cli'} + oauth_only_providers = {'gemini_cli', 'antigravity'} discovered_providers = { p.replace('_', ' ').title(): p.upper() + "_API_KEY" for p in PROVIDER_PLUGINS.keys() @@ -195,7 +266,8 @@ async def setup_new_credential(provider_name: str): oauth_friendly_names = { "gemini_cli": "Gemini CLI (OAuth)", "qwen_code": "Qwen Code (OAuth - also supports API keys)", - "iflow": "iFlow (OAuth - also supports API keys)" + "iflow": "iFlow (OAuth - also supports API keys)", + "antigravity": "Antigravity (OAuth)" } display_name = oauth_friendly_names.get(provider_name, provider_name.replace('_', ' ').title()) @@ -255,12 +327,12 @@ async def setup_new_credential(provider_name: str): async def export_gemini_cli_to_env(): """ Export a Gemini CLI credential JSON file to .env format. - Generates one .env file per credential. + Uses numbered format (GEMINI_CLI_1_*, GEMINI_CLI_2_*) for multiple credential support. """ console.print(Panel("[bold cyan]Export Gemini CLI Credential to .env[/bold cyan]", expand=False)) # Find all gemini_cli credentials - gemini_cli_files = list(OAUTH_BASE_DIR.glob("gemini_cli_oauth_*.json")) + gemini_cli_files = sorted(list(OAUTH_BASE_DIR.glob("gemini_cli_oauth_*.json"))) if not gemini_cli_files: console.print(Panel("No Gemini CLI credentials found. Please add one first using 'Add OAuth Credential'.", @@ -303,34 +375,30 @@ async def export_gemini_cli_to_env(): project_id = creds.get("_proxy_metadata", {}).get("project_id", "") tier = creds.get("_proxy_metadata", {}).get("tier", "") - # Generate .env file name + # Get credential number from filename + cred_number = _get_credential_number_from_filename(cred_file.name) + + # Generate .env file name with credential number safe_email = email.replace("@", "_at_").replace(".", "_") - env_filename = f"gemini_cli_{safe_email}.env" + env_filename = f"gemini_cli_{cred_number}_{safe_email}.env" env_filepath = OAUTH_BASE_DIR / env_filename - # Build .env content - env_lines = [ - f"# Gemini CLI Credential for: {email}", - f"# Generated from: {cred_file.name}", - f"# Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}", - "", - f"GEMINI_CLI_ACCESS_TOKEN={creds.get('access_token', '')}", - f"GEMINI_CLI_REFRESH_TOKEN={creds.get('refresh_token', '')}", - f"GEMINI_CLI_EXPIRY_DATE={creds.get('expiry_date', 0)}", - f"GEMINI_CLI_CLIENT_ID={creds.get('client_id', '')}", - f"GEMINI_CLI_CLIENT_SECRET={creds.get('client_secret', '')}", - f"GEMINI_CLI_TOKEN_URI={creds.get('token_uri', 'https://oauth2.googleapis.com/token')}", - f"GEMINI_CLI_UNIVERSE_DOMAIN={creds.get('universe_domain', 'googleapis.com')}", - f"GEMINI_CLI_EMAIL={email}", - ] - - # Add project_id if present + # Build extra fields + extra_fields = {} if project_id: - env_lines.append(f"GEMINI_CLI_PROJECT_ID={project_id}") - - # Add tier if present + extra_fields["PROJECT_ID"] = project_id if tier: - env_lines.append(f"GEMINI_CLI_TIER={tier}") + extra_fields["TIER"] = tier + + # Build .env content using helper + env_lines, numbered_prefix = _build_env_export_content( + provider_prefix="GEMINI_CLI", + cred_number=cred_number, + creds=creds, + email=email, + extra_fields=extra_fields, + include_client_creds=True + ) # Write to .env file with open(env_filepath, 'w') as f: @@ -338,11 +406,14 @@ async def export_gemini_cli_to_env(): success_text = Text.from_markup( f"Successfully exported credential to [bold yellow]'{env_filepath}'[/bold yellow]\n\n" - f"To use this credential:\n" - f"1. Copy [bold yellow]{env_filepath.name}[/bold yellow] to your deployment environment\n" - f"2. Load the variables: [bold cyan]export $(cat {env_filepath.name} | grep -v '^#' | xargs)[/bold cyan]\n" - f"3. Or source it: [bold cyan]source {env_filepath.name}[/bold cyan]\n" - f"4. The Gemini CLI provider will automatically use these environment variables" + f"[bold]Environment variable prefix:[/bold] [cyan]{numbered_prefix}_*[/cyan]\n\n" + f"[bold]To use this credential:[/bold]\n" + f"1. Copy the contents to your main .env file, OR\n" + f"2. Source it: [bold cyan]source {env_filepath.name}[/bold cyan] (Linux/Mac)\n" + f"3. Or on Windows: [bold cyan]Get-Content {env_filepath.name} | ForEach-Object {{ $_ -replace '^([^#].*)$', 'set $1' }} | cmd[/bold cyan]\n\n" + f"[bold]To combine multiple credentials:[/bold]\n" + f"Copy lines from multiple .env files into one file.\n" + f"Each credential uses a unique number ({numbered_prefix}_*)." ) console.print(Panel(success_text, style="bold green", title="Success")) else: @@ -402,22 +473,30 @@ async def export_qwen_code_to_env(): # Extract metadata email = creds.get("_proxy_metadata", {}).get("email", "unknown") - # Generate .env file name + # Get credential number from filename + cred_number = _get_credential_number_from_filename(cred_file.name) + + # Generate .env file name with credential number safe_email = email.replace("@", "_at_").replace(".", "_") - env_filename = f"qwen_code_{safe_email}.env" + env_filename = f"qwen_code_{cred_number}_{safe_email}.env" env_filepath = OAUTH_BASE_DIR / env_filename - # Build .env content + # Use numbered format: QWEN_CODE_N_* + numbered_prefix = f"QWEN_CODE_{cred_number}" + + # Build .env content (Qwen has different structure) env_lines = [ - f"# Qwen Code Credential for: {email}", - f"# Generated from: {cred_file.name}", + f"# QWEN_CODE Credential #{cred_number} for: {email}", f"# Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}", + f"# ", + f"# To combine multiple credentials into one .env file, copy these lines", + f"# and ensure each credential has a unique number (1, 2, 3, etc.)", "", - f"QWEN_CODE_ACCESS_TOKEN={creds.get('access_token', '')}", - f"QWEN_CODE_REFRESH_TOKEN={creds.get('refresh_token', '')}", - f"QWEN_CODE_EXPIRY_DATE={creds.get('expiry_date', 0)}", - f"QWEN_CODE_RESOURCE_URL={creds.get('resource_url', 'https://portal.qwen.ai/v1')}", - f"QWEN_CODE_EMAIL={email}", + f"{numbered_prefix}_ACCESS_TOKEN={creds.get('access_token', '')}", + f"{numbered_prefix}_REFRESH_TOKEN={creds.get('refresh_token', '')}", + f"{numbered_prefix}_EXPIRY_DATE={creds.get('expiry_date', 0)}", + f"{numbered_prefix}_RESOURCE_URL={creds.get('resource_url', 'https://portal.qwen.ai/v1')}", + f"{numbered_prefix}_EMAIL={email}", ] # Write to .env file @@ -426,11 +505,13 @@ async def export_qwen_code_to_env(): success_text = Text.from_markup( f"Successfully exported credential to [bold yellow]'{env_filepath}'[/bold yellow]\n\n" - f"To use this credential:\n" - f"1. Copy [bold yellow]{env_filepath.name}[/bold yellow] to your deployment environment\n" - f"2. Load the variables: [bold cyan]export $(cat {env_filepath.name} | grep -v '^#' | xargs)[/bold cyan]\n" - f"3. Or source it: [bold cyan]source {env_filepath.name}[/bold cyan]\n" - f"4. The Qwen Code provider will automatically use these environment variables" + f"[bold]Environment variable prefix:[/bold] [cyan]{numbered_prefix}_*[/cyan]\n\n" + f"[bold]To use this credential:[/bold]\n" + f"1. Copy the contents to your main .env file, OR\n" + f"2. Source it: [bold cyan]source {env_filepath.name}[/bold cyan] (Linux/Mac)\n\n" + f"[bold]To combine multiple credentials:[/bold]\n" + f"Copy lines from multiple .env files into one file.\n" + f"Each credential uses a unique number ({numbered_prefix}_*)." ) console.print(Panel(success_text, style="bold green", title="Success")) else: @@ -444,12 +525,12 @@ async def export_qwen_code_to_env(): async def export_iflow_to_env(): """ Export an iFlow credential JSON file to .env format. - Generates one .env file per credential. + Uses numbered format (IFLOW_1_*, IFLOW_2_*) for multiple credential support. """ console.print(Panel("[bold cyan]Export iFlow Credential to .env[/bold cyan]", expand=False)) # Find all iflow credentials - iflow_files = list(OAUTH_BASE_DIR.glob("iflow_oauth_*.json")) + iflow_files = sorted(list(OAUTH_BASE_DIR.glob("iflow_oauth_*.json"))) if not iflow_files: console.print(Panel("No iFlow credentials found. Please add one first using 'Add OAuth Credential'.", @@ -490,25 +571,32 @@ async def export_iflow_to_env(): # Extract metadata email = creds.get("_proxy_metadata", {}).get("email", "unknown") - # Generate .env file name + # Get credential number from filename + cred_number = _get_credential_number_from_filename(cred_file.name) + + # Generate .env file name with credential number safe_email = email.replace("@", "_at_").replace(".", "_") - env_filename = f"iflow_{safe_email}.env" + env_filename = f"iflow_{cred_number}_{safe_email}.env" env_filepath = OAUTH_BASE_DIR / env_filename - # Build .env content - # IMPORTANT: iFlow requires BOTH OAuth tokens AND the API key for API requests + # Use numbered format: IFLOW_N_* + numbered_prefix = f"IFLOW_{cred_number}" + + # Build .env content (iFlow has different structure with API key) env_lines = [ - f"# iFlow Credential for: {email}", - f"# Generated from: {cred_file.name}", + f"# IFLOW Credential #{cred_number} for: {email}", f"# Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}", + f"# ", + f"# To combine multiple credentials into one .env file, copy these lines", + f"# and ensure each credential has a unique number (1, 2, 3, etc.)", "", - f"IFLOW_ACCESS_TOKEN={creds.get('access_token', '')}", - f"IFLOW_REFRESH_TOKEN={creds.get('refresh_token', '')}", - f"IFLOW_API_KEY={creds.get('api_key', '')}", - f"IFLOW_EXPIRY_DATE={creds.get('expiry_date', '')}", - f"IFLOW_EMAIL={email}", - f"IFLOW_TOKEN_TYPE={creds.get('token_type', 'Bearer')}", - f"IFLOW_SCOPE={creds.get('scope', 'read write')}", + f"{numbered_prefix}_ACCESS_TOKEN={creds.get('access_token', '')}", + f"{numbered_prefix}_REFRESH_TOKEN={creds.get('refresh_token', '')}", + f"{numbered_prefix}_API_KEY={creds.get('api_key', '')}", + f"{numbered_prefix}_EXPIRY_DATE={creds.get('expiry_date', '')}", + f"{numbered_prefix}_EMAIL={email}", + f"{numbered_prefix}_TOKEN_TYPE={creds.get('token_type', 'Bearer')}", + f"{numbered_prefix}_SCOPE={creds.get('scope', 'read write')}", ] # Write to .env file @@ -517,11 +605,13 @@ async def export_iflow_to_env(): success_text = Text.from_markup( f"Successfully exported credential to [bold yellow]'{env_filepath}'[/bold yellow]\n\n" - f"To use this credential:\n" - f"1. Copy [bold yellow]{env_filepath.name}[/bold yellow] to your deployment environment\n" - f"2. Load the variables: [bold cyan]export $(cat {env_filepath.name} | grep -v '^#' | xargs)[/bold cyan]\n" - f"3. Or source it: [bold cyan]source {env_filepath.name}[/bold cyan]\n" - f"4. The iFlow provider will automatically use these environment variables" + f"[bold]Environment variable prefix:[/bold] [cyan]{numbered_prefix}_*[/cyan]\n\n" + f"[bold]To use this credential:[/bold]\n" + f"1. Copy the contents to your main .env file, OR\n" + f"2. Source it: [bold cyan]source {env_filepath.name}[/bold cyan] (Linux/Mac)\n\n" + f"[bold]To combine multiple credentials:[/bold]\n" + f"Copy lines from multiple .env files into one file.\n" + f"Each credential uses a unique number ({numbered_prefix}_*)." ) console.print(Panel(success_text, style="bold green", title="Success")) else: @@ -532,6 +622,479 @@ async def export_iflow_to_env(): console.print(Panel(f"An error occurred during export: {e}", style="bold red", title="Error")) +async def export_antigravity_to_env(): + """ + Export an Antigravity credential JSON file to .env format. + Uses numbered format (ANTIGRAVITY_1_*, ANTIGRAVITY_2_*) for multiple credential support. + """ + console.print(Panel("[bold cyan]Export Antigravity Credential to .env[/bold cyan]", expand=False)) + + # Find all antigravity credentials + antigravity_files = sorted(list(OAUTH_BASE_DIR.glob("antigravity_oauth_*.json"))) + + if not antigravity_files: + console.print(Panel("No Antigravity credentials found. Please add one first using 'Add OAuth Credential'.", + style="bold red", title="No Credentials")) + return + + # Display available credentials + cred_text = Text() + for i, cred_file in enumerate(antigravity_files): + try: + with open(cred_file, 'r') as f: + creds = json.load(f) + email = creds.get("_proxy_metadata", {}).get("email", "unknown") + cred_text.append(f" {i + 1}. {cred_file.name} ({email})\n") + except Exception as e: + cred_text.append(f" {i + 1}. {cred_file.name} (error reading: {e})\n") + + console.print(Panel(cred_text, title="Available Antigravity Credentials", style="bold blue")) + + choice = Prompt.ask( + Text.from_markup("[bold]Please select a credential to export or type [red]'b'[/red] to go back[/bold]"), + choices=[str(i + 1) for i in range(len(antigravity_files))] + ["b"], + show_choices=False + ) + + if choice.lower() == 'b': + return + + try: + choice_index = int(choice) - 1 + if 0 <= choice_index < len(antigravity_files): + cred_file = antigravity_files[choice_index] + + # Load the credential + with open(cred_file, 'r') as f: + creds = json.load(f) + + # Extract metadata + email = creds.get("_proxy_metadata", {}).get("email", "unknown") + + # Get credential number from filename + cred_number = _get_credential_number_from_filename(cred_file.name) + + # Generate .env file name with credential number + safe_email = email.replace("@", "_at_").replace(".", "_") + env_filename = f"antigravity_{cred_number}_{safe_email}.env" + env_filepath = OAUTH_BASE_DIR / env_filename + + # Build .env content using helper + env_lines, numbered_prefix = _build_env_export_content( + provider_prefix="ANTIGRAVITY", + cred_number=cred_number, + creds=creds, + email=email, + extra_fields=None, + include_client_creds=True + ) + + # Write to .env file + with open(env_filepath, 'w') as f: + f.write('\n'.join(env_lines)) + + success_text = Text.from_markup( + f"Successfully exported credential to [bold yellow]'{env_filepath}'[/bold yellow]\n\n" + f"[bold]Environment variable prefix:[/bold] [cyan]{numbered_prefix}_*[/cyan]\n\n" + f"[bold]To use this credential:[/bold]\n" + f"1. Copy the contents to your main .env file, OR\n" + f"2. Source it: [bold cyan]source {env_filepath.name}[/bold cyan] (Linux/Mac)\n" + f"3. Or on Windows: [bold cyan]Get-Content {env_filepath.name} | ForEach-Object {{ $_ -replace '^([^#].*)$', 'set $1' }} | cmd[/bold cyan]\n\n" + f"[bold]To combine multiple credentials:[/bold]\n" + f"Copy lines from multiple .env files into one file.\n" + f"Each credential uses a unique number ({numbered_prefix}_*)." + ) + console.print(Panel(success_text, style="bold green", title="Success")) + else: + console.print("[bold red]Invalid choice. Please try again.[/bold red]") + except ValueError: + console.print("[bold red]Invalid input. Please enter a number or 'b'.[/bold red]") + except Exception as e: + console.print(Panel(f"An error occurred during export: {e}", style="bold red", title="Error")) + + +def _build_gemini_cli_env_lines(creds: dict, cred_number: int) -> list[str]: + """Build .env lines for a Gemini CLI credential.""" + email = creds.get("_proxy_metadata", {}).get("email", "unknown") + project_id = creds.get("_proxy_metadata", {}).get("project_id", "") + tier = creds.get("_proxy_metadata", {}).get("tier", "") + + extra_fields = {} + if project_id: + extra_fields["PROJECT_ID"] = project_id + if tier: + extra_fields["TIER"] = tier + + env_lines, _ = _build_env_export_content( + provider_prefix="GEMINI_CLI", + cred_number=cred_number, + creds=creds, + email=email, + extra_fields=extra_fields, + include_client_creds=True + ) + return env_lines + + +def _build_qwen_code_env_lines(creds: dict, cred_number: int) -> list[str]: + """Build .env lines for a Qwen Code credential.""" + email = creds.get("_proxy_metadata", {}).get("email", "unknown") + numbered_prefix = f"QWEN_CODE_{cred_number}" + + env_lines = [ + f"# QWEN_CODE Credential #{cred_number} for: {email}", + f"# Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}", + "", + f"{numbered_prefix}_ACCESS_TOKEN={creds.get('access_token', '')}", + f"{numbered_prefix}_REFRESH_TOKEN={creds.get('refresh_token', '')}", + f"{numbered_prefix}_EXPIRY_DATE={creds.get('expiry_date', 0)}", + f"{numbered_prefix}_RESOURCE_URL={creds.get('resource_url', 'https://portal.qwen.ai/v1')}", + f"{numbered_prefix}_EMAIL={email}", + ] + return env_lines + + +def _build_iflow_env_lines(creds: dict, cred_number: int) -> list[str]: + """Build .env lines for an iFlow credential.""" + email = creds.get("_proxy_metadata", {}).get("email", "unknown") + numbered_prefix = f"IFLOW_{cred_number}" + + env_lines = [ + f"# IFLOW Credential #{cred_number} for: {email}", + f"# Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}", + "", + f"{numbered_prefix}_ACCESS_TOKEN={creds.get('access_token', '')}", + f"{numbered_prefix}_REFRESH_TOKEN={creds.get('refresh_token', '')}", + f"{numbered_prefix}_API_KEY={creds.get('api_key', '')}", + f"{numbered_prefix}_EXPIRY_DATE={creds.get('expiry_date', '')}", + f"{numbered_prefix}_EMAIL={email}", + f"{numbered_prefix}_TOKEN_TYPE={creds.get('token_type', 'Bearer')}", + f"{numbered_prefix}_SCOPE={creds.get('scope', 'read write')}", + ] + return env_lines + + +def _build_antigravity_env_lines(creds: dict, cred_number: int) -> list[str]: + """Build .env lines for an Antigravity credential.""" + email = creds.get("_proxy_metadata", {}).get("email", "unknown") + + env_lines, _ = _build_env_export_content( + provider_prefix="ANTIGRAVITY", + cred_number=cred_number, + creds=creds, + email=email, + extra_fields=None, + include_client_creds=True + ) + return env_lines + + +async def export_all_provider_credentials(provider_name: str): + """ + Export all credentials for a specific provider to individual .env files. + """ + provider_config = { + "gemini_cli": ("GEMINI_CLI", _build_gemini_cli_env_lines), + "qwen_code": ("QWEN_CODE", _build_qwen_code_env_lines), + "iflow": ("IFLOW", _build_iflow_env_lines), + "antigravity": ("ANTIGRAVITY", _build_antigravity_env_lines), + } + + if provider_name not in provider_config: + console.print(f"[bold red]Unknown provider: {provider_name}[/bold red]") + return + + prefix, build_func = provider_config[provider_name] + display_name = prefix.replace("_", " ").title() + + console.print(Panel(f"[bold cyan]Export All {display_name} Credentials[/bold cyan]", expand=False)) + + # Find all credentials for this provider + cred_files = sorted(list(OAUTH_BASE_DIR.glob(f"{provider_name}_oauth_*.json"))) + + if not cred_files: + console.print(Panel(f"No {display_name} credentials found.", style="bold red", title="No Credentials")) + return + + exported_count = 0 + for cred_file in cred_files: + try: + with open(cred_file, 'r') as f: + creds = json.load(f) + + email = creds.get("_proxy_metadata", {}).get("email", "unknown") + cred_number = _get_credential_number_from_filename(cred_file.name) + + # Generate .env file name + safe_email = email.replace("@", "_at_").replace(".", "_") + env_filename = f"{provider_name}_{cred_number}_{safe_email}.env" + env_filepath = OAUTH_BASE_DIR / env_filename + + # Build and write .env content + env_lines = build_func(creds, cred_number) + with open(env_filepath, 'w') as f: + f.write('\n'.join(env_lines)) + + console.print(f" ✓ Exported [cyan]{cred_file.name}[/cyan] → [yellow]{env_filename}[/yellow]") + exported_count += 1 + + except Exception as e: + console.print(f" ✗ Failed to export {cred_file.name}: {e}") + + console.print(Panel( + f"Successfully exported {exported_count}/{len(cred_files)} {display_name} credentials to individual .env files.", + style="bold green", title="Export Complete" + )) + + +async def combine_provider_credentials(provider_name: str): + """ + Combine all credentials for a specific provider into a single .env file. + """ + provider_config = { + "gemini_cli": ("GEMINI_CLI", _build_gemini_cli_env_lines), + "qwen_code": ("QWEN_CODE", _build_qwen_code_env_lines), + "iflow": ("IFLOW", _build_iflow_env_lines), + "antigravity": ("ANTIGRAVITY", _build_antigravity_env_lines), + } + + if provider_name not in provider_config: + console.print(f"[bold red]Unknown provider: {provider_name}[/bold red]") + return + + prefix, build_func = provider_config[provider_name] + display_name = prefix.replace("_", " ").title() + + console.print(Panel(f"[bold cyan]Combine All {display_name} Credentials[/bold cyan]", expand=False)) + + # Find all credentials for this provider + cred_files = sorted(list(OAUTH_BASE_DIR.glob(f"{provider_name}_oauth_*.json"))) + + if not cred_files: + console.print(Panel(f"No {display_name} credentials found.", style="bold red", title="No Credentials")) + return + + combined_lines = [ + f"# Combined {display_name} Credentials", + f"# Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}", + f"# Total credentials: {len(cred_files)}", + "#", + "# Copy all lines below into your main .env file", + "", + ] + + combined_count = 0 + for cred_file in cred_files: + try: + with open(cred_file, 'r') as f: + creds = json.load(f) + + cred_number = _get_credential_number_from_filename(cred_file.name) + env_lines = build_func(creds, cred_number) + + combined_lines.extend(env_lines) + combined_lines.append("") # Blank line between credentials + combined_count += 1 + + except Exception as e: + console.print(f" ✗ Failed to process {cred_file.name}: {e}") + + # Write combined file + combined_filename = f"{provider_name}_all_combined.env" + combined_filepath = OAUTH_BASE_DIR / combined_filename + + with open(combined_filepath, 'w') as f: + f.write('\n'.join(combined_lines)) + + console.print(Panel( + Text.from_markup( + f"Successfully combined {combined_count} {display_name} credentials into:\n" + f"[bold yellow]{combined_filepath}[/bold yellow]\n\n" + f"[bold]To use:[/bold] Copy the contents into your main .env file." + ), + style="bold green", title="Combine Complete" + )) + + +async def combine_all_credentials(): + """ + Combine ALL credentials from ALL providers into a single .env file. + """ + console.print(Panel("[bold cyan]Combine All Provider Credentials[/bold cyan]", expand=False)) + + provider_config = { + "gemini_cli": ("GEMINI_CLI", _build_gemini_cli_env_lines), + "qwen_code": ("QWEN_CODE", _build_qwen_code_env_lines), + "iflow": ("IFLOW", _build_iflow_env_lines), + "antigravity": ("ANTIGRAVITY", _build_antigravity_env_lines), + } + + combined_lines = [ + "# Combined All Provider Credentials", + f"# Generated at: {time.strftime('%Y-%m-%d %H:%M:%S')}", + "#", + "# Copy all lines below into your main .env file", + "", + ] + + total_count = 0 + provider_counts = {} + + for provider_name, (prefix, build_func) in provider_config.items(): + cred_files = sorted(list(OAUTH_BASE_DIR.glob(f"{provider_name}_oauth_*.json"))) + + if not cred_files: + continue + + display_name = prefix.replace("_", " ").title() + combined_lines.append(f"# ===== {display_name} Credentials =====") + combined_lines.append("") + + provider_count = 0 + for cred_file in cred_files: + try: + with open(cred_file, 'r') as f: + creds = json.load(f) + + cred_number = _get_credential_number_from_filename(cred_file.name) + env_lines = build_func(creds, cred_number) + + combined_lines.extend(env_lines) + combined_lines.append("") + provider_count += 1 + total_count += 1 + + except Exception as e: + console.print(f" ✗ Failed to process {cred_file.name}: {e}") + + provider_counts[display_name] = provider_count + + if total_count == 0: + console.print(Panel("No credentials found to combine.", style="bold red", title="No Credentials")) + return + + # Write combined file + combined_filename = "all_providers_combined.env" + combined_filepath = OAUTH_BASE_DIR / combined_filename + + with open(combined_filepath, 'w') as f: + f.write('\n'.join(combined_lines)) + + # Build summary + summary_lines = [f" • {name}: {count} credential(s)" for name, count in provider_counts.items()] + summary = "\n".join(summary_lines) + + console.print(Panel( + Text.from_markup( + f"Successfully combined {total_count} credentials from {len(provider_counts)} providers:\n" + f"{summary}\n\n" + f"[bold]Output file:[/bold] [yellow]{combined_filepath}[/yellow]\n\n" + f"[bold]To use:[/bold] Copy the contents into your main .env file." + ), + style="bold green", title="Combine Complete" + )) + + +async def export_credentials_submenu(): + """ + Submenu for credential export options. + """ + while True: + console.clear() + console.print(Panel("[bold cyan]Export Credentials to .env[/bold cyan]", title="--- API Key Proxy ---", expand=False)) + + console.print(Panel( + Text.from_markup( + "[bold]Individual Exports:[/bold]\n" + "1. Export Gemini CLI credential\n" + "2. Export Qwen Code credential\n" + "3. Export iFlow credential\n" + "4. Export Antigravity credential\n" + "\n" + "[bold]Bulk Exports (per provider):[/bold]\n" + "5. Export ALL Gemini CLI credentials\n" + "6. Export ALL Qwen Code credentials\n" + "7. Export ALL iFlow credentials\n" + "8. Export ALL Antigravity credentials\n" + "\n" + "[bold]Combine Credentials:[/bold]\n" + "9. Combine all Gemini CLI into one file\n" + "10. Combine all Qwen Code into one file\n" + "11. Combine all iFlow into one file\n" + "12. Combine all Antigravity into one file\n" + "13. Combine ALL providers into one file" + ), + title="Choose export option", + style="bold blue" + )) + + export_choice = Prompt.ask( + Text.from_markup("[bold]Please select an option or type [red]'b'[/red] to go back[/bold]"), + choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "b"], + show_choices=False + ) + + if export_choice.lower() == 'b': + break + + # Individual exports + if export_choice == "1": + await export_gemini_cli_to_env() + console.print("\n[dim]Press Enter to return to export menu...[/dim]") + input() + elif export_choice == "2": + await export_qwen_code_to_env() + console.print("\n[dim]Press Enter to return to export menu...[/dim]") + input() + elif export_choice == "3": + await export_iflow_to_env() + console.print("\n[dim]Press Enter to return to export menu...[/dim]") + input() + elif export_choice == "4": + await export_antigravity_to_env() + console.print("\n[dim]Press Enter to return to export menu...[/dim]") + input() + # Bulk exports (all credentials for a provider) + elif export_choice == "5": + await export_all_provider_credentials("gemini_cli") + console.print("\n[dim]Press Enter to return to export menu...[/dim]") + input() + elif export_choice == "6": + await export_all_provider_credentials("qwen_code") + console.print("\n[dim]Press Enter to return to export menu...[/dim]") + input() + elif export_choice == "7": + await export_all_provider_credentials("iflow") + console.print("\n[dim]Press Enter to return to export menu...[/dim]") + input() + elif export_choice == "8": + await export_all_provider_credentials("antigravity") + console.print("\n[dim]Press Enter to return to export menu...[/dim]") + input() + # Combine per provider + elif export_choice == "9": + await combine_provider_credentials("gemini_cli") + console.print("\n[dim]Press Enter to return to export menu...[/dim]") + input() + elif export_choice == "10": + await combine_provider_credentials("qwen_code") + console.print("\n[dim]Press Enter to return to export menu...[/dim]") + input() + elif export_choice == "11": + await combine_provider_credentials("iflow") + console.print("\n[dim]Press Enter to return to export menu...[/dim]") + input() + elif export_choice == "12": + await combine_provider_credentials("antigravity") + console.print("\n[dim]Press Enter to return to export menu...[/dim]") + input() + # Combine all providers + elif export_choice == "13": + await combine_all_credentials() + console.print("\n[dim]Press Enter to return to export menu...[/dim]") + input() + + async def main(clear_on_start=True): """ An interactive CLI tool to add new credentials. @@ -555,9 +1118,7 @@ async def main(clear_on_start=True): Text.from_markup( "1. Add OAuth Credential\n" "2. Add API Key\n" - "3. Export Gemini CLI credential to .env\n" - "4. Export Qwen Code credential to .env\n" - "5. Export iFlow credential to .env" + "3. Export Credentials" ), title="Choose credential type", style="bold blue" @@ -565,7 +1126,7 @@ async def main(clear_on_start=True): setup_type = Prompt.ask( Text.from_markup("[bold]Please select an option or type [red]'q'[/red] to quit[/bold]"), - choices=["1", "2", "3", "4", "5", "q"], + choices=["1", "2", "3", "q"], show_choices=False ) @@ -578,7 +1139,8 @@ async def main(clear_on_start=True): oauth_friendly_names = { "gemini_cli": "Gemini CLI (OAuth)", "qwen_code": "Qwen Code (OAuth - also supports API keys)", - "iflow": "iFlow (OAuth - also supports API keys)" + "iflow": "iFlow (OAuth - also supports API keys)", + "antigravity": "Antigravity (OAuth)", } provider_text = Text() @@ -620,19 +1182,7 @@ async def main(clear_on_start=True): input() elif setup_type == "3": - await export_gemini_cli_to_env() - console.print("\n[dim]Press Enter to return to main menu...[/dim]") - input() - - elif setup_type == "4": - await export_qwen_code_to_env() - console.print("\n[dim]Press Enter to return to main menu...[/dim]") - input() - - elif setup_type == "5": - await export_iflow_to_env() - console.print("\n[dim]Press Enter to return to main menu...[/dim]") - input() + await export_credentials_submenu() def run_credential_tool(from_launcher=False): """ diff --git a/src/rotator_library/error_handler.py b/src/rotator_library/error_handler.py index 5298aec..a3775f7 100644 --- a/src/rotator_library/error_handler.py +++ b/src/rotator_library/error_handler.py @@ -17,6 +17,42 @@ ) +def extract_retry_after_from_body(error_body: Optional[str]) -> Optional[int]: + """ + Extract the retry-after time from an API error response body. + + Handles various error formats including: + - Gemini CLI: "Your quota will reset after 39s." + - Generic: "quota will reset after 120s", "retry after 60s" + + Args: + error_body: The raw error response body + + Returns: + The retry time in seconds, or None if not found + """ + if not error_body: + return None + + # Pattern to match various "reset after Xs" or "retry after Xs" formats + patterns = [ + r"quota will reset after\s*(\d+)s", + r"reset after\s*(\d+)s", + r"retry after\s*(\d+)s", + r"try again in\s*(\d+)\s*seconds?", + ] + + for pattern in patterns: + match = re.search(pattern, error_body, re.IGNORECASE) + if match: + try: + return int(match.group(1)) + except (ValueError, IndexError): + continue + + return None + + class NoAvailableKeysError(Exception): """Raised when no API keys are available for a request after waiting.""" @@ -106,6 +142,8 @@ def get_retry_after(error: Exception) -> Optional[int]: r"wait for\s*(\d+)\s*seconds?", r'"retryDelay":\s*"(\d+)s"', r"x-ratelimit-reset:?\s*(\d+)", + r"quota will reset after\s*(\d+)s", # Gemini CLI rate limit format + r"reset after\s*(\d+)s", # Generic reset after format ] for pattern in patterns: diff --git a/src/rotator_library/model_info_service.py b/src/rotator_library/model_info_service.py new file mode 100644 index 0000000..0c577bc --- /dev/null +++ b/src/rotator_library/model_info_service.py @@ -0,0 +1,946 @@ +""" +Unified Model Registry + +Provides aggregated model metadata from external catalogs (OpenRouter, Models.dev) +for pricing calculations and the /v1/models endpoint. + +Data retrieval happens asynchronously post-startup to keep initialization fast. +""" + +import asyncio +import json +import logging +import os +import time +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple +from urllib.request import Request, urlopen +from urllib.error import URLError + +logger = logging.getLogger(__name__) + + +# ============================================================================ +# Data Structures +# ============================================================================ + +@dataclass +class ModelPricing: + """Token-level pricing information.""" + prompt: Optional[float] = None + completion: Optional[float] = None + cached_input: Optional[float] = None + cache_write: Optional[float] = None + + +@dataclass +class ModelLimits: + """Context and output token limits.""" + context_window: Optional[int] = None + max_output: Optional[int] = None + + +@dataclass +class ModelCapabilities: + """Feature flags for model capabilities.""" + tools: bool = False + functions: bool = False + reasoning: bool = False + vision: bool = False + system_prompt: bool = True + caching: bool = False + prefill: bool = False + + +@dataclass +class ModelMetadata: + """Complete model information record.""" + + model_id: str + display_name: str = "" + provider: str = "" + category: str = "chat" # chat, embedding, image, audio + + pricing: ModelPricing = field(default_factory=ModelPricing) + limits: ModelLimits = field(default_factory=ModelLimits) + capabilities: ModelCapabilities = field(default_factory=ModelCapabilities) + + input_types: List[str] = field(default_factory=lambda: ["text"]) + output_types: List[str] = field(default_factory=lambda: ["text"]) + + timestamp: int = field(default_factory=lambda: int(time.time())) + origin: str = "" + match_quality: str = "unknown" + + def as_api_response(self) -> Dict[str, Any]: + """Format for OpenAI-compatible /v1/models response.""" + response = { + "id": self.model_id, + "object": "model", + "created": self.timestamp, + "owned_by": self.provider or "proxy", + } + + # Pricing fields + if self.pricing.prompt is not None: + response["input_cost_per_token"] = self.pricing.prompt + if self.pricing.completion is not None: + response["output_cost_per_token"] = self.pricing.completion + if self.pricing.cached_input is not None: + response["cache_read_input_token_cost"] = self.pricing.cached_input + if self.pricing.cache_write is not None: + response["cache_creation_input_token_cost"] = self.pricing.cache_write + + # Limits + if self.limits.context_window: + response["max_input_tokens"] = self.limits.context_window + response["context_window"] = self.limits.context_window + if self.limits.max_output: + response["max_output_tokens"] = self.limits.max_output + + # Category and modalities + response["mode"] = self.category + response["supported_modalities"] = self.input_types + response["supported_output_modalities"] = self.output_types + + # Capability flags + response["capabilities"] = { + "tool_choice": self.capabilities.tools, + "function_calling": self.capabilities.functions, + "reasoning": self.capabilities.reasoning, + "vision": self.capabilities.vision, + "system_messages": self.capabilities.system_prompt, + "prompt_caching": self.capabilities.caching, + "assistant_prefill": self.capabilities.prefill, + } + + # Debug metadata + if self.origin: + response["_sources"] = [self.origin] + response["_match_type"] = self.match_quality + + return response + + def as_minimal(self) -> Dict[str, Any]: + """Minimal OpenAI format.""" + return { + "id": self.model_id, + "object": "model", + "created": self.timestamp, + "owned_by": self.provider or "proxy", + } + + def to_dict(self) -> Dict[str, Any]: + """Alias for as_api_response() - backward compatibility.""" + return self.as_api_response() + + def to_openai_format(self) -> Dict[str, Any]: + """Alias for as_minimal() - backward compatibility.""" + return self.as_minimal() + + # Backward-compatible property aliases + @property + def id(self) -> str: + return self.model_id + + @property + def name(self) -> str: + return self.display_name + + @property + def input_cost_per_token(self) -> Optional[float]: + return self.pricing.prompt + + @property + def output_cost_per_token(self) -> Optional[float]: + return self.pricing.completion + + @property + def cache_read_input_token_cost(self) -> Optional[float]: + return self.pricing.cached_input + + @property + def cache_creation_input_token_cost(self) -> Optional[float]: + return self.pricing.cache_write + + @property + def max_input_tokens(self) -> Optional[int]: + return self.limits.context_window + + @property + def max_output_tokens(self) -> Optional[int]: + return self.limits.max_output + + @property + def mode(self) -> str: + return self.category + + @property + def supported_modalities(self) -> List[str]: + return self.input_types + + @property + def supported_output_modalities(self) -> List[str]: + return self.output_types + + @property + def supports_tool_choice(self) -> bool: + return self.capabilities.tools + + @property + def supports_function_calling(self) -> bool: + return self.capabilities.functions + + @property + def supports_reasoning(self) -> bool: + return self.capabilities.reasoning + + @property + def supports_vision(self) -> bool: + return self.capabilities.vision + + @property + def supports_system_messages(self) -> bool: + return self.capabilities.system_prompt + + @property + def supports_prompt_caching(self) -> bool: + return self.capabilities.caching + + @property + def supports_assistant_prefill(self) -> bool: + return self.capabilities.prefill + + @property + def litellm_provider(self) -> str: + return self.provider + + @property + def created(self) -> int: + return self.timestamp + + @property + def _sources(self) -> List[str]: + return [self.origin] if self.origin else [] + + @property + def _match_type(self) -> str: + return self.match_quality + + +# ============================================================================ +# Data Source Adapters +# ============================================================================ + +class DataSourceAdapter: + """Base interface for external data sources.""" + + source_name: str = "unknown" + endpoint: str = "" + + def fetch(self) -> Dict[str, Dict]: + """Retrieve and normalize data. Returns {model_id: raw_data}.""" + raise NotImplementedError + + def _http_get(self, url: str, timeout: int = 30) -> Any: + """Execute HTTP GET with standard headers.""" + req = Request(url, headers={"User-Agent": "ModelRegistry/1.0"}) + with urlopen(req, timeout=timeout) as resp: + return json.loads(resp.read().decode("utf-8")) + + +class OpenRouterAdapter(DataSourceAdapter): + """Fetches model data from OpenRouter's public API.""" + + source_name = "openrouter" + endpoint = "https://openrouter.ai/api/v1/models" + + def fetch(self) -> Dict[str, Dict]: + try: + raw = self._http_get(self.endpoint) + entries = raw.get("data", []) + + catalog = {} + for entry in entries: + mid = entry.get("id") + if not mid: + continue + + full_id = f"openrouter/{mid}" + catalog[full_id] = self._normalize(entry) + + return catalog + except (URLError, json.JSONDecodeError, TimeoutError) as err: + raise ConnectionError(f"OpenRouter unavailable: {err}") from err + + def _normalize(self, raw: Dict) -> Dict: + """Transform OpenRouter schema to internal format.""" + prices = raw.get("pricing", {}) + arch = raw.get("architecture", {}) + top = raw.get("top_provider", {}) + params = raw.get("supported_parameters", []) + + tokenizer = arch.get("tokenizer", "") + category = "embedding" if "embedding" in tokenizer.lower() else "chat" + + return { + "name": raw.get("name", ""), + "prompt_cost": float(prices.get("prompt", 0)), + "completion_cost": float(prices.get("completion", 0)), + "cache_read_cost": float(prices.get("input_cache_read", 0)) or None, + "context": top.get("context_length", 0), + "max_out": top.get("max_completion_tokens", 0), + "category": category, + "inputs": arch.get("input_modalities", ["text"]), + "outputs": arch.get("output_modalities", ["text"]), + "has_tools": "tool_choice" in params or "tools" in params, + "has_functions": "tools" in params or "function_calling" in params, + "has_reasoning": "reasoning" in params, + "has_vision": "image" in arch.get("input_modalities", []), + "provider": "openrouter", + "source": "openrouter", + } + + +class ModelsDevAdapter(DataSourceAdapter): + """Fetches model data from Models.dev catalog.""" + + source_name = "modelsdev" + endpoint = "https://models.dev/api.json" + + def __init__(self, skip_providers: Optional[List[str]] = None): + self.skip_providers = skip_providers or [] + + def fetch(self) -> Dict[str, Dict]: + try: + raw = self._http_get(self.endpoint) + + catalog = {} + for provider_key, provider_block in raw.items(): + if not isinstance(provider_block, dict): + continue + if provider_key in self.skip_providers: + continue + + models_block = provider_block.get("models", {}) + if not isinstance(models_block, dict): + continue + + for model_key, model_data in models_block.items(): + if not isinstance(model_data, dict): + continue + + full_id = f"{provider_key}/{model_key}" + catalog[full_id] = self._normalize(model_data, provider_key) + + return catalog + except (URLError, json.JSONDecodeError, TimeoutError) as err: + raise ConnectionError(f"Models.dev unavailable: {err}") from err + + def _normalize(self, raw: Dict, provider_key: str) -> Dict: + """Transform Models.dev schema to internal format.""" + costs = raw.get("cost", {}) + mods = raw.get("modalities", {}) + lims = raw.get("limit", {}) + + outputs = mods.get("output", ["text"]) + if "image" in outputs: + category = "image" + elif "audio" in outputs: + category = "audio" + else: + category = "chat" + + # Models.dev uses per-million pricing, convert to per-token + divisor = 1_000_000 + + cache_read = costs.get("cache_read") + cache_write = costs.get("cache_write") + + return { + "name": raw.get("name", ""), + "prompt_cost": float(costs.get("input", 0)) / divisor, + "completion_cost": float(costs.get("output", 0)) / divisor, + "cache_read_cost": float(cache_read) / divisor if cache_read else None, + "cache_write_cost": float(cache_write) / divisor if cache_write else None, + "context": lims.get("context", 0), + "max_out": lims.get("output", 0), + "category": category, + "inputs": mods.get("input", ["text"]), + "outputs": outputs, + "has_tools": raw.get("tool_call", False), + "has_functions": raw.get("tool_call", False), + "has_reasoning": raw.get("reasoning", False), + "has_vision": "image" in mods.get("input", []), + "provider": provider_key, + "source": "modelsdev", + } + + +# ============================================================================ +# Lookup Index +# ============================================================================ + +class ModelIndex: + """Fast lookup structure for model ID resolution.""" + + def __init__(self): + self._by_full_id: Dict[str, str] = {} # normalized_id -> canonical_id + self._by_suffix: Dict[str, List[str]] = {} # short_name -> [canonical_ids] + + def clear(self): + """Reset the index.""" + self._by_full_id.clear() + self._by_suffix.clear() + + def entry_count(self) -> int: + """Return total number of suffix index entries.""" + return sum(len(v) for v in self._by_suffix.values()) + + def add(self, canonical_id: str): + """Index a canonical model ID for various lookup patterns.""" + self._by_full_id[canonical_id] = canonical_id + + segments = canonical_id.split("/") + if len(segments) >= 2: + # Index by everything after first segment + partial = "/".join(segments[1:]) + self._by_suffix.setdefault(partial, []).append(canonical_id) + + # Index by final segment only + if len(segments) >= 3: + tail = segments[-1] + self._by_suffix.setdefault(tail, []).append(canonical_id) + + def resolve(self, query: str) -> List[str]: + """Find all canonical IDs matching a query.""" + # Direct match + if query in self._by_full_id: + return [self._by_full_id[query]] + + # Try with openrouter prefix + prefixed = f"openrouter/{query}" + if prefixed in self._by_full_id: + return [self._by_full_id[prefixed]] + + # Extract search terms from query + search_keys = [] + parts = query.split("/") + if len(parts) >= 2: + search_keys.append("/".join(parts[1:])) + search_keys.append(parts[-1]) + else: + search_keys.append(query) + # Find matches + matches = [] + seen = set() + for key in search_keys: + for cid in self._by_suffix.get(key, []): + if cid not in seen: + seen.add(cid) + matches.append(cid) + + return matches + + +# ============================================================================ +# Data Merger +# ============================================================================ + +class DataMerger: + """Combines data from multiple sources into unified ModelMetadata.""" + + @staticmethod + def single(model_id: str, data: Dict, origin: str, quality: str) -> ModelMetadata: + """Create ModelMetadata from a single source record.""" + return ModelMetadata( + model_id=model_id, + display_name=data.get("name", model_id), + provider=data.get("provider", ""), + category=data.get("category", "chat"), + pricing=ModelPricing( + prompt=data.get("prompt_cost"), + completion=data.get("completion_cost"), + cached_input=data.get("cache_read_cost"), + cache_write=data.get("cache_write_cost"), + ), + limits=ModelLimits( + context_window=data.get("context") or None, + max_output=data.get("max_out") or None, + ), + capabilities=ModelCapabilities( + tools=data.get("has_tools", False), + functions=data.get("has_functions", False), + reasoning=data.get("has_reasoning", False), + vision=data.get("has_vision", False), + ), + input_types=data.get("inputs", ["text"]), + output_types=data.get("outputs", ["text"]), + origin=origin, + match_quality=quality, + ) + + @staticmethod + def combine(model_id: str, records: List[Tuple[Dict, str]], quality: str) -> ModelMetadata: + """Merge multiple source records into one ModelMetadata.""" + if len(records) == 1: + data, origin = records[0] + return DataMerger.single(model_id, data, origin, quality) + + # Aggregate pricing - use average + prompt_costs = [r[0]["prompt_cost"] for r in records if r[0].get("prompt_cost")] + comp_costs = [r[0]["completion_cost"] for r in records if r[0].get("completion_cost")] + cache_costs = [r[0]["cache_read_cost"] for r in records if r[0].get("cache_read_cost")] + + # Aggregate limits - use most common value + contexts = [r[0]["context"] for r in records if r[0].get("context")] + max_outs = [r[0]["max_out"] for r in records if r[0].get("max_out")] + + # Capabilities - OR logic (any source supporting = supported) + has_tools = any(r[0].get("has_tools") for r in records) + has_funcs = any(r[0].get("has_functions") for r in records) + has_reason = any(r[0].get("has_reasoning") for r in records) + has_vis = any(r[0].get("has_vision") for r in records) + + # Modalities - union + all_inputs = set() + all_outputs = set() + for r in records: + all_inputs.update(r[0].get("inputs", ["text"])) + all_outputs.update(r[0].get("outputs", ["text"])) + + # Category - majority vote + categories = [r[0].get("category", "chat") for r in records] + category = max(set(categories), key=categories.count) + + # Name - first non-empty + name = model_id + for r in records: + if r[0].get("name"): + name = r[0]["name"] + break + + origins = [r[1] for r in records] + + return ModelMetadata( + model_id=model_id, + display_name=name, + provider=records[0][0].get("provider", ""), + category=category, + pricing=ModelPricing( + prompt=sum(prompt_costs) / len(prompt_costs) if prompt_costs else None, + completion=sum(comp_costs) / len(comp_costs) if comp_costs else None, + cached_input=sum(cache_costs) / len(cache_costs) if cache_costs else None, + ), + limits=ModelLimits( + context_window=DataMerger._mode(contexts), + max_output=DataMerger._mode(max_outs), + ), + capabilities=ModelCapabilities( + tools=has_tools, + functions=has_funcs, + reasoning=has_reason, + vision=has_vis, + ), + input_types=list(all_inputs) or ["text"], + output_types=list(all_outputs) or ["text"], + origin=",".join(origins), + match_quality=quality, + ) + + @staticmethod + def _mode(values: List[int]) -> Optional[int]: + """Return most frequent value.""" + if not values: + return None + return max(set(values), key=values.count) + + +# ============================================================================ +# Main Registry Service +# ============================================================================ + +class ModelRegistry: + """ + Central registry for model metadata from external catalogs. + + Manages background data refresh and provides lookup/pricing APIs. + """ + + REFRESH_INTERVAL_DEFAULT = 6 * 60 * 60 # 6 hours + + def __init__( + self, + refresh_seconds: Optional[int] = None, + skip_modelsdev_providers: Optional[List[str]] = None, + ): + interval_env = os.getenv("MODEL_INFO_REFRESH_INTERVAL") + self._refresh_interval = refresh_seconds or ( + int(interval_env) if interval_env else self.REFRESH_INTERVAL_DEFAULT + ) + + # Configure adapters + self._adapters: List[DataSourceAdapter] = [ + OpenRouterAdapter(), + ModelsDevAdapter(skip_providers=skip_modelsdev_providers or []), + ] + + # Raw data stores + self._openrouter_store: Dict[str, Dict] = {} + self._modelsdev_store: Dict[str, Dict] = {} + + # Lookup infrastructure + self._index = ModelIndex() + self._result_cache: Dict[str, ModelMetadata] = {} + + # Async coordination + self._ready = asyncio.Event() + self._mutex = asyncio.Lock() + self._worker: Optional[asyncio.Task] = None + self._last_refresh: float = 0 + + # ---------- Lifecycle ---------- + + async def start(self): + """Begin background refresh worker.""" + if self._worker is None: + self._worker = asyncio.create_task(self._refresh_worker()) + logger.info( + "ModelRegistry started (refresh every %ds)", + self._refresh_interval + ) + + async def stop(self): + """Halt background worker.""" + if self._worker: + self._worker.cancel() + try: + await self._worker + except asyncio.CancelledError: + pass + self._worker = None + logger.info("ModelRegistry stopped") + + async def await_ready(self, timeout_secs: float = 30.0) -> bool: + """Block until initial data load completes.""" + try: + await asyncio.wait_for(self._ready.wait(), timeout=timeout_secs) + return True + except asyncio.TimeoutError: + logger.warning("ModelRegistry ready timeout after %.1fs", timeout_secs) + return False + + @property + def is_ready(self) -> bool: + return self._ready.is_set() + + # ---------- Background Worker ---------- + + async def _refresh_worker(self): + """Periodic refresh loop.""" + await self._load_all_sources() + self._ready.set() + + while True: + try: + await asyncio.sleep(self._refresh_interval) + logger.info("Scheduled registry refresh...") + await self._load_all_sources() + logger.info("Registry refresh complete") + except asyncio.CancelledError: + break + except Exception as ex: + logger.error("Registry refresh error: %s", ex) + + async def _load_all_sources(self): + """Fetch from all adapters concurrently.""" + loop = asyncio.get_event_loop() + + tasks = [ + loop.run_in_executor(None, adapter.fetch) + for adapter in self._adapters + ] + + results = await asyncio.gather(*tasks, return_exceptions=True) + + async with self._mutex: + for adapter, result in zip(self._adapters, results): + if isinstance(result, Exception): + logger.error("%s fetch failed: %s", adapter.source_name, result) + continue + + if adapter.source_name == "openrouter": + self._openrouter_store = result + logger.info("OpenRouter: %d models loaded", len(result)) + elif adapter.source_name == "modelsdev": + self._modelsdev_store = result + logger.info("Models.dev: %d models loaded", len(result)) + + self._rebuild_index() + self._last_refresh = time.time() + + def _rebuild_index(self): + """Reconstruct lookup index from current stores.""" + self._index.clear() + self._result_cache.clear() + + for model_id in self._openrouter_store: + self._index.add(model_id) + + for model_id in self._modelsdev_store: + self._index.add(model_id) + + # ---------- Query API ---------- + + def lookup(self, model_id: str) -> Optional[ModelMetadata]: + """ + Retrieve model metadata by ID. + + Matching strategy: + 1. Exact match against known IDs + 2. Fuzzy match by model name suffix + 3. Aggregate if multiple sources match + """ + if model_id in self._result_cache: + return self._result_cache[model_id] + + metadata = self._resolve_model(model_id) + if metadata: + self._result_cache[model_id] = metadata + return metadata + + def _resolve_model(self, model_id: str) -> Optional[ModelMetadata]: + """Build ModelMetadata by matching source data.""" + records: List[Tuple[Dict, str]] = [] + quality = "none" + + # Check exact matches first + or_key = f"openrouter/{model_id}" if not model_id.startswith("openrouter/") else model_id + if or_key in self._openrouter_store: + records.append((self._openrouter_store[or_key], f"openrouter:exact:{or_key}")) + quality = "exact" + + if model_id in self._modelsdev_store: + records.append((self._modelsdev_store[model_id], f"modelsdev:exact:{model_id}")) + quality = "exact" + + # Fall back to index search + if not records: + candidates = self._index.resolve(model_id) + for cid in candidates: + if cid in self._openrouter_store: + records.append((self._openrouter_store[cid], f"openrouter:fuzzy:{cid}")) + elif cid in self._modelsdev_store: + records.append((self._modelsdev_store[cid], f"modelsdev:fuzzy:{cid}")) + + if records: + quality = "fuzzy" + + if not records: + return None + + return DataMerger.combine(model_id, records, quality) + + def get_pricing(self, model_id: str) -> Optional[Dict[str, float]]: + """Extract just pricing info for cost calculations.""" + meta = self.lookup(model_id) + if not meta: + return None + + result = {} + if meta.pricing.prompt is not None: + result["input_cost_per_token"] = meta.pricing.prompt + if meta.pricing.completion is not None: + result["output_cost_per_token"] = meta.pricing.completion + if meta.pricing.cached_input is not None: + result["cache_read_input_token_cost"] = meta.pricing.cached_input + if meta.pricing.cache_write is not None: + result["cache_creation_input_token_cost"] = meta.pricing.cache_write + + return result if result else None + + def compute_cost( + self, + model_id: str, + input_tokens: int, + output_tokens: int, + cache_hit_tokens: int = 0, + cache_miss_tokens: int = 0, + ) -> Optional[float]: + """ + Calculate total request cost. + + Returns None if pricing unavailable. + """ + pricing = self.get_pricing(model_id) + if not pricing: + return None + + in_rate = pricing.get("input_cost_per_token") + out_rate = pricing.get("output_cost_per_token") + + if in_rate is None or out_rate is None: + return None + + total = (input_tokens * in_rate) + (output_tokens * out_rate) + + cache_read_rate = pricing.get("cache_read_input_token_cost") + if cache_read_rate and cache_hit_tokens: + total += cache_hit_tokens * cache_read_rate + + cache_write_rate = pricing.get("cache_creation_input_token_cost") + if cache_write_rate and cache_miss_tokens: + total += cache_miss_tokens * cache_write_rate + + return total + + def enrich_models(self, model_ids: List[str]) -> List[Dict[str, Any]]: + """ + Attach metadata to a list of model IDs. + + Used by /v1/models endpoint. + """ + enriched = [] + for mid in model_ids: + meta = self.lookup(mid) + if meta: + enriched.append(meta.as_api_response()) + else: + # Fallback minimal entry + enriched.append({ + "id": mid, + "object": "model", + "created": int(time.time()), + "owned_by": mid.split("/")[0] if "/" in mid else "unknown", + }) + return enriched + + def all_raw_models(self) -> Dict[str, Dict]: + """Return all raw source data (for debugging).""" + combined = {} + combined.update(self._openrouter_store) + combined.update(self._modelsdev_store) + return combined + + def diagnostics(self) -> Dict[str, Any]: + """Return service health/stats.""" + return { + "ready": self._ready.is_set(), + "last_refresh": self._last_refresh, + "openrouter_count": len(self._openrouter_store), + "modelsdev_count": len(self._modelsdev_store), + "cached_lookups": len(self._result_cache), + "index_entries": self._index.entry_count(), + "refresh_interval": self._refresh_interval, + } + + # ---------- Backward Compatibility Methods ---------- + + def get_model_info(self, model_id: str) -> Optional[ModelMetadata]: + """Alias for lookup() - backward compatibility.""" + return self.lookup(model_id) + + def get_cost_info(self, model_id: str) -> Optional[Dict[str, float]]: + """Alias for get_pricing() - backward compatibility.""" + return self.get_pricing(model_id) + + def calculate_cost( + self, + model_id: str, + prompt_tokens: int, + completion_tokens: int, + cache_read_tokens: int = 0, + cache_creation_tokens: int = 0, + ) -> Optional[float]: + """Alias for compute_cost() - backward compatibility.""" + return self.compute_cost( + model_id, prompt_tokens, completion_tokens, + cache_read_tokens, cache_creation_tokens + ) + + def enrich_model_list(self, model_ids: List[str]) -> List[Dict[str, Any]]: + """Alias for enrich_models() - backward compatibility.""" + return self.enrich_models(model_ids) + + def get_all_source_models(self) -> Dict[str, Dict]: + """Alias for all_raw_models() - backward compatibility.""" + return self.all_raw_models() + + def get_stats(self) -> Dict[str, Any]: + """Alias for diagnostics() - backward compatibility.""" + return self.diagnostics() + + def wait_for_ready(self, timeout: float = 30.0): + """Sync wrapper for await_ready() - for compatibility.""" + return self.await_ready(timeout) + + +# ============================================================================ +# Backward Compatibility Layer +# ============================================================================ + +# Alias for backward compatibility +ModelInfo = ModelMetadata +ModelInfoService = ModelRegistry + +# Global singleton +_registry_instance: Optional[ModelRegistry] = None + + +def get_model_info_service() -> ModelRegistry: + """Get or create the global registry instance.""" + global _registry_instance + if _registry_instance is None: + _registry_instance = ModelRegistry() + return _registry_instance + + +async def init_model_info_service() -> ModelRegistry: + """Initialize and start the global registry.""" + registry = get_model_info_service() + await registry.start() + return registry + + +# Compatibility shim - map old method names to new +class _CompatibilityWrapper: + """Provides old API method names for gradual migration.""" + + def __init__(self, registry: ModelRegistry): + self._reg = registry + + def get_model_info(self, model_id: str) -> Optional[ModelMetadata]: + return self._reg.lookup(model_id) + + def get_cost_info(self, model_id: str) -> Optional[Dict[str, float]]: + return self._reg.get_pricing(model_id) + + def calculate_cost( + self, model_id: str, prompt_tokens: int, completion_tokens: int, + cache_read_tokens: int = 0, cache_creation_tokens: int = 0 + ) -> Optional[float]: + return self._reg.compute_cost( + model_id, prompt_tokens, completion_tokens, + cache_read_tokens, cache_creation_tokens + ) + + def enrich_model_list(self, model_ids: List[str]) -> List[Dict[str, Any]]: + return self._reg.enrich_models(model_ids) + + def get_all_source_models(self) -> Dict[str, Dict]: + return self._reg.all_raw_models() + + def get_stats(self) -> Dict[str, Any]: + return self._reg.diagnostics() + + async def start(self): + await self._reg.start() + + async def stop(self): + await self._reg.stop() + + async def wait_for_ready(self, timeout: float = 30.0) -> bool: + return await self._reg.await_ready(timeout) + + def is_ready(self) -> bool: + return self._reg.is_ready diff --git a/src/rotator_library/provider_factory.py b/src/rotator_library/provider_factory.py index f53eabd..f13d16a 100644 --- a/src/rotator_library/provider_factory.py +++ b/src/rotator_library/provider_factory.py @@ -3,11 +3,13 @@ from .providers.gemini_auth_base import GeminiAuthBase from .providers.qwen_auth_base import QwenAuthBase from .providers.iflow_auth_base import IFlowAuthBase +from .providers.antigravity_auth_base import AntigravityAuthBase PROVIDER_MAP = { "gemini_cli": GeminiAuthBase, "qwen_code": QwenAuthBase, "iflow": IFlowAuthBase, + "antigravity": AntigravityAuthBase, } def get_provider_auth_class(provider_name: str): diff --git a/src/rotator_library/providers/__init__.py b/src/rotator_library/providers/__init__.py index 3541d11..c6bee07 100644 --- a/src/rotator_library/providers/__init__.py +++ b/src/rotator_library/providers/__init__.py @@ -112,6 +112,8 @@ def _register_providers(): "chutes", "iflow", "qwen_code", + "gemini_cli", + "antigravity", ]: continue diff --git a/src/rotator_library/providers/antigravity_auth_base.py b/src/rotator_library/providers/antigravity_auth_base.py new file mode 100644 index 0000000..7240304 --- /dev/null +++ b/src/rotator_library/providers/antigravity_auth_base.py @@ -0,0 +1,24 @@ +# src/rotator_library/providers/antigravity_auth_base.py + +from .google_oauth_base import GoogleOAuthBase + +class AntigravityAuthBase(GoogleOAuthBase): + """ + Antigravity OAuth2 authentication implementation. + + Inherits all OAuth functionality from GoogleOAuthBase with Antigravity-specific configuration. + Uses Antigravity's OAuth credentials and includes additional scopes for cclog and experimentsandconfigs. + """ + + CLIENT_ID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" + CLIENT_SECRET = "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" + OAUTH_SCOPES = [ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/userinfo.email", + "https://www.googleapis.com/auth/userinfo.profile", + "https://www.googleapis.com/auth/cclog", # Antigravity-specific + "https://www.googleapis.com/auth/experimentsandconfigs", # Antigravity-specific + ] + ENV_PREFIX = "ANTIGRAVITY" + CALLBACK_PORT = 51121 + CALLBACK_PATH = "/oauthcallback" diff --git a/src/rotator_library/providers/antigravity_provider.py b/src/rotator_library/providers/antigravity_provider.py new file mode 100644 index 0000000..3f06b19 --- /dev/null +++ b/src/rotator_library/providers/antigravity_provider.py @@ -0,0 +1,2341 @@ +# src/rotator_library/providers/antigravity_provider_v2.py +""" +Antigravity Provider - Refactored Implementation + +A clean, well-structured provider for Google's Antigravity API, supporting: +- Gemini 2.5 (Pro/Flash) with thinkingBudget +- Gemini 3 (Pro/Image) with thinkingLevel +- Claude (Sonnet 4.5) via Antigravity proxy + +Key Features: +- Unified streaming/non-streaming handling +- Server-side thought signature caching +- Automatic base URL fallback +- Gemini 3 tool hallucination prevention +""" + +from __future__ import annotations + +import copy +import hashlib +import json +import logging +import os +import random +import time +import uuid +from datetime import datetime +from pathlib import Path +from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union +from urllib.parse import urlparse + +import httpx +import litellm + +from .provider_interface import ProviderInterface +from .antigravity_auth_base import AntigravityAuthBase +from .provider_cache import ProviderCache +from ..model_definitions import ModelDefinitions + + +# ============================================================================= +# CONFIGURATION CONSTANTS +# ============================================================================= + +lib_logger = logging.getLogger('rotator_library') + +# Antigravity base URLs with fallback order +# Priority: daily (sandbox) → autopush (sandbox) → production +BASE_URLS = [ + "https://daily-cloudcode-pa.sandbox.googleapis.com/v1internal", + "https://autopush-cloudcode-pa.sandbox.googleapis.com/v1internal", + "https://cloudcode-pa.googleapis.com/v1internal", # Production fallback +] + +# Available models via Antigravity +AVAILABLE_MODELS = [ + #"gemini-2.5-pro", + #"gemini-2.5-flash", + #"gemini-2.5-flash-lite", + "gemini-3-pro-preview", # Internally mapped to -low/-high variant based on thinkingLevel + #"gemini-3-pro-image-preview", + #"gemini-2.5-computer-use-preview-10-2025", + "claude-sonnet-4-5", # Internally mapped to -thinking variant when reasoning_effort is provided +] + +# Default max output tokens (including thinking) - can be overridden per request +DEFAULT_MAX_OUTPUT_TOKENS = 32384 + +# Model alias mappings (internal ↔ public) +MODEL_ALIAS_MAP = { + "rev19-uic3-1p": "gemini-2.5-computer-use-preview-10-2025", + "gemini-3-pro-image": "gemini-3-pro-image-preview", + "gemini-3-pro-low": "gemini-3-pro-preview", + "gemini-3-pro-high": "gemini-3-pro-preview", +} +MODEL_ALIAS_REVERSE = {v: k for k, v in MODEL_ALIAS_MAP.items()} + +# Models to exclude from dynamic discovery +EXCLUDED_MODELS = {"chat_20706", "chat_23310", "gemini-2.5-flash-thinking", "gemini-2.5-pro"} + +# Gemini finish reason mapping +FINISH_REASON_MAP = { + "STOP": "stop", + "MAX_TOKENS": "length", + "SAFETY": "content_filter", + "RECITATION": "content_filter", + "OTHER": "stop", +} + +# Default safety settings - disable content filtering for all categories +# Per CLIProxyAPI: these are attached to prevent safety blocks during API calls +DEFAULT_SAFETY_SETTINGS = [ + {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"}, + {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}, +] + +# Directory paths +_BASE_DIR = Path(__file__).resolve().parent.parent.parent.parent +LOGS_DIR = _BASE_DIR / "logs" / "antigravity_logs" +CACHE_DIR = _BASE_DIR / "cache" / "antigravity" +GEMINI3_SIGNATURE_CACHE_FILE = CACHE_DIR / "gemini3_signatures.json" +CLAUDE_THINKING_CACHE_FILE = CACHE_DIR / "claude_thinking.json" + +# Gemini 3 tool fix system instruction (prevents hallucination) +DEFAULT_GEMINI3_SYSTEM_INSTRUCTION = """ +You are operating in a CUSTOM ENVIRONMENT where tool definitions COMPLETELY DIFFER from your training data. +VIOLATION OF THESE RULES WILL CAUSE IMMEDIATE SYSTEM FAILURE. + +## ABSOLUTE RULES - NO EXCEPTIONS + +1. **SCHEMA IS LAW**: The JSON schema in each tool definition is the ONLY source of truth. + - Your pre-trained knowledge about tools like 'read_file', 'apply_diff', 'write_to_file', 'bash', etc. is INVALID here. + - Every tool has been REDEFINED with different parameters than what you learned during training. + +2. **PARAMETER NAMES ARE EXACT**: Use ONLY the parameter names from the schema. + - WRONG: 'suggested_answers', 'file_path', 'files_to_read', 'command_to_run' + - RIGHT: Check the 'properties' field in the schema for the exact names + - The schema's 'required' array tells you which parameters are mandatory + +3. **ARRAY PARAMETERS**: When a parameter has "type": "array", check the 'items' field: + - If items.type is "object", you MUST provide an array of objects with the EXACT properties listed + - If items.type is "string", you MUST provide an array of strings + - NEVER provide a single object when an array is expected + - NEVER provide an array when a single value is expected + +4. **NESTED OBJECTS**: When items.type is "object": + - Check items.properties for the EXACT field names required + - Check items.required for which nested fields are mandatory + - Include ALL required nested fields in EVERY array element + +5. **STRICT PARAMETERS HINT**: Tool descriptions contain "STRICT PARAMETERS: ..." which lists: + - Parameter name, type, and whether REQUIRED + - For arrays of objects: the nested structure in brackets like [field: type REQUIRED, ...] + - USE THIS as your quick reference, but the JSON schema is authoritative + +6. **BEFORE EVERY TOOL CALL**: + a. Read the tool's 'parametersJsonSchema' or 'parameters' field completely + b. Identify ALL required parameters + c. Verify your parameter names match EXACTLY (case-sensitive) + d. For arrays, verify you're providing the correct item structure + e. Do NOT add parameters that don't exist in the schema + +## COMMON FAILURE PATTERNS TO AVOID + +- Using 'path' when schema says 'filePath' (or vice versa) +- Using 'content' when schema says 'text' (or vice versa) +- Providing {"file": "..."} when schema wants [{"path": "...", "line_ranges": [...]}] +- Omitting required nested fields in array items +- Adding 'additionalProperties' that the schema doesn't define +- Guessing parameter names from similar tools you know from training + +## REMEMBER +Your training data about function calling is OUTDATED for this environment. +The tool names may look familiar, but the schemas are DIFFERENT. +When in doubt, RE-READ THE SCHEMA before making the call. + +""" + +# Claude tool fix system instruction (prevents hallucination) +DEFAULT_CLAUDE_SYSTEM_INSTRUCTION = """CRITICAL TOOL USAGE INSTRUCTIONS: +You are operating in a custom environment where tool definitions differ from your training data. +You MUST follow these rules strictly: + +1. DO NOT use your internal training data to guess tool parameters +2. ONLY use the exact parameter structure defined in the tool schema +3. Parameter names in schemas are EXACT - do not substitute with similar names from your training (e.g., use 'follow_up' not 'suggested_answers') +4. Array parameters have specific item types - check the schema's 'items' field for the exact structure +5. When you see "STRICT PARAMETERS" in a tool description, those type definitions override any assumptions +6. Tool use in agentic workflows is REQUIRED - you must call tools with the exact parameters specified in the schema + +If you are unsure about a tool's parameters, YOU MUST read the schema definition carefully. +""" + + +# ============================================================================= +# HELPER FUNCTIONS +# ============================================================================= + +def _env_bool(key: str, default: bool = False) -> bool: + """Get boolean from environment variable.""" + return os.getenv(key, str(default).lower()).lower() in ("true", "1", "yes") + + +def _env_int(key: str, default: int) -> int: + """Get integer from environment variable.""" + return int(os.getenv(key, str(default))) + + +def _generate_request_id() -> str: + """Generate Antigravity request ID: agent-{uuid}""" + return f"agent-{uuid.uuid4()}" + + +def _generate_session_id() -> str: + """Generate Antigravity session ID: -{random_number}""" + n = random.randint(1_000_000_000_000_000_000, 9_999_999_999_999_999_999) + return f"-{n}" + + +def _generate_project_id() -> str: + """Generate fake project ID: {adj}-{noun}-{random}""" + adjectives = ["useful", "bright", "swift", "calm", "bold"] + nouns = ["fuze", "wave", "spark", "flow", "core"] + return f"{random.choice(adjectives)}-{random.choice(nouns)}-{uuid.uuid4().hex[:5]}" + + +def _normalize_type_arrays(schema: Any) -> Any: + """ + Normalize type arrays in JSON Schema for Proto-based Antigravity API. + Converts `"type": ["string", "null"]` → `"type": "string"`. + """ + if isinstance(schema, dict): + normalized = {} + for key, value in schema.items(): + if key == "type" and isinstance(value, list): + non_null = [t for t in value if t != "null"] + normalized[key] = non_null[0] if non_null else value[0] + else: + normalized[key] = _normalize_type_arrays(value) + return normalized + elif isinstance(schema, list): + return [_normalize_type_arrays(item) for item in schema] + return schema + + +def _recursively_parse_json_strings(obj: Any) -> Any: + """ + Recursively parse JSON strings in nested data structures. + + Antigravity sometimes returns tool arguments with JSON-stringified values: + {"files": "[{...}]"} instead of {"files": [{...}]}. + + Additionally handles: + - Malformed double-encoded JSON (extra trailing '}' or ']') + - Escaped string content (\n, \t, \", etc.) + """ + if isinstance(obj, dict): + return {k: _recursively_parse_json_strings(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [_recursively_parse_json_strings(item) for item in obj] + elif isinstance(obj, str): + stripped = obj.strip() + + # Check if string contains common escape sequences that need unescaping + # This handles cases where diff content or other text has literal \n instead of newlines + if '\\n' in obj or '\\t' in obj or '\\"' in obj or '\\\\' in obj: + try: + # Use json.loads with quotes to properly unescape the string + # This converts \n -> newline, \t -> tab, \" -> quote, etc. + unescaped = json.loads(f'"{obj}"') + lib_logger.debug( + f"[Antigravity] Unescaped string content: " + f"{len(obj) - len(unescaped)} chars changed" + ) + return unescaped + except (json.JSONDecodeError, ValueError): + # If unescaping fails, continue with original processing + pass + + # Check if it looks like JSON (starts with { or [) + if stripped and stripped[0] in ('{', '['): + # Try standard parsing first + if (stripped.startswith('{') and stripped.endswith('}')) or \ + (stripped.startswith('[') and stripped.endswith(']')): + try: + parsed = json.loads(obj) + return _recursively_parse_json_strings(parsed) + except (json.JSONDecodeError, ValueError): + pass + + # Handle malformed JSON: array that doesn't end with ] + # e.g., '[{"path": "..."}]}' instead of '[{"path": "..."}]' + if stripped.startswith('[') and not stripped.endswith(']'): + try: + # Find the last ] and truncate there + last_bracket = stripped.rfind(']') + if last_bracket > 0: + cleaned = stripped[:last_bracket+1] + parsed = json.loads(cleaned) + lib_logger.warning( + f"[Antigravity] Auto-corrected malformed JSON string: " + f"truncated {len(stripped) - len(cleaned)} extra chars" + ) + return _recursively_parse_json_strings(parsed) + except (json.JSONDecodeError, ValueError): + pass + + # Handle malformed JSON: object that doesn't end with } + if stripped.startswith('{') and not stripped.endswith('}'): + try: + # Find the last } and truncate there + last_brace = stripped.rfind('}') + if last_brace > 0: + cleaned = stripped[:last_brace+1] + parsed = json.loads(cleaned) + lib_logger.warning( + f"[Antigravity] Auto-corrected malformed JSON string: " + f"truncated {len(stripped) - len(cleaned)} extra chars" + ) + return _recursively_parse_json_strings(parsed) + except (json.JSONDecodeError, ValueError): + pass + return obj + + +def _clean_claude_schema(schema: Any) -> Any: + """ + Recursively clean JSON Schema for Antigravity/Google's Proto-based API. + - Removes unsupported fields ($schema, additionalProperties, etc.) + - Converts 'const' to 'enum' with single value (supported equivalent) + - Converts 'anyOf'/'oneOf' to the first option (Claude doesn't support these) + """ + if not isinstance(schema, dict): + return schema + + # Fields not supported by Antigravity/Google's Proto-based API + # Note: Claude via Antigravity rejects JSON Schema draft 2020-12 validation keywords + incompatible = { + '$schema', 'additionalProperties', 'minItems', 'maxItems', 'pattern', + 'minLength', 'maxLength', 'minimum', 'maximum', 'default', + 'exclusiveMinimum', 'exclusiveMaximum', 'multipleOf', 'format', + 'minProperties', 'maxProperties', 'uniqueItems', 'contentEncoding', + 'contentMediaType', 'contentSchema', 'deprecated', 'readOnly', 'writeOnly', + 'examples', '$id', '$ref', '$defs', 'definitions', 'title', + } + + # Handle 'anyOf' by taking the first option (Claude doesn't support anyOf) + if 'anyOf' in schema and isinstance(schema['anyOf'], list) and schema['anyOf']: + first_option = _clean_claude_schema(schema['anyOf'][0]) + if isinstance(first_option, dict): + return first_option + + # Handle 'oneOf' similarly + if 'oneOf' in schema and isinstance(schema['oneOf'], list) and schema['oneOf']: + first_option = _clean_claude_schema(schema['oneOf'][0]) + if isinstance(first_option, dict): + return first_option + + + cleaned = {} + + # Handle 'const' by converting to 'enum' with single value + if 'const' in schema: + const_value = schema['const'] + cleaned['enum'] = [const_value] + + for key, value in schema.items(): + if key in incompatible or key == 'const': + continue + if isinstance(value, dict): + cleaned[key] = _clean_claude_schema(value) + elif isinstance(value, list): + cleaned[key] = [_clean_claude_schema(item) if isinstance(item, dict) else item for item in value] + else: + cleaned[key] = value + + return cleaned + + +# ============================================================================= +# FILE LOGGER +# ============================================================================= + +class AntigravityFileLogger: + """Transaction file logger for debugging Antigravity requests/responses.""" + + __slots__ = ('enabled', 'log_dir') + + def __init__(self, model_name: str, enabled: bool = True): + self.enabled = enabled + self.log_dir: Optional[Path] = None + + if not enabled: + return + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") + safe_model = model_name.replace('/', '_').replace(':', '_') + self.log_dir = LOGS_DIR / f"{timestamp}_{safe_model}_{uuid.uuid4()}" + + try: + self.log_dir.mkdir(parents=True, exist_ok=True) + except Exception as e: + lib_logger.error(f"Failed to create log directory: {e}") + self.enabled = False + + def log_request(self, payload: Dict[str, Any]) -> None: + """Log the request payload.""" + self._write_json("request_payload.json", payload) + + def log_response_chunk(self, chunk: str) -> None: + """Append a raw chunk to the response stream log.""" + self._append_text("response_stream.log", chunk) + + def log_error(self, error_message: str) -> None: + """Log an error message.""" + self._append_text("error.log", f"[{datetime.utcnow().isoformat()}] {error_message}") + + def log_final_response(self, response: Dict[str, Any]) -> None: + """Log the final response.""" + self._write_json("final_response.json", response) + + def _write_json(self, filename: str, data: Dict[str, Any]) -> None: + if not self.enabled or not self.log_dir: + return + try: + with open(self.log_dir / filename, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2, ensure_ascii=False) + except Exception as e: + lib_logger.error(f"Failed to write {filename}: {e}") + + def _append_text(self, filename: str, text: str) -> None: + if not self.enabled or not self.log_dir: + return + try: + with open(self.log_dir / filename, "a", encoding="utf-8") as f: + f.write(text + "\n") + except Exception as e: + lib_logger.error(f"Failed to append to {filename}: {e}") + + + + +# ============================================================================= +# MAIN PROVIDER CLASS +# ============================================================================= + +class AntigravityProvider(AntigravityAuthBase, ProviderInterface): + """ + Antigravity provider for Gemini and Claude models via Google's internal API. + + Supports: + - Gemini 2.5 (Pro/Flash) with thinkingBudget + - Gemini 3 (Pro/Image) with thinkingLevel + - Claude Sonnet 4.5 via Antigravity proxy + + Features: + - Unified streaming/non-streaming handling + - ThoughtSignature caching for multi-turn conversations + - Automatic base URL fallback + - Gemini 3 tool hallucination prevention + """ + + skip_cost_calculation = True + + def __init__(self): + super().__init__() + self.model_definitions = ModelDefinitions() + + # Base URL management + self._base_url_index = 0 + self._current_base_url = BASE_URLS[0] + + # Configuration from environment + memory_ttl = _env_int("ANTIGRAVITY_SIGNATURE_CACHE_TTL", 3600) + disk_ttl = _env_int("ANTIGRAVITY_SIGNATURE_DISK_TTL", 86400) + + # Initialize caches using shared ProviderCache + self._signature_cache = ProviderCache( + GEMINI3_SIGNATURE_CACHE_FILE, memory_ttl, disk_ttl, + env_prefix="ANTIGRAVITY_SIGNATURE" + ) + self._thinking_cache = ProviderCache( + CLAUDE_THINKING_CACHE_FILE, memory_ttl, disk_ttl, + env_prefix="ANTIGRAVITY_THINKING" + ) + + # Feature flags + self._preserve_signatures_in_client = _env_bool("ANTIGRAVITY_PRESERVE_THOUGHT_SIGNATURES", True) + self._enable_signature_cache = _env_bool("ANTIGRAVITY_ENABLE_SIGNATURE_CACHE", True) + self._enable_dynamic_models = _env_bool("ANTIGRAVITY_ENABLE_DYNAMIC_MODELS", False) + self._enable_gemini3_tool_fix = _env_bool("ANTIGRAVITY_GEMINI3_TOOL_FIX", True) + self._enable_claude_tool_fix = _env_bool("ANTIGRAVITY_CLAUDE_TOOL_FIX", True) + self._enable_thinking_sanitization = _env_bool("ANTIGRAVITY_CLAUDE_THINKING_SANITIZATION", True) + + # Gemini 3 tool fix configuration + self._gemini3_tool_prefix = os.getenv("ANTIGRAVITY_GEMINI3_TOOL_PREFIX", "gemini3_") + self._gemini3_description_prompt = os.getenv( + "ANTIGRAVITY_GEMINI3_DESCRIPTION_PROMPT", + "\n\n⚠️ STRICT PARAMETERS (use EXACTLY as shown): {params}. Do NOT use parameters from your training data - use ONLY these parameter names." + ) + self._gemini3_enforce_strict_schema = _env_bool("ANTIGRAVITY_GEMINI3_STRICT_SCHEMA", True) + self._gemini3_system_instruction = os.getenv( + "ANTIGRAVITY_GEMINI3_SYSTEM_INSTRUCTION", + DEFAULT_GEMINI3_SYSTEM_INSTRUCTION + ) + + # Claude tool fix configuration (separate from Gemini 3) + self._claude_description_prompt = os.getenv( + "ANTIGRAVITY_CLAUDE_DESCRIPTION_PROMPT", + "\n\nSTRICT PARAMETERS: {params}." + ) + self._claude_system_instruction = os.getenv( + "ANTIGRAVITY_CLAUDE_SYSTEM_INSTRUCTION", + DEFAULT_CLAUDE_SYSTEM_INSTRUCTION + ) + + # Log configuration + self._log_config() + + def _log_config(self) -> None: + """Log provider configuration.""" + lib_logger.debug( + f"Antigravity config: signatures_in_client={self._preserve_signatures_in_client}, " + f"cache={self._enable_signature_cache}, dynamic_models={self._enable_dynamic_models}, " + f"gemini3_fix={self._enable_gemini3_tool_fix}, gemini3_strict_schema={self._gemini3_enforce_strict_schema}, " + f"claude_fix={self._enable_claude_tool_fix}, thinking_sanitization={self._enable_thinking_sanitization}" + ) + + # ========================================================================= + # MODEL UTILITIES + # ========================================================================= + + def _alias_to_internal(self, alias: str) -> str: + """Convert public alias to internal model name.""" + return MODEL_ALIAS_REVERSE.get(alias, alias) + + def _internal_to_alias(self, internal: str) -> str: + """Convert internal model name to public alias.""" + if internal in EXCLUDED_MODELS: + return "" + return MODEL_ALIAS_MAP.get(internal, internal) + + def _is_gemini_3(self, model: str) -> bool: + """Check if model is Gemini 3 (requires special handling).""" + internal = self._alias_to_internal(model) + return internal.startswith("gemini-3-") or model.startswith("gemini-3-") + + def _is_claude(self, model: str) -> bool: + """Check if model is Claude.""" + return "claude" in model.lower() + + def _strip_provider_prefix(self, model: str) -> str: + """Strip provider prefix from model name.""" + return model.split("/")[-1] if "/" in model else model + + # ========================================================================= + # BASE URL MANAGEMENT + # ========================================================================= + + def _get_base_url(self) -> str: + """Get current base URL.""" + return self._current_base_url + + def _try_next_base_url(self) -> bool: + """Switch to next base URL in fallback list. Returns True if successful.""" + if self._base_url_index < len(BASE_URLS) - 1: + self._base_url_index += 1 + self._current_base_url = BASE_URLS[self._base_url_index] + lib_logger.info(f"Switching to fallback URL: {self._current_base_url}") + return True + return False + + def _reset_base_url(self) -> None: + """Reset to primary base URL.""" + self._base_url_index = 0 + self._current_base_url = BASE_URLS[0] + + # ========================================================================= + # THINKING CACHE KEY GENERATION + # ========================================================================= + + def _generate_thinking_cache_key( + self, + text_content: str, + tool_calls: List[Dict] + ) -> Optional[str]: + """ + Generate stable cache key from response content for Claude thinking preservation. + + Uses composite key: + - Tool call IDs (most stable) + - Text hash (for text-only responses) + """ + key_parts = [] + + if tool_calls: + first_id = tool_calls[0].get("id", "") + if first_id: + key_parts.append(f"tool_{first_id.replace('call_', '')}") + + if text_content: + text_hash = hashlib.md5(text_content[:200].encode()).hexdigest()[:16] + key_parts.append(f"text_{text_hash}") + + return "thinking_" + "_".join(key_parts) if key_parts else None + + # ========================================================================= + # THINKING MODE SANITIZATION + # ========================================================================= + + def _analyze_conversation_state( + self, + messages: List[Dict[str, Any]] + ) -> Dict[str, Any]: + """ + Analyze conversation state to detect tool use loops and thinking mode issues. + + Returns: + { + "in_tool_loop": bool - True if we're in an incomplete tool use loop + "last_assistant_idx": int - Index of last assistant message + "last_assistant_has_thinking": bool - Whether last assistant msg has thinking + "last_assistant_has_tool_calls": bool - Whether last assistant msg has tool calls + "pending_tool_results": bool - Whether there are tool results after last assistant + "thinking_block_indices": List[int] - Indices of messages with thinking/reasoning + } + """ + state = { + "in_tool_loop": False, + "last_assistant_idx": -1, + "last_assistant_has_thinking": False, + "last_assistant_has_tool_calls": False, + "pending_tool_results": False, + "thinking_block_indices": [], + } + + # Find last assistant message and analyze the conversation + for i, msg in enumerate(messages): + role = msg.get("role") + + if role == "assistant": + state["last_assistant_idx"] = i + state["last_assistant_has_tool_calls"] = bool(msg.get("tool_calls")) + # Check for thinking/reasoning content + has_thinking = bool(msg.get("reasoning_content")) + # Also check for thinking in content array (some formats) + content = msg.get("content") + if isinstance(content, list): + for item in content: + if isinstance(item, dict) and item.get("type") == "thinking": + has_thinking = True + break + state["last_assistant_has_thinking"] = has_thinking + if has_thinking: + state["thinking_block_indices"].append(i) + elif role == "tool": + # Tool result after an assistant message with tool calls = in tool loop + if state["last_assistant_has_tool_calls"]: + state["pending_tool_results"] = True + + # We're in a tool loop if: + # 1. Last assistant message had tool calls + # 2. There are tool results after it + # 3. There's no final text response yet (the conversation ends with tool results) + if state["pending_tool_results"] and messages: + last_msg = messages[-1] + if last_msg.get("role") == "tool": + state["in_tool_loop"] = True + + return state + + def _sanitize_thinking_for_claude( + self, + messages: List[Dict[str, Any]], + thinking_enabled: bool + ) -> Tuple[List[Dict[str, Any]], bool]: + """ + Sanitize thinking blocks in conversation history for Claude compatibility. + + Handles the following scenarios per Claude docs: + 1. If thinking is disabled, remove all thinking blocks from conversation + 2. If thinking is enabled: + a. In a tool use loop WITH thinking: preserve it (same mode continues) + b. In a tool use loop WITHOUT thinking: this is INVALID toggle - force disable + c. Not in tool loop: strip old thinking, new response adds thinking naturally + + Per Claude docs: + - "If thinking is enabled, the final assistant turn must start with a thinking block" + - "If thinking is disabled, the final assistant turn must not contain any thinking blocks" + - Tool use loops are part of a single assistant turn + - You CANNOT toggle thinking mid-turn + + The key insight: We only force-disable thinking when TOGGLING it ON mid-turn. + If thinking was already enabled (assistant has thinking), we preserve. + If thinking was disabled (assistant has no thinking), enabling it now is invalid. + + Returns: + Tuple of (sanitized_messages, force_disable_thinking) + - sanitized_messages: The cleaned message list + - force_disable_thinking: If True, thinking must be disabled for this request + """ + messages = copy.deepcopy(messages) + state = self._analyze_conversation_state(messages) + + lib_logger.debug( + f"[Thinking Sanitization] thinking_enabled={thinking_enabled}, " + f"in_tool_loop={state['in_tool_loop']}, " + f"last_assistant_has_thinking={state['last_assistant_has_thinking']}, " + f"last_assistant_has_tool_calls={state['last_assistant_has_tool_calls']}" + ) + + if not thinking_enabled: + # CASE 1: Thinking is disabled - strip ALL thinking blocks + return self._strip_all_thinking_blocks(messages), False + + # CASE 2: Thinking is enabled + if state["in_tool_loop"]: + # We're in a tool use loop (conversation ends with tool_result) + # Per Claude docs: entire assistant turn must operate in single thinking mode + + if state["last_assistant_has_thinking"]: + # Last assistant turn HAD thinking - this is valid! + # Thinking was enabled when tool was called, continue with thinking enabled. + # Only keep thinking for the current turn (last assistant + following tools) + lib_logger.debug( + "[Thinking Sanitization] Tool loop with existing thinking - preserving." + ) + return self._preserve_current_turn_thinking( + messages, state["last_assistant_idx"] + ), False + else: + # Last assistant turn DID NOT have thinking, but thinking is NOW enabled + # This is the INVALID case: toggling thinking ON mid-turn + # + # Per Claude docs, this causes: + # "Expected `thinking` or `redacted_thinking`, but found `tool_use`." + # + # SOLUTION: Inject a synthetic assistant message to CLOSE the tool loop. + # This allows Claude to start a fresh turn WITH thinking. + # + # The synthetic message summarizes the tool results, allowing the model + # to respond naturally with thinking enabled on what is now a "new" turn. + lib_logger.info( + "[Thinking Sanitization] Closing tool loop with synthetic response. " + "This allows thinking to be enabled on the new turn." + ) + return self._close_tool_loop_for_thinking(messages), False + else: + # Not in a tool loop - this is the simple case + # The conversation doesn't end with tool_result, so we're starting fresh. + # Strip thinking from old turns (API ignores them anyway). + # New response will include thinking naturally. + + if state["last_assistant_idx"] >= 0 and not state["last_assistant_has_thinking"]: + if state["last_assistant_has_tool_calls"]: + # Last assistant made tool calls but no thinking + # This could be from context compression, model switch, or + # the assistant responded after tool results (completing the turn) + lib_logger.debug( + "[Thinking Sanitization] Last assistant has completed tool_calls but no thinking. " + "This is likely from context compression or completed tool loop. " + "New response will include thinking." + ) + + # Strip thinking from old turns, let new response add thinking naturally + return self._strip_old_turn_thinking(messages, state["last_assistant_idx"]), False + + def _strip_all_thinking_blocks( + self, + messages: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """Remove all thinking/reasoning content from messages.""" + for msg in messages: + if msg.get("role") == "assistant": + # Remove reasoning_content field + msg.pop("reasoning_content", None) + + # Remove thinking blocks from content array + content = msg.get("content") + if isinstance(content, list): + filtered = [ + item for item in content + if not (isinstance(item, dict) and item.get("type") == "thinking") + ] + # If filtering leaves empty list, we need to preserve message structure + # to maintain user/assistant alternation. Use empty string as placeholder + # (will result in empty "text" part which is valid). + if not filtered: + # Only if there are no tool_calls either - otherwise message is valid + if not msg.get("tool_calls"): + msg["content"] = "" + else: + msg["content"] = None # tool_calls exist, content not needed + else: + msg["content"] = filtered + return messages + + def _strip_old_turn_thinking( + self, + messages: List[Dict[str, Any]], + last_assistant_idx: int + ) -> List[Dict[str, Any]]: + """ + Strip thinking from old turns but preserve for the last assistant turn. + + Per Claude docs: "thinking blocks from previous turns are removed from context" + This mimics the API behavior and prevents issues. + """ + for i, msg in enumerate(messages): + if msg.get("role") == "assistant" and i < last_assistant_idx: + # Old turn - strip thinking + msg.pop("reasoning_content", None) + content = msg.get("content") + if isinstance(content, list): + filtered = [ + item for item in content + if not (isinstance(item, dict) and item.get("type") == "thinking") + ] + # Preserve message structure with empty string if needed + if not filtered: + msg["content"] = "" if not msg.get("tool_calls") else None + else: + msg["content"] = filtered + return messages + + def _preserve_current_turn_thinking( + self, + messages: List[Dict[str, Any]], + last_assistant_idx: int + ) -> List[Dict[str, Any]]: + """ + Preserve thinking only for the current (last) assistant turn. + Strip from all previous turns. + """ + # Same as strip_old_turn_thinking - we keep the last turn intact + return self._strip_old_turn_thinking(messages, last_assistant_idx) + + def _close_tool_loop_for_thinking( + self, + messages: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """ + Close an incomplete tool loop by injecting a synthetic assistant response. + + This is used when: + - We're in a tool loop (conversation ends with tool_result) + - The tool call was made WITHOUT thinking (e.g., by Gemini or non-thinking Claude) + - We NOW want to enable thinking + + By injecting a synthetic response that "closes" the previous turn, + Claude can start a fresh turn with thinking enabled. + + The synthetic message is minimal and factual - it just acknowledges + the tool results were received, allowing the model to process them + with thinking on the new turn. + """ + # Strip any old thinking first + messages = self._strip_all_thinking_blocks(messages) + + # Collect tool results from the end of the conversation + tool_results = [] + for msg in reversed(messages): + if msg.get("role") == "tool": + tool_results.append(msg) + elif msg.get("role") == "assistant": + break # Stop at the assistant that made the tool calls + + tool_results.reverse() # Put back in order + + # Safety check: if no tool results found, this shouldn't have been called + # But handle gracefully with a generic message + if not tool_results: + lib_logger.warning( + "[Thinking Sanitization] _close_tool_loop_for_thinking called but no tool results found. " + "This may indicate malformed conversation history." + ) + synthetic_content = "[Processing previous context.]" + elif len(tool_results) == 1: + synthetic_content = "[Tool execution completed. Processing results.]" + else: + synthetic_content = f"[{len(tool_results)} tool executions completed. Processing results.]" + + # Inject the synthetic assistant message to close the loop + synthetic_msg = { + "role": "assistant", + "content": synthetic_content + } + messages.append(synthetic_msg) + + lib_logger.debug( + f"[Thinking Sanitization] Injected synthetic closure: '{synthetic_content}'" + ) + + return messages + + # ========================================================================= + # REASONING CONFIGURATION + # ========================================================================= + + def _get_thinking_config( + self, + reasoning_effort: Optional[str], + model: str, + custom_budget: bool = False + ) -> Optional[Dict[str, Any]]: + """ + Map reasoning_effort to thinking configuration. + + - Gemini 2.5 & Claude: thinkingBudget (integer tokens) + - Gemini 3: thinkingLevel (string: "low"/"high") + """ + internal = self._alias_to_internal(model) + is_gemini_25 = "gemini-2.5" in model + is_gemini_3 = internal.startswith("gemini-3-") + is_claude = self._is_claude(model) + + if not (is_gemini_25 or is_gemini_3 or is_claude): + return None + + # Gemini 3: String-based thinkingLevel + if is_gemini_3: + if reasoning_effort == "low": + return {"thinkingLevel": "low", "include_thoughts": True} + return {"thinkingLevel": "high", "include_thoughts": True} + + # Gemini 2.5 & Claude: Integer thinkingBudget + if not reasoning_effort: + return {"thinkingBudget": -1, "include_thoughts": True} # Auto + + if reasoning_effort == "disable": + return {"thinkingBudget": 0, "include_thoughts": False} + + # Model-specific budgets + if "gemini-2.5-pro" in model or is_claude: + budgets = {"low": 8192, "medium": 16384, "high": 32768} + elif "gemini-2.5-flash" in model: + budgets = {"low": 6144, "medium": 12288, "high": 24576} + else: + budgets = {"low": 1024, "medium": 2048, "high": 4096} + + budget = budgets.get(reasoning_effort, -1) + if not custom_budget: + budget = budget // 4 # Default to 25% of max output tokens + + return {"thinkingBudget": budget, "include_thoughts": True} + + # ========================================================================= + # MESSAGE TRANSFORMATION (OpenAI → Gemini) + # ========================================================================= + + def _transform_messages( + self, + messages: List[Dict[str, Any]], + model: str + ) -> Tuple[Optional[Dict[str, Any]], List[Dict[str, Any]]]: + """ + Transform OpenAI messages to Gemini CLI format. + + Handles: + - System instruction extraction + - Multi-part content (text, images) + - Tool calls and responses + - Claude thinking injection from cache + - Gemini 3 thoughtSignature preservation + """ + messages = copy.deepcopy(messages) + system_instruction = None + gemini_contents = [] + + # Extract system prompt + if messages and messages[0].get('role') == 'system': + system_content = messages.pop(0).get('content', '') + if system_content: + system_parts = self._parse_content_parts(system_content, _strip_cache_control=True) + if system_parts: + system_instruction = {"role": "user", "parts": system_parts} + + # Build tool_call_id → name mapping + tool_id_to_name = {} + for msg in messages: + if msg.get("role") == "assistant" and msg.get("tool_calls"): + for tc in msg["tool_calls"]: + if tc.get("type") == "function": + tc_id = tc["id"] + tc_name = tc["function"]["name"] + tool_id_to_name[tc_id] = tc_name + #lib_logger.debug(f"[ID Mapping] Registered tool_call: id={tc_id}, name={tc_name}") + + # Convert each message, consolidating consecutive tool responses + # Per Gemini docs: parallel function responses must be in a single user message + pending_tool_parts = [] + + for msg in messages: + role = msg.get("role") + content = msg.get("content") + parts = [] + + # Flush pending tool parts before non-tool message + if pending_tool_parts and role != "tool": + gemini_contents.append({"role": "user", "parts": pending_tool_parts}) + pending_tool_parts = [] + + if role == "user": + parts = self._transform_user_message(content) + elif role == "assistant": + parts = self._transform_assistant_message(msg, model, tool_id_to_name) + elif role == "tool": + tool_parts = self._transform_tool_message(msg, model, tool_id_to_name) + # Accumulate tool responses instead of adding individually + pending_tool_parts.extend(tool_parts) + continue + + if parts: + gemini_role = "model" if role == "assistant" else "user" + gemini_contents.append({"role": gemini_role, "parts": parts}) + + # Flush any remaining tool parts + if pending_tool_parts: + gemini_contents.append({"role": "user", "parts": pending_tool_parts}) + + return system_instruction, gemini_contents + + def _parse_content_parts( + self, + content: Any, + _strip_cache_control: bool = False + ) -> List[Dict[str, Any]]: + """Parse content into Gemini parts format.""" + parts = [] + + if isinstance(content, str): + if content: + parts.append({"text": content}) + elif isinstance(content, list): + for item in content: + if item.get("type") == "text": + text = item.get("text", "") + if text: + parts.append({"text": text}) + elif item.get("type") == "image_url": + image_part = self._parse_image_url(item.get("image_url", {})) + if image_part: + parts.append(image_part) + + return parts + + def _parse_image_url(self, image_url: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Parse image URL into Gemini inlineData format.""" + url = image_url.get("url", "") + if not url.startswith("data:"): + return None + + try: + header, data = url.split(",", 1) + mime_type = header.split(":")[1].split(";")[0] + return {"inlineData": {"mimeType": mime_type, "data": data}} + except Exception as e: + lib_logger.warning(f"Failed to parse image URL: {e}") + return None + + def _transform_user_message(self, content: Any) -> List[Dict[str, Any]]: + """Transform user message content to Gemini parts.""" + return self._parse_content_parts(content) + + def _transform_assistant_message( + self, + msg: Dict[str, Any], + model: str, + _tool_id_to_name: Dict[str, str] + ) -> List[Dict[str, Any]]: + """Transform assistant message including tool calls and thinking injection.""" + parts = [] + content = msg.get("content") + tool_calls = msg.get("tool_calls", []) + reasoning_content = msg.get("reasoning_content") + + # Handle reasoning_content if present (from original Claude response with thinking) + if reasoning_content and self._is_claude(model): + # Add thinking part with cached signature + thinking_part = { + "text": reasoning_content, + "thought": True, + } + # Try to get signature from cache + cache_key = self._generate_thinking_cache_key( + content if isinstance(content, str) else "", + tool_calls + ) + cached_sig = None + if cache_key: + cached_json = self._thinking_cache.retrieve(cache_key) + if cached_json: + try: + cached_data = json.loads(cached_json) + cached_sig = cached_data.get("thought_signature", "") + except json.JSONDecodeError: + pass + + if cached_sig: + thinking_part["thoughtSignature"] = cached_sig + parts.append(thinking_part) + lib_logger.debug(f"Added reasoning_content with cached signature ({len(reasoning_content)} chars)") + else: + # No cached signature - skip the thinking block + # This can happen if context was compressed and signature was lost + lib_logger.warning( + f"Skipping reasoning_content - no valid signature found. " + f"This may cause issues if thinking is enabled." + ) + elif self._is_claude(model) and self._enable_signature_cache and not reasoning_content: + # Fallback: Try to inject cached thinking for Claude (original behavior) + thinking_parts = self._get_cached_thinking(content, tool_calls) + parts.extend(thinking_parts) + + # Add regular content + if isinstance(content, str) and content: + parts.append({"text": content}) + + # Add tool calls + # Track if we've seen the first function call in this message + # Per Gemini docs: Only the FIRST parallel function call gets a signature + first_func_in_msg = True + for tc in tool_calls: + if tc.get("type") != "function": + continue + + try: + args = json.loads(tc["function"]["arguments"]) + except (json.JSONDecodeError, TypeError): + args = {} + + tool_id = tc.get("id", "") + func_name = tc["function"]["name"] + + #lib_logger.debug( + # f"[ID Transform] Converting assistant tool_call to functionCall: " + # f"id={tool_id}, name={func_name}" + #) + + # Add prefix for Gemini 3 + if self._is_gemini_3(model) and self._enable_gemini3_tool_fix: + func_name = f"{self._gemini3_tool_prefix}{func_name}" + + func_part = { + "functionCall": { + "name": func_name, + "args": args, + "id": tool_id + } + } + + # Add thoughtSignature for Gemini 3 + # Per Gemini docs: Only the FIRST parallel function call gets a signature. + # Subsequent parallel calls should NOT have a thoughtSignature field. + if self._is_gemini_3(model): + sig = tc.get("thought_signature") + if not sig and tool_id and self._enable_signature_cache: + sig = self._signature_cache.retrieve(tool_id) + + if sig: + func_part["thoughtSignature"] = sig + elif first_func_in_msg: + # Only add bypass to the first function call if no sig available + func_part["thoughtSignature"] = "skip_thought_signature_validator" + lib_logger.warning(f"Missing thoughtSignature for first func call {tool_id}, using bypass") + # Subsequent parallel calls: no signature field at all + + first_func_in_msg = False + + parts.append(func_part) + + # Safety: ensure we return at least one part to maintain role alternation + # This handles edge cases like assistant messages that had only thinking content + # which got stripped, leaving the message otherwise empty + if not parts: + # Use a minimal text part - can happen after thinking is stripped + parts.append({"text": ""}) + lib_logger.debug( + "[Transform] Added empty text part to maintain role alternation" + ) + + return parts + + def _get_cached_thinking( + self, + content: Any, + tool_calls: List[Dict] + ) -> List[Dict[str, Any]]: + """Retrieve and format cached thinking content for Claude.""" + parts = [] + msg_text = content if isinstance(content, str) else "" + cache_key = self._generate_thinking_cache_key(msg_text, tool_calls) + + if not cache_key: + return parts + + cached_json = self._thinking_cache.retrieve(cache_key) + if not cached_json: + return parts + + try: + thinking_data = json.loads(cached_json) + thinking_text = thinking_data.get("thinking_text", "") + sig = thinking_data.get("thought_signature", "") + + if thinking_text: + thinking_part = { + "text": thinking_text, + "thought": True, + "thoughtSignature": sig or "skip_thought_signature_validator" + } + parts.append(thinking_part) + lib_logger.debug(f"Injected {len(thinking_text)} chars of thinking") + except json.JSONDecodeError: + lib_logger.warning(f"Failed to parse cached thinking: {cache_key}") + + return parts + + def _transform_tool_message( + self, + msg: Dict[str, Any], + model: str, + tool_id_to_name: Dict[str, str] + ) -> List[Dict[str, Any]]: + """Transform tool response message.""" + tool_id = msg.get("tool_call_id", "") + func_name = tool_id_to_name.get(tool_id, "unknown_function") + content = msg.get("content", "{}") + + # Log ID lookup + if tool_id not in tool_id_to_name: + lib_logger.warning( + f"[ID Mismatch] Tool response has ID '{tool_id}' which was not found in tool_id_to_name map. " + f"Available IDs: {list(tool_id_to_name.keys())}" + ) + #else: + #lib_logger.debug(f"[ID Mapping] Tool response matched: id={tool_id}, name={func_name}") + + # Add prefix for Gemini 3 + if self._is_gemini_3(model) and self._enable_gemini3_tool_fix: + func_name = f"{self._gemini3_tool_prefix}{func_name}" + + try: + parsed_content = json.loads(content) + except (json.JSONDecodeError, TypeError): + parsed_content = content + + return [{ + "functionResponse": { + "name": func_name, + "response": {"result": parsed_content}, + "id": tool_id + } + }] + + # ========================================================================= + # TOOL RESPONSE GROUPING + # ========================================================================= + + def _fix_tool_response_grouping( + self, + contents: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """ + Group function calls with their responses for Antigravity compatibility. + + Converts linear format (call, response, call, response) + to grouped format (model with calls, user with all responses). + + IMPORTANT: Preserves ID-based pairing to prevent mismatches. + """ + new_contents = [] + pending_groups = [] # List of {"ids": [id1, id2, ...], "call_indices": [...]} + collected_responses = {} # Dict mapping ID -> response_part + + for content in contents: + role = content.get("role") + parts = content.get("parts", []) + + response_parts = [p for p in parts if "functionResponse" in p] + + if response_parts: + # Collect responses by ID (ignore duplicates - keep first occurrence) + for resp in response_parts: + resp_id = resp.get("functionResponse", {}).get("id", "") + if resp_id: + if resp_id in collected_responses: + lib_logger.warning( + f"[Grouping] Duplicate response ID detected: {resp_id}. " + f"Ignoring duplicate - this may indicate malformed conversation history." + ) + continue + #lib_logger.debug(f"[Grouping] Collected response for ID: {resp_id}") + collected_responses[resp_id] = resp + + # Try to satisfy pending groups (newest first) + for i in range(len(pending_groups) - 1, -1, -1): + group = pending_groups[i] + group_ids = group["ids"] + + # Check if we have ALL responses for this group + if all(gid in collected_responses for gid in group_ids): + # Extract responses in the same order as the function calls + group_responses = [collected_responses.pop(gid) for gid in group_ids] + new_contents.append({"parts": group_responses, "role": "user"}) + #lib_logger.debug( + # f"[Grouping] Satisfied group with {len(group_responses)} responses: " + # f"ids={group_ids}" + #) + pending_groups.pop(i) + break + continue + + if role == "model": + func_calls = [p for p in parts if "functionCall" in p] + new_contents.append(content) + if func_calls: + call_ids = [fc.get("functionCall", {}).get("id", "") for fc in func_calls] + call_ids = [cid for cid in call_ids if cid] # Filter empty IDs + if call_ids: + lib_logger.debug(f"[Grouping] Created pending group expecting {len(call_ids)} responses: ids={call_ids}") + pending_groups.append({"ids": call_ids, "call_indices": list(range(len(func_calls)))}) + else: + new_contents.append(content) + + # Handle remaining groups (shouldn't happen in well-formed conversations) + for group in pending_groups: + group_ids = group["ids"] + available_ids = [gid for gid in group_ids if gid in collected_responses] + if available_ids: + group_responses = [collected_responses.pop(gid) for gid in available_ids] + new_contents.append({"parts": group_responses, "role": "user"}) + lib_logger.warning( + f"[Grouping] Partial group satisfaction: expected {len(group_ids)}, " + f"got {len(available_ids)} responses" + ) + + # Warn about unmatched responses + if collected_responses: + lib_logger.warning( + f"[Grouping] {len(collected_responses)} unmatched responses remaining: " + f"ids={list(collected_responses.keys())}" + ) + + return new_contents + + # ========================================================================= + # GEMINI 3 TOOL TRANSFORMATIONS + # ========================================================================= + + def _apply_gemini3_namespace( + self, + tools: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """Add namespace prefix to tool names for Gemini 3.""" + if not tools: + return tools + + modified = copy.deepcopy(tools) + for tool in modified: + for func_decl in tool.get("functionDeclarations", []): + name = func_decl.get("name", "") + if name: + func_decl["name"] = f"{self._gemini3_tool_prefix}{name}" + + return modified + + def _enforce_strict_schema(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Enforce strict JSON schema for Gemini 3 to prevent hallucinated parameters. + + Adds 'additionalProperties: false' recursively to all object schemas, + which tells the model it CANNOT add properties not in the schema. + """ + if not tools: + return tools + + def enforce_strict(schema: Any) -> Any: + if not isinstance(schema, dict): + return schema + + result = {} + for key, value in schema.items(): + if isinstance(value, dict): + result[key] = enforce_strict(value) + elif isinstance(value, list): + result[key] = [enforce_strict(item) if isinstance(item, dict) else item for item in value] + else: + result[key] = value + + # Add additionalProperties: false to object schemas + if result.get("type") == "object" and "properties" in result: + result["additionalProperties"] = False + + return result + + modified = copy.deepcopy(tools) + for tool in modified: + for func_decl in tool.get("functionDeclarations", []): + if "parametersJsonSchema" in func_decl: + func_decl["parametersJsonSchema"] = enforce_strict(func_decl["parametersJsonSchema"]) + + return modified + + def _inject_signature_into_descriptions( + self, + tools: List[Dict[str, Any]], + description_prompt: Optional[str] = None + ) -> List[Dict[str, Any]]: + """Inject parameter signatures into tool descriptions for Gemini 3 & Claude.""" + if not tools: + return tools + + # Use provided prompt or default to Gemini 3 prompt + prompt_template = description_prompt or self._gemini3_description_prompt + + modified = copy.deepcopy(tools) + for tool in modified: + for func_decl in tool.get("functionDeclarations", []): + schema = func_decl.get("parametersJsonSchema", {}) + if not schema: + continue + + required = schema.get("required", []) + properties = schema.get("properties", {}) + + if not properties: + continue + + param_list = [] + for prop_name, prop_data in properties.items(): + if not isinstance(prop_data, dict): + continue + + type_hint = self._format_type_hint(prop_data) + is_required = prop_name in required + param_list.append( + f"{prop_name} ({type_hint}{', REQUIRED' if is_required else ''})" + ) + + if param_list: + sig_str = prompt_template.replace( + "{params}", ", ".join(param_list) + ) + func_decl["description"] = func_decl.get("description", "") + sig_str + + return modified + + def _format_type_hint(self, prop_data: Dict[str, Any], depth: int = 0) -> str: + """Format a detailed type hint for a property schema.""" + type_hint = prop_data.get("type", "unknown") + + # Handle enum values - show allowed options + if "enum" in prop_data: + enum_vals = prop_data["enum"] + if len(enum_vals) <= 5: + return f"string ENUM[{', '.join(repr(v) for v in enum_vals)}]" + return f"string ENUM[{len(enum_vals)} options]" + + # Handle const values + if "const" in prop_data: + return f"string CONST={repr(prop_data['const'])}" + + if type_hint == "array": + items = prop_data.get("items", {}) + if isinstance(items, dict): + item_type = items.get("type", "unknown") + if item_type == "object": + nested_props = items.get("properties", {}) + nested_req = items.get("required", []) + if nested_props: + nested_list = [] + for n, d in nested_props.items(): + if isinstance(d, dict): + # Recursively format nested types (limit depth) + if depth < 1: + t = self._format_type_hint(d, depth + 1) + else: + t = d.get("type", "unknown") + req = " REQUIRED" if n in nested_req else "" + nested_list.append(f"{n}: {t}{req}") + return f"ARRAY_OF_OBJECTS[{', '.join(nested_list)}]" + return "ARRAY_OF_OBJECTS" + return f"ARRAY_OF_{item_type.upper()}" + return "ARRAY" + + if type_hint == "object": + nested_props = prop_data.get("properties", {}) + nested_req = prop_data.get("required", []) + if nested_props and depth < 1: + nested_list = [] + for n, d in nested_props.items(): + if isinstance(d, dict): + t = d.get("type", "unknown") + req = " REQUIRED" if n in nested_req else "" + nested_list.append(f"{n}: {t}{req}") + return f"object{{{', '.join(nested_list)}}}" + + return type_hint + + def _strip_gemini3_prefix(self, name: str) -> str: + """Strip the Gemini 3 namespace prefix from a tool name.""" + if name and name.startswith(self._gemini3_tool_prefix): + return name[len(self._gemini3_tool_prefix):] + return name + + def _translate_tool_choice(self, tool_choice: Union[str, Dict[str, Any]], model: str = "") -> Optional[Dict[str, Any]]: + """ + Translates OpenAI's `tool_choice` to Gemini's `toolConfig`. + Handles Gemini 3 namespace prefixes for specific tool selection. + """ + if not tool_choice: + return None + + config = {} + mode = "AUTO" # Default to auto + is_gemini_3 = self._is_gemini_3(model) + + if isinstance(tool_choice, str): + if tool_choice == "auto": + mode = "AUTO" + elif tool_choice == "none": + mode = "NONE" + elif tool_choice == "required": + mode = "ANY" + elif isinstance(tool_choice, dict) and tool_choice.get("type") == "function": + function_name = tool_choice.get("function", {}).get("name") + if function_name: + # Add Gemini 3 prefix if needed + if is_gemini_3 and self._enable_gemini3_tool_fix: + function_name = f"{self._gemini3_tool_prefix}{function_name}" + + mode = "ANY" # Force a call, but only to this function + config["functionCallingConfig"] = { + "mode": mode, + "allowedFunctionNames": [function_name] + } + return config + + config["functionCallingConfig"] = {"mode": mode} + return config + + # ========================================================================= + # REQUEST TRANSFORMATION + # ========================================================================= + + def _build_tools_payload( + self, + tools: Optional[List[Dict[str, Any]]], + _model: str + ) -> Optional[List[Dict[str, Any]]]: + """Build Gemini-format tools from OpenAI tools.""" + if not tools: + return None + + gemini_tools = [] + for tool in tools: + if tool.get("type") != "function": + continue + + func = tool.get("function", {}) + params = func.get("parameters") + + func_decl = { + "name": func.get("name", ""), + "description": func.get("description", "") + } + + if params and isinstance(params, dict): + schema = dict(params) + schema.pop("$schema", None) + schema.pop("strict", None) + schema = _normalize_type_arrays(schema) + func_decl["parametersJsonSchema"] = schema + else: + func_decl["parametersJsonSchema"] = {"type": "object", "properties": {}} + + gemini_tools.append({"functionDeclarations": [func_decl]}) + + return gemini_tools or None + + def _transform_to_antigravity_format( + self, + gemini_payload: Dict[str, Any], + model: str, + max_tokens: Optional[int] = None, + reasoning_effort: Optional[str] = None, + tool_choice: Optional[Union[str, Dict[str, Any]]] = None + ) -> Dict[str, Any]: + """ + Transform Gemini CLI payload to complete Antigravity format. + + Args: + gemini_payload: Request in Gemini CLI format + model: Model name (public alias) + max_tokens: Max output tokens (including thinking) + reasoning_effort: Reasoning effort level (determines -thinking variant for Claude) + """ + internal_model = self._alias_to_internal(model) + + # Map base Claude model to -thinking variant when reasoning_effort is provided + if self._is_claude(internal_model) and reasoning_effort: + if internal_model == "claude-sonnet-4-5" and not internal_model.endswith("-thinking"): + internal_model = "claude-sonnet-4-5-thinking" + + # Map gemini-3-pro-preview to -low/-high variant based on thinking config + if model == "gemini-3-pro-preview" or internal_model == "gemini-3-pro-preview": + # Check thinking config to determine variant + thinking_config = gemini_payload.get("generationConfig", {}).get("thinkingConfig", {}) + thinking_level = thinking_config.get("thinkingLevel", "high") + if thinking_level == "low": + internal_model = "gemini-3-pro-low" + else: + internal_model = "gemini-3-pro-high" + + # Wrap in Antigravity envelope + antigravity_payload = { + "project": _generate_project_id(), + "userAgent": "antigravity", + "requestId": _generate_request_id(), + "model": internal_model, + "request": copy.deepcopy(gemini_payload) + } + + # Add session ID + antigravity_payload["request"]["sessionId"] = _generate_session_id() + + # Add default safety settings to prevent content filtering + # Only add if not already present in the payload + if "safetySettings" not in antigravity_payload["request"]: + antigravity_payload["request"]["safetySettings"] = copy.deepcopy(DEFAULT_SAFETY_SETTINGS) + + # Handle max_tokens - only apply to Claude, or if explicitly set for others + gen_config = antigravity_payload["request"].get("generationConfig", {}) + is_claude = self._is_claude(model) + + if max_tokens is not None: + # Explicitly set in request - apply to all models + gen_config["maxOutputTokens"] = max_tokens + elif is_claude: + # Claude model without explicit max_tokens - use default + gen_config["maxOutputTokens"] = DEFAULT_MAX_OUTPUT_TOKENS + # For non-Claude models without explicit max_tokens, don't set it + + antigravity_payload["request"]["generationConfig"] = gen_config + + # Set toolConfig based on tool_choice parameter + tool_config_result = self._translate_tool_choice(tool_choice, model) + if tool_config_result: + antigravity_payload["request"]["toolConfig"] = tool_config_result + else: + # Default to AUTO if no tool_choice specified + tool_config = antigravity_payload["request"].setdefault("toolConfig", {}) + func_config = tool_config.setdefault("functionCallingConfig", {}) + func_config["mode"] = "AUTO" + + # Handle Gemini 3 thinking logic + if not internal_model.startswith("gemini-3-"): + thinking_config = gen_config.get("thinkingConfig", {}) + if "thinkingLevel" in thinking_config: + del thinking_config["thinkingLevel"] + thinking_config["thinkingBudget"] = -1 + + # Ensure first function call in each model message has a thoughtSignature for Gemini 3 + # Per Gemini docs: Only the FIRST parallel function call gets a signature + if internal_model.startswith("gemini-3-"): + for content in antigravity_payload["request"].get("contents", []): + if content.get("role") == "model": + first_func_seen = False + for part in content.get("parts", []): + if "functionCall" in part: + if not first_func_seen: + # First function call in this message - needs a signature + if "thoughtSignature" not in part: + part["thoughtSignature"] = "skip_thought_signature_validator" + first_func_seen = True + # Subsequent parallel calls: leave as-is (no signature) + + # Claude-specific tool schema transformation + if internal_model.startswith("claude-sonnet-"): + self._apply_claude_tool_transform(antigravity_payload) + + return antigravity_payload + + def _apply_claude_tool_transform(self, payload: Dict[str, Any]) -> None: + """Apply Claude-specific tool schema transformations.""" + tools = payload["request"].get("tools", []) + for tool in tools: + for func_decl in tool.get("functionDeclarations", []): + if "parametersJsonSchema" in func_decl: + params = func_decl["parametersJsonSchema"] + params = _clean_claude_schema(params) if isinstance(params, dict) else params + func_decl["parameters"] = params + del func_decl["parametersJsonSchema"] + + # ========================================================================= + # RESPONSE TRANSFORMATION + # ========================================================================= + + def _unwrap_response(self, response: Dict[str, Any]) -> Dict[str, Any]: + """Extract Gemini response from Antigravity envelope.""" + return response.get("response", response) + + def _gemini_to_openai_chunk( + self, + chunk: Dict[str, Any], + model: str, + accumulator: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + """ + Convert Gemini response chunk to OpenAI streaming format. + + Args: + chunk: Gemini API response chunk + model: Model name + accumulator: Optional dict to accumulate data for post-processing + """ + candidates = chunk.get("candidates", []) + if not candidates: + return {} + + candidate = candidates[0] + content_parts = candidate.get("content", {}).get("parts", []) + + text_content = "" + reasoning_content = "" + tool_calls = [] + # Use accumulator's tool_idx if available, otherwise use local counter + tool_idx = accumulator.get("tool_idx", 0) if accumulator else 0 + + for part in content_parts: + has_func = "functionCall" in part + has_text = "text" in part + has_sig = bool(part.get("thoughtSignature")) + is_thought = part.get("thought") is True or str(part.get("thought")).lower() == 'true' + + # Accumulate signature for Claude caching + if has_sig and is_thought and accumulator is not None: + accumulator["thought_signature"] = part["thoughtSignature"] + + # Skip standalone signature parts + if has_sig and not has_func and (not has_text or not part.get("text")): + continue + + if has_text: + text = part["text"] + if is_thought: + reasoning_content += text + if accumulator is not None: + accumulator["reasoning_content"] += text + else: + text_content += text + if accumulator is not None: + accumulator["text_content"] += text + + if has_func: + tool_call = self._extract_tool_call(part, model, tool_idx, accumulator) + + # Store signature for each tool call (needed for parallel tool calls) + if has_sig: + self._handle_tool_signature(tool_call, part["thoughtSignature"]) + + tool_calls.append(tool_call) + tool_idx += 1 + + # Build delta + delta = {} + if text_content: + delta["content"] = text_content + if reasoning_content: + delta["reasoning_content"] = reasoning_content + if tool_calls: + delta["tool_calls"] = tool_calls + delta["role"] = "assistant" + # Update tool_idx for next chunk + if accumulator is not None: + accumulator["tool_idx"] = tool_idx + elif text_content or reasoning_content: + delta["role"] = "assistant" + + # Build usage if present + usage = self._build_usage(chunk.get("usageMetadata", {})) + + # Mark completion when we see usageMetadata + if chunk.get("usageMetadata") and accumulator is not None: + accumulator["is_complete"] = True + + # Build choice - just translate, don't include finish_reason + # Client will handle finish_reason logic + choice = {"index": 0, "delta": delta} + + response = { + "id": chunk.get("responseId", f"chatcmpl-{uuid.uuid4().hex[:24]}"), + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [choice] + } + + if usage: + response["usage"] = usage + + return response + + def _gemini_to_openai_non_streaming( + self, + response: Dict[str, Any], + model: str + ) -> Dict[str, Any]: + """Convert Gemini response to OpenAI non-streaming format.""" + candidates = response.get("candidates", []) + if not candidates: + return {} + + candidate = candidates[0] + content_parts = candidate.get("content", {}).get("parts", []) + + text_content = "" + reasoning_content = "" + tool_calls = [] + thought_sig = "" + + for part in content_parts: + has_func = "functionCall" in part + has_text = "text" in part + has_sig = bool(part.get("thoughtSignature")) + is_thought = part.get("thought") is True or str(part.get("thought")).lower() == 'true' + + if has_sig and is_thought: + thought_sig = part["thoughtSignature"] + + if has_sig and not has_func and (not has_text or not part.get("text")): + continue + + if has_text: + if is_thought: + reasoning_content += part["text"] + else: + text_content += part["text"] + + if has_func: + tool_call = self._extract_tool_call(part, model, len(tool_calls)) + + # Store signature for each tool call (needed for parallel tool calls) + if has_sig: + self._handle_tool_signature(tool_call, part["thoughtSignature"]) + + tool_calls.append(tool_call) + + # Cache Claude thinking + if reasoning_content and self._is_claude(model) and self._enable_signature_cache: + self._cache_thinking(reasoning_content, thought_sig, text_content, tool_calls) + + # Build message + message = {"role": "assistant"} + if text_content: + message["content"] = text_content + elif not tool_calls: + message["content"] = "" + if reasoning_content: + message["reasoning_content"] = reasoning_content + if tool_calls: + message["tool_calls"] = tool_calls + message.pop("content", None) + + finish_reason = self._map_finish_reason(candidate.get("finishReason"), bool(tool_calls)) + usage = self._build_usage(response.get("usageMetadata", {})) + + # For non-streaming, always include finish_reason (should always be present) + result = { + "id": response.get("responseId", f"chatcmpl-{uuid.uuid4().hex[:24]}"), + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": [{"index": 0, "message": message, "finish_reason": finish_reason or "stop"}] + } + + if usage: + result["usage"] = usage + + return result + + def _extract_tool_call( + self, + part: Dict[str, Any], + model: str, + index: int, + accumulator: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + """Extract and format a tool call from a response part.""" + func_call = part["functionCall"] + tool_id = func_call.get("id") or f"call_{uuid.uuid4().hex[:24]}" + + #lib_logger.debug(f"[ID Extraction] Extracting tool call: id={tool_id}, raw_id={func_call.get('id')}") + + tool_name = func_call.get("name", "") + if self._is_gemini_3(model) and self._enable_gemini3_tool_fix: + tool_name = self._strip_gemini3_prefix(tool_name) + + raw_args = func_call.get("args", {}) + parsed_args = _recursively_parse_json_strings(raw_args) + + tool_call = { + "id": tool_id, + "type": "function", + "index": index, + "function": { + "name": tool_name, + "arguments": json.dumps(parsed_args) + } + } + + if accumulator is not None: + accumulator["tool_calls"].append(tool_call) + + return tool_call + + def _handle_tool_signature(self, tool_call: Dict, signature: str) -> None: + """Handle thoughtSignature for a tool call.""" + tool_id = tool_call["id"] + + if self._enable_signature_cache: + self._signature_cache.store(tool_id, signature) + lib_logger.debug(f"Stored signature for {tool_id}") + + if self._preserve_signatures_in_client: + tool_call["thought_signature"] = signature + + def _map_finish_reason( + self, + gemini_reason: Optional[str], + has_tool_calls: bool + ) -> Optional[str]: + """Map Gemini finish reason to OpenAI format.""" + if not gemini_reason: + return None + reason = FINISH_REASON_MAP.get(gemini_reason, "stop") + return "tool_calls" if has_tool_calls else reason + + def _build_usage(self, metadata: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Build usage dict from Gemini usage metadata.""" + if not metadata: + return None + + prompt = metadata.get("promptTokenCount", 0) + thoughts = metadata.get("thoughtsTokenCount", 0) + completion = metadata.get("candidatesTokenCount", 0) + + usage = { + "prompt_tokens": prompt + thoughts, + "completion_tokens": completion, + "total_tokens": metadata.get("totalTokenCount", 0) + } + + if thoughts > 0: + usage["completion_tokens_details"] = {"reasoning_tokens": thoughts} + + return usage + + def _cache_thinking( + self, + reasoning: str, + signature: str, + text: str, + tool_calls: List[Dict] + ) -> None: + """Cache Claude thinking content.""" + cache_key = self._generate_thinking_cache_key(text, tool_calls) + if not cache_key: + return + + data = { + "thinking_text": reasoning, + "thought_signature": signature, + "text_preview": text[:100] if text else "", + "tool_ids": [tc.get("id", "") for tc in tool_calls], + "timestamp": time.time() + } + + self._thinking_cache.store(cache_key, json.dumps(data)) + lib_logger.info(f"Cached thinking: {cache_key[:50]}...") + + # ========================================================================= + # PROVIDER INTERFACE IMPLEMENTATION + # ========================================================================= + + async def get_valid_token(self, credential_identifier: str) -> str: + """Get a valid access token for the credential.""" + creds = await self._load_credentials(credential_identifier) + if self._is_token_expired(creds): + creds = await self._refresh_token(credential_identifier, creds) + return creds['access_token'] + + def has_custom_logic(self) -> bool: + """Antigravity uses custom translation logic.""" + return True + + async def get_auth_header(self, credential_identifier: str) -> Dict[str, str]: + """Get OAuth authorization header.""" + token = await self.get_valid_token(credential_identifier) + return {"Authorization": f"Bearer {token}"} + + async def get_models( + self, + api_key: str, + client: httpx.AsyncClient + ) -> List[str]: + """Fetch available models from Antigravity.""" + if not self._enable_dynamic_models: + lib_logger.debug("Using hardcoded model list") + return [f"antigravity/{m}" for m in AVAILABLE_MODELS] + + try: + token = await self.get_valid_token(api_key) + url = f"{self._get_base_url()}/fetchAvailableModels" + + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json" + } + payload = { + "project": _generate_project_id(), + "requestId": _generate_request_id(), + "userAgent": "antigravity" + } + + response = await client.post(url, json=payload, headers=headers, timeout=30.0) + response.raise_for_status() + data = response.json() + + models = [] + for model_info in data.get("models", []): + internal = model_info.get("name", "").replace("models/", "") + if internal: + public = self._internal_to_alias(internal) + if public: + models.append(f"antigravity/{public}") + + if models: + lib_logger.info(f"Discovered {len(models)} models") + return models + except Exception as e: + lib_logger.warning(f"Dynamic model discovery failed: {e}") + + return [f"antigravity/{m}" for m in AVAILABLE_MODELS] + + async def acompletion( + self, + client: httpx.AsyncClient, + **kwargs + ) -> Union[litellm.ModelResponse, AsyncGenerator[litellm.ModelResponse, None]]: + """ + Handle completion requests for Antigravity. + + Main entry point that: + 1. Extracts parameters and transforms messages + 2. Builds Antigravity request payload + 3. Makes API call with fallback logic + 4. Transforms response to OpenAI format + """ + # Extract parameters + model = self._strip_provider_prefix(kwargs.get("model", "gemini-2.5-pro")) + messages = kwargs.get("messages", []) + stream = kwargs.get("stream", False) + credential_path = kwargs.pop("credential_identifier", kwargs.get("api_key", "")) + tools = kwargs.get("tools") + tool_choice = kwargs.get("tool_choice") + reasoning_effort = kwargs.get("reasoning_effort") + top_p = kwargs.get("top_p") + temperature = kwargs.get("temperature") + max_tokens = kwargs.get("max_tokens") + custom_budget = kwargs.get("custom_reasoning_budget", False) + enable_logging = kwargs.pop("enable_request_logging", False) + + # Create logger + file_logger = AntigravityFileLogger(model, enable_logging) + + # Determine if thinking is enabled for this request + # Thinking is enabled if reasoning_effort is set (and not "disable") for Claude + thinking_enabled = False + if self._is_claude(model): + # For Claude, thinking is enabled when reasoning_effort is provided and not "disable" + thinking_enabled = reasoning_effort is not None and reasoning_effort != "disable" + + # Sanitize thinking blocks for Claude to prevent 400 errors + # This handles: context compression, model switching, mid-turn thinking toggle + # Returns (sanitized_messages, force_disable_thinking) + force_disable_thinking = False + if self._is_claude(model) and self._enable_thinking_sanitization: + messages, force_disable_thinking = self._sanitize_thinking_for_claude(messages, thinking_enabled) + + # If we're in a mid-turn thinking toggle situation, we MUST disable thinking + # for this request. Thinking will naturally resume on the next turn. + if force_disable_thinking: + thinking_enabled = False + reasoning_effort = "disable" # Force disable for this request + + # Transform messages + system_instruction, gemini_contents = self._transform_messages(messages, model) + gemini_contents = self._fix_tool_response_grouping(gemini_contents) + + # Build payload + gemini_payload = {"contents": gemini_contents} + + if system_instruction: + gemini_payload["system_instruction"] = system_instruction + + # Inject tool usage hardening system instructions + if tools: + if self._is_gemini_3(model) and self._enable_gemini3_tool_fix: + self._inject_tool_hardening_instruction(gemini_payload, self._gemini3_system_instruction) + elif self._is_claude(model) and self._enable_claude_tool_fix: + self._inject_tool_hardening_instruction(gemini_payload, self._claude_system_instruction) + + # Add generation config + gen_config = {} + if top_p is not None: + gen_config["topP"] = top_p + + # Handle temperature - Gemini 3 defaults to 1 if not explicitly set + if temperature is not None: + gen_config["temperature"] = temperature + elif self._is_gemini_3(model): + # Gemini 3 performs better with temperature=1 for tool use + gen_config["temperature"] = 1.0 + + thinking_config = self._get_thinking_config(reasoning_effort, model, custom_budget) + if thinking_config: + gen_config.setdefault("thinkingConfig", {}).update(thinking_config) + + if gen_config: + gemini_payload["generationConfig"] = gen_config + + # Add tools + gemini_tools = self._build_tools_payload(tools, model) + if gemini_tools: + gemini_payload["tools"] = gemini_tools + + # Apply tool transformations + if self._is_gemini_3(model) and self._enable_gemini3_tool_fix: + # Gemini 3: namespace prefix + strict schema + parameter signatures + gemini_payload["tools"] = self._apply_gemini3_namespace(gemini_payload["tools"]) + if self._gemini3_enforce_strict_schema: + gemini_payload["tools"] = self._enforce_strict_schema(gemini_payload["tools"]) + gemini_payload["tools"] = self._inject_signature_into_descriptions( + gemini_payload["tools"], + self._gemini3_description_prompt + ) + elif self._is_claude(model) and self._enable_claude_tool_fix: + # Claude: parameter signatures only (no namespace prefix) + gemini_payload["tools"] = self._inject_signature_into_descriptions( + gemini_payload["tools"], + self._claude_description_prompt + ) + + # Transform to Antigravity format + payload = self._transform_to_antigravity_format(gemini_payload, model, max_tokens, reasoning_effort, tool_choice) + file_logger.log_request(payload) + + # Make API call + token = await self.get_valid_token(credential_path) + base_url = self._get_base_url() + endpoint = ":streamGenerateContent" if stream else ":generateContent" + url = f"{base_url}{endpoint}" + + if stream: + url = f"{url}?alt=sse" + + parsed = urlparse(base_url) + host = parsed.netloc or base_url.replace("https://", "").replace("http://", "").rstrip("/") + + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + "Host": host, + "User-Agent": "antigravity/1.11.9", + "Accept": "text/event-stream" if stream else "application/json" + } + + try: + if stream: + return self._handle_streaming(client, url, headers, payload, model, file_logger) + else: + return await self._handle_non_streaming(client, url, headers, payload, model, file_logger) + except Exception as e: + if self._try_next_base_url(): + lib_logger.warning(f"Retrying with fallback URL: {e}") + url = f"{self._get_base_url()}{endpoint}" + if stream: + return self._handle_streaming(client, url, headers, payload, model, file_logger) + else: + return await self._handle_non_streaming(client, url, headers, payload, model, file_logger) + raise + + def _inject_tool_hardening_instruction(self, payload: Dict[str, Any], instruction_text: str) -> None: + """Inject tool usage hardening system instruction for Gemini 3 & Claude.""" + if not instruction_text: + return + + instruction_part = {"text": instruction_text} + + if "system_instruction" in payload: + existing = payload["system_instruction"] + if isinstance(existing, dict) and "parts" in existing: + existing["parts"].insert(0, instruction_part) + else: + payload["system_instruction"] = { + "role": "user", + "parts": [instruction_part, {"text": str(existing)}] + } + else: + payload["system_instruction"] = {"role": "user", "parts": [instruction_part]} + + async def _handle_non_streaming( + self, + client: httpx.AsyncClient, + url: str, + headers: Dict[str, str], + payload: Dict[str, Any], + model: str, + file_logger: Optional[AntigravityFileLogger] = None + ) -> litellm.ModelResponse: + """Handle non-streaming completion.""" + response = await client.post(url, headers=headers, json=payload, timeout=120.0) + response.raise_for_status() + + data = response.json() + if file_logger: + file_logger.log_final_response(data) + + gemini_response = self._unwrap_response(data) + openai_response = self._gemini_to_openai_non_streaming(gemini_response, model) + + return litellm.ModelResponse(**openai_response) + + async def _handle_streaming( + self, + client: httpx.AsyncClient, + url: str, + headers: Dict[str, str], + payload: Dict[str, Any], + model: str, + file_logger: Optional[AntigravityFileLogger] = None + ) -> AsyncGenerator[litellm.ModelResponse, None]: + """Handle streaming completion.""" + # Accumulator tracks state across chunks for caching and tool indexing + accumulator = { + "reasoning_content": "", + "thought_signature": "", + "text_content": "", + "tool_calls": [], + "tool_idx": 0, # Track tool call index across chunks + "is_complete": False # Track if we received usageMetadata + } + + async with client.stream("POST", url, headers=headers, json=payload, timeout=120.0) as response: + if response.status_code >= 400: + try: + error_body = await response.aread() + lib_logger.error(f"API error {response.status_code}: {error_body.decode()}") + except Exception: + pass + + response.raise_for_status() + + async for line in response.aiter_lines(): + if file_logger: + file_logger.log_response_chunk(line) + + if line.startswith("data: "): + data_str = line[6:] + if data_str == "[DONE]": + break + + try: + chunk = json.loads(data_str) + gemini_chunk = self._unwrap_response(chunk) + openai_chunk = self._gemini_to_openai_chunk(gemini_chunk, model, accumulator) + + yield litellm.ModelResponse(**openai_chunk) + except json.JSONDecodeError: + if file_logger: + file_logger.log_error(f"Parse error: {data_str[:100]}") + continue + + # If stream ended without usageMetadata chunk, emit a final chunk with finish_reason + # Emit final chunk if stream ended without usageMetadata + # Client will determine the correct finish_reason based on accumulated state + if not accumulator.get("is_complete"): + final_chunk = { + "id": f"chatcmpl-{uuid.uuid4().hex[:24]}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [{"index": 0, "delta": {}, "finish_reason": None}], + # Include minimal usage to signal this is the final chunk + "usage": {"prompt_tokens": 0, "completion_tokens": 1, "total_tokens": 1} + } + yield litellm.ModelResponse(**final_chunk) + + # Cache Claude thinking after stream completes + if self._is_claude(model) and self._enable_signature_cache and accumulator.get("reasoning_content"): + self._cache_thinking( + accumulator["reasoning_content"], + accumulator["thought_signature"], + accumulator["text_content"], + accumulator["tool_calls"] + ) + + async def count_tokens( + self, + client: httpx.AsyncClient, + credential_path: str, + model: str, + messages: List[Dict[str, Any]], + tools: Optional[List[Dict[str, Any]]] = None, + _litellm_params: Optional[Dict[str, Any]] = None + ) -> Dict[str, int]: + """Count tokens for the given prompt using Antigravity :countTokens endpoint.""" + try: + token = await self.get_valid_token(credential_path) + internal_model = self._alias_to_internal(model) + + system_instruction, contents = self._transform_messages(messages, internal_model) + + gemini_payload = {"contents": contents} + if system_instruction: + gemini_payload["systemInstruction"] = system_instruction + + gemini_tools = self._build_tools_payload(tools, model) + if gemini_tools: + gemini_payload["tools"] = gemini_tools + + antigravity_payload = { + "project": _generate_project_id(), + "userAgent": "antigravity", + "requestId": _generate_request_id(), + "model": internal_model, + "request": gemini_payload + } + + url = f"{self._get_base_url()}:countTokens" + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json" + } + + response = await client.post(url, headers=headers, json=antigravity_payload, timeout=30) + response.raise_for_status() + + data = response.json() + unwrapped = self._unwrap_response(data) + total = unwrapped.get('totalTokens', 0) + + return {'prompt_tokens': total, 'total_tokens': total} + except Exception as e: + lib_logger.error(f"Token counting failed: {e}") + return {'prompt_tokens': 0, 'total_tokens': 0} \ No newline at end of file diff --git a/src/rotator_library/providers/gemini_auth_base.py b/src/rotator_library/providers/gemini_auth_base.py index 6e8c1cc..90b9d9a 100644 --- a/src/rotator_library/providers/gemini_auth_base.py +++ b/src/rotator_library/providers/gemini_auth_base.py @@ -1,625 +1,21 @@ # src/rotator_library/providers/gemini_auth_base.py -import os -import webbrowser -from typing import Union, Optional -import json -import time -import asyncio -import logging -from pathlib import Path -from typing import Dict, Any -import tempfile -import shutil - -import httpx -from rich.console import Console -from rich.panel import Panel -from rich.text import Text - -from ..utils.headless_detection import is_headless_environment - -lib_logger = logging.getLogger('rotator_library') - -CLIENT_ID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" #https://api.kilocode.ai/extension-config.json -CLIENT_SECRET = "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" #https://api.kilocode.ai/extension-config.json -TOKEN_URI = "https://oauth2.googleapis.com/token" -USER_INFO_URI = "https://www.googleapis.com/oauth2/v1/userinfo" -REFRESH_EXPIRY_BUFFER_SECONDS = 30 * 60 # 30 minutes buffer before expiry - -console = Console() - -class GeminiAuthBase: - def __init__(self): - self._credentials_cache: Dict[str, Dict[str, Any]] = {} - self._refresh_locks: Dict[str, asyncio.Lock] = {} - self._locks_lock = asyncio.Lock() # Protects the locks dict from race conditions - # [BACKOFF TRACKING] Track consecutive failures per credential - self._refresh_failures: Dict[str, int] = {} # Track consecutive failures per credential - self._next_refresh_after: Dict[str, float] = {} # Track backoff timers (Unix timestamp) - - # [QUEUE SYSTEM] Sequential refresh processing - self._refresh_queue: asyncio.Queue = asyncio.Queue() - self._queued_credentials: set = set() # Track credentials already in queue - self._unavailable_credentials: set = set() # Mark credentials unavailable during re-auth - self._queue_tracking_lock = asyncio.Lock() # Protects queue sets - self._queue_processor_task: Optional[asyncio.Task] = None # Background worker task - - def _load_from_env(self) -> Optional[Dict[str, Any]]: - """ - Load OAuth credentials from environment variables for stateless deployments. - - Expected environment variables: - - GEMINI_CLI_ACCESS_TOKEN (required) - - GEMINI_CLI_REFRESH_TOKEN (required) - - GEMINI_CLI_EXPIRY_DATE (optional, defaults to 0) - - GEMINI_CLI_CLIENT_ID (optional, uses default) - - GEMINI_CLI_CLIENT_SECRET (optional, uses default) - - GEMINI_CLI_TOKEN_URI (optional, uses default) - - GEMINI_CLI_UNIVERSE_DOMAIN (optional, defaults to googleapis.com) - - GEMINI_CLI_EMAIL (optional, defaults to "env-user") - - GEMINI_CLI_PROJECT_ID (optional) - - GEMINI_CLI_TIER (optional) - - Returns: - Dict with credential structure if env vars present, None otherwise - """ - access_token = os.getenv("GEMINI_CLI_ACCESS_TOKEN") - refresh_token = os.getenv("GEMINI_CLI_REFRESH_TOKEN") - - # Both access and refresh tokens are required - if not (access_token and refresh_token): - return None - - lib_logger.debug("Loading Gemini CLI credentials from environment variables") - - # Parse expiry_date as float, default to 0 if not present - expiry_str = os.getenv("GEMINI_CLI_EXPIRY_DATE", "0") - try: - expiry_date = float(expiry_str) - except ValueError: - lib_logger.warning(f"Invalid GEMINI_CLI_EXPIRY_DATE value: {expiry_str}, using 0") - expiry_date = 0 - - creds = { - "access_token": access_token, - "refresh_token": refresh_token, - "expiry_date": expiry_date, - "client_id": os.getenv("GEMINI_CLI_CLIENT_ID", CLIENT_ID), - "client_secret": os.getenv("GEMINI_CLI_CLIENT_SECRET", CLIENT_SECRET), - "token_uri": os.getenv("GEMINI_CLI_TOKEN_URI", TOKEN_URI), - "universe_domain": os.getenv("GEMINI_CLI_UNIVERSE_DOMAIN", "googleapis.com"), - "_proxy_metadata": { - "email": os.getenv("GEMINI_CLI_EMAIL", "env-user"), - "last_check_timestamp": time.time(), - "loaded_from_env": True # Flag to indicate env-based credentials - } - } - - # Add project_id if provided - project_id = os.getenv("GEMINI_CLI_PROJECT_ID") - if project_id: - creds["_proxy_metadata"]["project_id"] = project_id - - # Add tier if provided - tier = os.getenv("GEMINI_CLI_TIER") - if tier: - creds["_proxy_metadata"]["tier"] = tier - - return creds - - async def _load_credentials(self, path: str) -> Dict[str, Any]: - if path in self._credentials_cache: - return self._credentials_cache[path] - - async with await self._get_lock(path): - if path in self._credentials_cache: - return self._credentials_cache[path] - - # First, try loading from environment variables - env_creds = self._load_from_env() - if env_creds: - lib_logger.info("Using Gemini CLI credentials from environment variables") - # Cache env-based credentials using the path as key - self._credentials_cache[path] = env_creds - return env_creds - - # Fall back to file-based loading - try: - lib_logger.debug(f"Loading Gemini credentials from file: {path}") - with open(path, 'r') as f: - creds = json.load(f) - # Handle gcloud-style creds file which nest tokens under "credential" - if "credential" in creds: - creds = creds["credential"] - self._credentials_cache[path] = creds - return creds - except FileNotFoundError: - raise IOError(f"Gemini OAuth credential file not found at '{path}'") - except Exception as e: - raise IOError(f"Failed to load Gemini OAuth credentials from '{path}': {e}") - - async def _save_credentials(self, path: str, creds: Dict[str, Any]): - # Don't save to file if credentials were loaded from environment - if creds.get("_proxy_metadata", {}).get("loaded_from_env"): - lib_logger.debug("Credentials loaded from env, skipping file save") - # Still update cache for in-memory consistency - self._credentials_cache[path] = creds - return - - # [ATOMIC WRITE] Use tempfile + move pattern to ensure atomic writes - # This prevents credential corruption if the process is interrupted during write - parent_dir = os.path.dirname(os.path.abspath(path)) - os.makedirs(parent_dir, exist_ok=True) - - tmp_fd = None - tmp_path = None - try: - # Create temp file in same directory as target (ensures same filesystem) - tmp_fd, tmp_path = tempfile.mkstemp(dir=parent_dir, prefix='.tmp_', suffix='.json', text=True) - - # Write JSON to temp file - with os.fdopen(tmp_fd, 'w') as f: - json.dump(creds, f, indent=2) - tmp_fd = None # fdopen closes the fd - - # Set secure permissions (0600 = owner read/write only) - try: - os.chmod(tmp_path, 0o600) - except (OSError, AttributeError): - # Windows may not support chmod, ignore - pass - - # Atomic move (overwrites target if it exists) - shutil.move(tmp_path, path) - tmp_path = None # Successfully moved - - # Update cache AFTER successful file write (prevents cache/file inconsistency) - self._credentials_cache[path] = creds - lib_logger.debug(f"Saved updated Gemini OAuth credentials to '{path}' (atomic write).") - - except Exception as e: - lib_logger.error(f"Failed to save updated Gemini OAuth credentials to '{path}': {e}") - # Clean up temp file if it still exists - if tmp_fd is not None: - try: - os.close(tmp_fd) - except: - pass - if tmp_path and os.path.exists(tmp_path): - try: - os.unlink(tmp_path) - except: - pass - raise - - def _is_token_expired(self, creds: Dict[str, Any]) -> bool: - expiry = creds.get("token_expiry") # gcloud format - if not expiry: # gemini-cli format - expiry_timestamp = creds.get("expiry_date", 0) / 1000 - else: - expiry_timestamp = time.mktime(time.strptime(expiry, "%Y-%m-%dT%H:%M:%SZ")) - return expiry_timestamp < time.time() + REFRESH_EXPIRY_BUFFER_SECONDS - - async def _refresh_token(self, path: str, creds: Dict[str, Any], force: bool = False) -> Dict[str, Any]: - async with await self._get_lock(path): - # Skip the expiry check if a refresh is being forced - if not force and not self._is_token_expired(self._credentials_cache.get(path, creds)): - return self._credentials_cache.get(path, creds) - - lib_logger.debug(f"Refreshing Gemini OAuth token for '{Path(path).name}' (forced: {force})...") - refresh_token = creds.get("refresh_token") - if not refresh_token: - raise ValueError("No refresh_token found in credentials file.") - - # [RETRY LOGIC] Implement exponential backoff for transient errors - max_retries = 3 - new_token_data = None - last_error = None - needs_reauth = False - - async with httpx.AsyncClient() as client: - for attempt in range(max_retries): - try: - response = await client.post(TOKEN_URI, data={ - "client_id": creds.get("client_id", CLIENT_ID), - "client_secret": creds.get("client_secret", CLIENT_SECRET), - "refresh_token": refresh_token, - "grant_type": "refresh_token", - }, timeout=30.0) - response.raise_for_status() - new_token_data = response.json() - break # Success, exit retry loop - - except httpx.HTTPStatusError as e: - last_error = e - status_code = e.response.status_code - - # [INVALID GRANT HANDLING] Handle 401/403 by triggering re-authentication - if status_code == 401 or status_code == 403: - lib_logger.warning( - f"Refresh token invalid for '{Path(path).name}' (HTTP {status_code}). " - f"Token may have been revoked or expired. Starting re-authentication..." - ) - needs_reauth = True - break # Exit retry loop to trigger re-auth - - elif status_code == 429: - # Rate limit - honor Retry-After header if present - retry_after = int(e.response.headers.get("Retry-After", 60)) - lib_logger.warning(f"Rate limited (HTTP 429), retry after {retry_after}s") - if attempt < max_retries - 1: - await asyncio.sleep(retry_after) - continue - raise - - elif status_code >= 500 and status_code < 600: - # Server error - retry with exponential backoff - if attempt < max_retries - 1: - wait_time = 2 ** attempt # 1s, 2s, 4s - lib_logger.warning(f"Server error (HTTP {status_code}), retry {attempt + 1}/{max_retries} in {wait_time}s") - await asyncio.sleep(wait_time) - continue - raise # Final attempt failed - - else: - # Other errors - don't retry - raise - - except (httpx.RequestError, httpx.TimeoutException) as e: - # Network errors - retry with backoff - last_error = e - if attempt < max_retries - 1: - wait_time = 2 ** attempt - lib_logger.warning(f"Network error during refresh: {e}, retry {attempt + 1}/{max_retries} in {wait_time}s") - await asyncio.sleep(wait_time) - continue - raise - - # [INVALID GRANT RE-AUTH] Trigger OAuth flow if refresh token is invalid - if needs_reauth: - lib_logger.info(f"Starting re-authentication for '{Path(path).name}'...") - try: - # Call initialize_token to trigger OAuth flow - new_creds = await self.initialize_token(path) - return new_creds - except Exception as reauth_error: - lib_logger.error(f"Re-authentication failed for '{Path(path).name}': {reauth_error}") - raise ValueError(f"Refresh token invalid and re-authentication failed: {reauth_error}") - - # If we exhausted retries without success - if new_token_data is None: - raise last_error or Exception("Token refresh failed after all retries") - - # [FIX 1] Update OAuth token fields from response - creds["access_token"] = new_token_data["access_token"] - expiry_timestamp = time.time() + new_token_data["expires_in"] - creds["expiry_date"] = expiry_timestamp * 1000 # gemini-cli format - - # [FIX 2] Update refresh_token if server provided a new one (rare but possible with Google OAuth) - if "refresh_token" in new_token_data: - creds["refresh_token"] = new_token_data["refresh_token"] - - # [FIX 3] Ensure all required OAuth client fields are present (restore if missing) - if "client_id" not in creds or not creds["client_id"]: - creds["client_id"] = CLIENT_ID - if "client_secret" not in creds or not creds["client_secret"]: - creds["client_secret"] = CLIENT_SECRET - if "token_uri" not in creds or not creds["token_uri"]: - creds["token_uri"] = TOKEN_URI - if "universe_domain" not in creds or not creds["universe_domain"]: - creds["universe_domain"] = "googleapis.com" - - # [FIX 4] Add scopes array if missing - if "scopes" not in creds: - creds["scopes"] = [ - "https://www.googleapis.com/auth/cloud-platform", - "https://www.googleapis.com/auth/userinfo.email", - "https://www.googleapis.com/auth/userinfo.profile", - ] - - # [FIX 5] Ensure _proxy_metadata exists and update timestamp - if "_proxy_metadata" not in creds: - creds["_proxy_metadata"] = {} - creds["_proxy_metadata"]["last_check_timestamp"] = time.time() - - # [VALIDATION] Verify refreshed credentials have all required fields - required_fields = ["access_token", "refresh_token", "client_id", "client_secret", "token_uri"] - missing_fields = [field for field in required_fields if not creds.get(field)] - if missing_fields: - raise ValueError(f"Refreshed credentials missing required fields: {missing_fields}") - - # [VALIDATION] Optional: Test that the refreshed token is actually usable - try: - async with httpx.AsyncClient() as client: - test_response = await client.get( - USER_INFO_URI, - headers={"Authorization": f"Bearer {creds['access_token']}"}, - timeout=5.0 - ) - test_response.raise_for_status() - lib_logger.debug(f"Token validation successful for '{Path(path).name}'") - except Exception as e: - lib_logger.warning(f"Refreshed token validation failed for '{Path(path).name}': {e}") - # Don't fail the refresh - the token might still work for other endpoints - # But log it for debugging purposes - - await self._save_credentials(path, creds) - lib_logger.debug(f"Successfully refreshed Gemini OAuth token for '{Path(path).name}'.") - return creds - - async def proactively_refresh(self, credential_path: str): - """Proactively refresh a credential by queueing it for refresh.""" - creds = await self._load_credentials(credential_path) - if self._is_token_expired(creds): - # Queue for refresh with needs_reauth=False (automated refresh) - await self._queue_refresh(credential_path, force=False, needs_reauth=False) - - async def _get_lock(self, path: str) -> asyncio.Lock: - # [FIX RACE CONDITION] Protect lock creation with a master lock - # This prevents TOCTOU bug where multiple coroutines check and create simultaneously - async with self._locks_lock: - if path not in self._refresh_locks: - self._refresh_locks[path] = asyncio.Lock() - return self._refresh_locks[path] - - def is_credential_available(self, path: str) -> bool: - """Check if a credential is available for rotation (not queued/refreshing).""" - return path not in self._unavailable_credentials - - async def _ensure_queue_processor_running(self): - """Lazily starts the queue processor if not already running.""" - if self._queue_processor_task is None or self._queue_processor_task.done(): - self._queue_processor_task = asyncio.create_task(self._process_refresh_queue()) - - async def _queue_refresh(self, path: str, force: bool = False, needs_reauth: bool = False): - """Add a credential to the refresh queue if not already queued. - - Args: - path: Credential file path - force: Force refresh even if not expired - needs_reauth: True if full re-authentication needed (bypasses backoff) - """ - # IMPORTANT: Only check backoff for simple automated refreshes - # Re-authentication (interactive OAuth) should BYPASS backoff since it needs user input - if not needs_reauth: - now = time.time() - if path in self._next_refresh_after: - backoff_until = self._next_refresh_after[path] - if now < backoff_until: - # Credential is in backoff for automated refresh, do not queue - remaining = int(backoff_until - now) - lib_logger.debug(f"Skipping automated refresh for '{Path(path).name}' (in backoff for {remaining}s)") - return - - async with self._queue_tracking_lock: - if path not in self._queued_credentials: - self._queued_credentials.add(path) - self._unavailable_credentials.add(path) # Mark as unavailable - await self._refresh_queue.put((path, force, needs_reauth)) - await self._ensure_queue_processor_running() - - async def _process_refresh_queue(self): - """Background worker that processes refresh requests sequentially.""" - while True: - path = None - try: - # Wait for an item with timeout to allow graceful shutdown - try: - path, force, needs_reauth = await asyncio.wait_for( - self._refresh_queue.get(), - timeout=60.0 - ) - except asyncio.TimeoutError: - # No items for 60s, exit to save resources - self._queue_processor_task = None - return - - try: - # Perform the actual refresh (still using per-credential lock) - async with await self._get_lock(path): - # Re-check if still expired (may have changed since queueing) - creds = self._credentials_cache.get(path) - if creds and not self._is_token_expired(creds): - # No longer expired, mark as available - async with self._queue_tracking_lock: - self._unavailable_credentials.discard(path) - continue - - # Perform refresh - if not creds: - creds = await self._load_credentials(path) - await self._refresh_token(path, creds, force=force) - - # SUCCESS: Mark as available again - async with self._queue_tracking_lock: - self._unavailable_credentials.discard(path) - - finally: - # Remove from queued set - async with self._queue_tracking_lock: - self._queued_credentials.discard(path) - self._refresh_queue.task_done() - except asyncio.CancelledError: - break - except Exception as e: - lib_logger.error(f"Error in queue processor: {e}") - # Even on error, mark as available (backoff will prevent immediate retry) - if path: - async with self._queue_tracking_lock: - self._unavailable_credentials.discard(path) - - async def initialize_token(self, creds_or_path: Union[Dict[str, Any], str]) -> Dict[str, Any]: - path = creds_or_path if isinstance(creds_or_path, str) else None - - # Get display name from metadata if available, otherwise derive from path - if isinstance(creds_or_path, dict): - display_name = creds_or_path.get("_proxy_metadata", {}).get("display_name", "in-memory object") - else: - display_name = Path(path).name if path else "in-memory object" - - lib_logger.debug(f"Initializing Gemini token for '{display_name}'...") - try: - creds = await self._load_credentials(creds_or_path) if path else creds_or_path - reason = "" - if not creds.get("refresh_token"): - reason = "refresh token is missing" - elif self._is_token_expired(creds): - reason = "token is expired" - - if reason: - if reason == "token is expired" and creds.get("refresh_token"): - try: - return await self._refresh_token(path, creds) - except Exception as e: - lib_logger.warning(f"Automatic token refresh for '{display_name}' failed: {e}. Proceeding to interactive login.") - - lib_logger.warning(f"Gemini OAuth token for '{display_name}' needs setup: {reason}.") - - # [HEADLESS DETECTION] Check if running in headless environment - is_headless = is_headless_environment() - - auth_code_future = asyncio.get_event_loop().create_future() - server = None - - async def handle_callback(reader, writer): - try: - request_line_bytes = await reader.readline() - if not request_line_bytes: return - path = request_line_bytes.decode('utf-8').strip().split(' ')[1] - while await reader.readline() != b'\r\n': pass - from urllib.parse import urlparse, parse_qs - query_params = parse_qs(urlparse(path).query) - writer.write(b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n") - if 'code' in query_params: - if not auth_code_future.done(): - auth_code_future.set_result(query_params['code'][0]) - writer.write(b"

Authentication successful!

You can close this window.

") - else: - error = query_params.get('error', ['Unknown error'])[0] - if not auth_code_future.done(): - auth_code_future.set_exception(Exception(f"OAuth failed: {error}")) - writer.write(f"

Authentication Failed

Error: {error}. Please try again.

".encode()) - await writer.drain() - except Exception as e: - lib_logger.error(f"Error in OAuth callback handler: {e}") - finally: - writer.close() - - try: - server = await asyncio.start_server(handle_callback, '127.0.0.1', 8085) - from urllib.parse import urlencode - auth_url = "https://accounts.google.com/o/oauth2/v2/auth?" + urlencode({ - "client_id": CLIENT_ID, - "redirect_uri": "http://localhost:8085/oauth2callback", - "scope": " ".join(["https://www.googleapis.com/auth/cloud-platform", "https://www.googleapis.com/auth/userinfo.email", "https://www.googleapis.com/auth/userinfo.profile"]), - "access_type": "offline", "response_type": "code", "prompt": "consent" - }) - - # [HEADLESS SUPPORT] Display appropriate instructions - if is_headless: - auth_panel_text = Text.from_markup( - "Running in headless environment (no GUI detected).\n" - "Please open the URL below in a browser on another machine to authorize:\n" - ) - else: - auth_panel_text = Text.from_markup( - "1. Your browser will now open to log in and authorize the application.\n" - "2. If it doesn't open automatically, please open the URL below manually." - ) - - console.print(Panel(auth_panel_text, title=f"Gemini OAuth Setup for [bold yellow]{display_name}[/bold yellow]", style="bold blue")) - console.print(f"[bold]URL:[/bold] [link={auth_url}]{auth_url}[/link]\n") - - # [HEADLESS SUPPORT] Only attempt browser open if NOT headless - if not is_headless: - try: - webbrowser.open(auth_url) - lib_logger.info("Browser opened successfully for OAuth flow") - except Exception as e: - lib_logger.warning(f"Failed to open browser automatically: {e}. Please open the URL manually.") - - with console.status("[bold green]Waiting for you to complete authentication in the browser...[/bold green]", spinner="dots"): - auth_code = await asyncio.wait_for(auth_code_future, timeout=300) - except asyncio.TimeoutError: - raise Exception("OAuth flow timed out. Please try again.") - finally: - if server: - server.close() - await server.wait_closed() - - lib_logger.info(f"Attempting to exchange authorization code for tokens...") - async with httpx.AsyncClient() as client: - response = await client.post(TOKEN_URI, data={ - "code": auth_code.strip(), "client_id": CLIENT_ID, "client_secret": CLIENT_SECRET, - "redirect_uri": "http://localhost:8085/oauth2callback", "grant_type": "authorization_code" - }) - response.raise_for_status() - token_data = response.json() - # Start with the full token data from the exchange - creds = token_data.copy() - - # Convert 'expires_in' to 'expiry_date' in milliseconds - creds["expiry_date"] = (time.time() + creds.pop("expires_in")) * 1000 - - # Ensure client_id and client_secret are present - creds["client_id"] = CLIENT_ID - creds["client_secret"] = CLIENT_SECRET - - creds["token_uri"] = TOKEN_URI - creds["universe_domain"] = "googleapis.com" - - # Fetch user info and add metadata - user_info_response = await client.get(USER_INFO_URI, headers={"Authorization": f"Bearer {creds['access_token']}"}) - user_info_response.raise_for_status() - user_info = user_info_response.json() - creds["_proxy_metadata"] = { - "email": user_info.get("email"), - "last_check_timestamp": time.time() - } - - if path: - await self._save_credentials(path, creds) - lib_logger.info(f"Gemini OAuth initialized successfully for '{display_name}'.") - return creds - - lib_logger.info(f"Gemini OAuth token at '{display_name}' is valid.") - return creds - except Exception as e: - raise ValueError(f"Failed to initialize Gemini OAuth for '{path}': {e}") - - async def get_auth_header(self, credential_path: str) -> Dict[str, str]: - creds = await self._load_credentials(credential_path) - if self._is_token_expired(creds): - creds = await self._refresh_token(credential_path, creds) - return {"Authorization": f"Bearer {creds['access_token']}"} - - async def get_user_info(self, creds_or_path: Union[Dict[str, Any], str]) -> Dict[str, Any]: - path = creds_or_path if isinstance(creds_or_path, str) else None - creds = await self._load_credentials(creds_or_path) if path else creds_or_path - - if path and self._is_token_expired(creds): - creds = await self._refresh_token(path, creds) - - # Prefer locally stored metadata - if creds.get("_proxy_metadata", {}).get("email"): - if path: - creds["_proxy_metadata"]["last_check_timestamp"] = time.time() - await self._save_credentials(path, creds) - return {"email": creds["_proxy_metadata"]["email"]} - - # Fallback to API call if metadata is missing - headers = {"Authorization": f"Bearer {creds['access_token']}"} - async with httpx.AsyncClient() as client: - response = await client.get(USER_INFO_URI, headers=headers) - response.raise_for_status() - user_info = response.json() - - # Save the retrieved info for future use - creds["_proxy_metadata"] = { - "email": user_info.get("email"), - "last_check_timestamp": time.time() - } - if path: - await self._save_credentials(path, creds) - return {"email": user_info.get("email")} \ No newline at end of file +from .google_oauth_base import GoogleOAuthBase + +class GeminiAuthBase(GoogleOAuthBase): + """ + Gemini CLI OAuth2 authentication implementation. + + Inherits all OAuth functionality from GoogleOAuthBase with Gemini-specific configuration. + """ + + CLIENT_ID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" + CLIENT_SECRET = "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" + OAUTH_SCOPES = [ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/userinfo.email", + "https://www.googleapis.com/auth/userinfo.profile", + ] + ENV_PREFIX = "GEMINI_CLI" + CALLBACK_PORT = 8085 + CALLBACK_PATH = "/oauth2callback" \ No newline at end of file diff --git a/src/rotator_library/providers/gemini_cli_provider.py b/src/rotator_library/providers/gemini_cli_provider.py index fe3980f..259fb83 100644 --- a/src/rotator_library/providers/gemini_cli_provider.py +++ b/src/rotator_library/providers/gemini_cli_provider.py @@ -1,5 +1,6 @@ # src/rotator_library/providers/gemini_cli_provider.py +import copy import json import httpx import logging @@ -8,10 +9,11 @@ from typing import List, Dict, Any, AsyncGenerator, Union, Optional, Tuple from .provider_interface import ProviderInterface from .gemini_auth_base import GeminiAuthBase +from .provider_cache import ProviderCache from ..model_definitions import ModelDefinitions import litellm from litellm.exceptions import RateLimitError -from litellm.llms.vertex_ai.common_utils import _build_vertex_schema +from ..error_handler import extract_retry_after_from_body import os from pathlib import Path import uuid @@ -81,9 +83,89 @@ def log_final_response(self, response_data: Dict[str, Any]): HARDCODED_MODELS = [ "gemini-2.5-pro", "gemini-2.5-flash", - "gemini-2.5-flash-lite" + "gemini-2.5-flash-lite", + "gemini-3-pro-preview" ] +# Cache directory for Gemini CLI +CACHE_DIR = Path(__file__).resolve().parent.parent.parent.parent / "cache" / "gemini_cli" +GEMINI3_SIGNATURE_CACHE_FILE = CACHE_DIR / "gemini3_signatures.json" + +# Gemini 3 tool fix system instruction (prevents hallucination) +DEFAULT_GEMINI3_SYSTEM_INSTRUCTION = """ +You are operating in a CUSTOM ENVIRONMENT where tool definitions COMPLETELY DIFFER from your training data. +VIOLATION OF THESE RULES WILL CAUSE IMMEDIATE SYSTEM FAILURE. + +## ABSOLUTE RULES - NO EXCEPTIONS + +1. **SCHEMA IS LAW**: The JSON schema in each tool definition is the ONLY source of truth. + - Your pre-trained knowledge about tools like 'read_file', 'apply_diff', 'write_to_file', 'bash', etc. is INVALID here. + - Every tool has been REDEFINED with different parameters than what you learned during training. + +2. **PARAMETER NAMES ARE EXACT**: Use ONLY the parameter names from the schema. + - WRONG: 'suggested_answers', 'file_path', 'files_to_read', 'command_to_run' + - RIGHT: Check the 'properties' field in the schema for the exact names + - The schema's 'required' array tells you which parameters are mandatory + +3. **ARRAY PARAMETERS**: When a parameter has "type": "array", check the 'items' field: + - If items.type is "object", you MUST provide an array of objects with the EXACT properties listed + - If items.type is "string", you MUST provide an array of strings + - NEVER provide a single object when an array is expected + - NEVER provide an array when a single value is expected + +4. **NESTED OBJECTS**: When items.type is "object": + - Check items.properties for the EXACT field names required + - Check items.required for which nested fields are mandatory + - Include ALL required nested fields in EVERY array element + +5. **STRICT PARAMETERS HINT**: Tool descriptions contain "STRICT PARAMETERS: ..." which lists: + - Parameter name, type, and whether REQUIRED + - For arrays of objects: the nested structure in brackets like [field: type REQUIRED, ...] + - USE THIS as your quick reference, but the JSON schema is authoritative + +6. **BEFORE EVERY TOOL CALL**: + a. Read the tool's 'parametersJsonSchema' or 'parameters' field completely + b. Identify ALL required parameters + c. Verify your parameter names match EXACTLY (case-sensitive) + d. For arrays, verify you're providing the correct item structure + e. Do NOT add parameters that don't exist in the schema + +## COMMON FAILURE PATTERNS TO AVOID + +- Using 'path' when schema says 'filePath' (or vice versa) +- Using 'content' when schema says 'text' (or vice versa) +- Providing {"file": "..."} when schema wants [{"path": "...", "line_ranges": [...]}] +- Omitting required nested fields in array items +- Adding 'additionalProperties' that the schema doesn't define +- Guessing parameter names from similar tools you know from training + +## REMEMBER +Your training data about function calling is OUTDATED for this environment. +The tool names may look familiar, but the schemas are DIFFERENT. +When in doubt, RE-READ THE SCHEMA before making the call. + +""" + +# Gemini finish reason mapping +FINISH_REASON_MAP = { + "STOP": "stop", + "MAX_TOKENS": "length", + "SAFETY": "content_filter", + "RECITATION": "content_filter", + "OTHER": "stop", +} + + +def _env_bool(key: str, default: bool = False) -> bool: + """Get boolean from environment variable.""" + return os.getenv(key, str(default).lower()).lower() in ("true", "1", "yes") + + +def _env_int(key: str, default: int) -> int: + """Get integer from environment variable.""" + return int(os.getenv(key, str(default))) + + class GeminiCliProvider(GeminiAuthBase, ProviderInterface): skip_cost_calculation = True @@ -92,9 +174,125 @@ def __init__(self): self.model_definitions = ModelDefinitions() self.project_id_cache: Dict[str, str] = {} # Cache project ID per credential path self.project_tier_cache: Dict[str, str] = {} # Cache project tier per credential path + + # Gemini 3 configuration from environment + memory_ttl = _env_int("GEMINI_CLI_SIGNATURE_CACHE_TTL", 3600) + disk_ttl = _env_int("GEMINI_CLI_SIGNATURE_DISK_TTL", 86400) + + # Initialize signature cache for Gemini 3 thoughtSignatures + self._signature_cache = ProviderCache( + GEMINI3_SIGNATURE_CACHE_FILE, memory_ttl, disk_ttl, + env_prefix="GEMINI_CLI_SIGNATURE" + ) + + # Gemini 3 feature flags + self._preserve_signatures_in_client = _env_bool("GEMINI_CLI_PRESERVE_THOUGHT_SIGNATURES", True) + self._enable_signature_cache = _env_bool("GEMINI_CLI_ENABLE_SIGNATURE_CACHE", True) + self._enable_gemini3_tool_fix = _env_bool("GEMINI_CLI_GEMINI3_TOOL_FIX", True) + self._gemini3_enforce_strict_schema = _env_bool("GEMINI_CLI_GEMINI3_STRICT_SCHEMA", True) + + # Gemini 3 tool fix configuration + self._gemini3_tool_prefix = os.getenv("GEMINI_CLI_GEMINI3_TOOL_PREFIX", "gemini3_") + self._gemini3_description_prompt = os.getenv( + "GEMINI_CLI_GEMINI3_DESCRIPTION_PROMPT", + "\n\n⚠️ STRICT PARAMETERS (use EXACTLY as shown): {params}. Do NOT use parameters from your training data - use ONLY these parameter names." + ) + self._gemini3_system_instruction = os.getenv( + "GEMINI_CLI_GEMINI3_SYSTEM_INSTRUCTION", + DEFAULT_GEMINI3_SYSTEM_INSTRUCTION + ) + + lib_logger.debug( + f"GeminiCli config: signatures_in_client={self._preserve_signatures_in_client}, " + f"cache={self._enable_signature_cache}, gemini3_fix={self._enable_gemini3_tool_fix}, " + f"gemini3_strict_schema={self._gemini3_enforce_strict_schema}" + ) + + # ========================================================================= + # CREDENTIAL PRIORITIZATION + # ========================================================================= + + def get_credential_priority(self, credential: str) -> Optional[int]: + """ + Returns priority based on Gemini tier. + Paid tiers: priority 1 (highest) + Free/Legacy tiers: priority 2 + Unknown: priority 10 (lowest) + + Args: + credential: The credential path + + Returns: + Priority level (1-10) or None if tier not yet discovered + """ + tier = self.project_tier_cache.get(credential) + if not tier: + return None # Not yet discovered + + # Paid tiers get highest priority + if tier not in ['free-tier', 'legacy-tier', 'unknown']: + return 1 + + # Free tier gets lower priority + if tier == 'free-tier': + return 2 + + # Legacy and unknown get even lower + return 10 + + def get_model_tier_requirement(self, model: str) -> Optional[int]: + """ + Returns the minimum priority tier required for a model. + Gemini 3 requires paid tier (priority 1). + + Args: + model: The model name (with or without provider prefix) + + Returns: + Minimum required priority level or None if no restrictions + """ + model_name = model.split('/')[-1].replace(':thinking', '') + + # Gemini 3 requires paid tier + if model_name.startswith("gemini-3-"): + return 1 # Only priority 1 (paid) credentials + + return None # All other models have no restrictions + + + + # ========================================================================= + # MODEL UTILITIES + # ========================================================================= + + def _is_gemini_3(self, model: str) -> bool: + """Check if model is Gemini 3 (requires special handling).""" + model_name = model.split('/')[-1].replace(':thinking', '') + return model_name.startswith("gemini-3-") + + def _strip_gemini3_prefix(self, name: str) -> str: + """Strip the Gemini 3 namespace prefix from a tool name.""" + if name and name.startswith(self._gemini3_tool_prefix): + return name[len(self._gemini3_tool_prefix):] + return name async def _discover_project_id(self, credential_path: str, access_token: str, litellm_params: Dict[str, Any]) -> str: - """Discovers the Google Cloud Project ID, with caching and onboarding for new accounts.""" + """ + Discovers the Google Cloud Project ID, with caching and onboarding for new accounts. + + This follows the official Gemini CLI discovery flow: + 1. Check in-memory cache + 2. Check configured project_id override (litellm_params or env var) + 3. Check persisted project_id in credential file + 4. Call loadCodeAssist to check if user is already known (has currentTier) + - If currentTier exists AND cloudaicompanionProject returned: use server's project + - If currentTier exists but NO cloudaicompanionProject: use configured project_id (paid tier requires this) + - If no currentTier: user needs onboarding + 5. Onboard user based on tier: + - FREE tier: pass cloudaicompanionProject=None (server-managed) + - PAID tier: pass cloudaicompanionProject=configured_project_id + 6. Fallback to GCP Resource Manager project listing + """ lib_logger.debug(f"Starting project discovery for credential: {credential_path}") # Check in-memory cache first @@ -103,34 +301,37 @@ async def _discover_project_id(self, credential_path: str, access_token: str, li lib_logger.debug(f"Using cached project ID: {cached_project}") return cached_project - # Check for configured project ID override - if litellm_params.get("project_id"): - project_id = litellm_params["project_id"] - lib_logger.info(f"Using configured Gemini CLI project ID: {project_id}") - self.project_id_cache[credential_path] = project_id - return project_id - - # [NEW] Load credentials from file to check for persisted project_id and tier - try: - with open(credential_path, 'r') as f: - creds = json.load(f) - - metadata = creds.get("_proxy_metadata", {}) - persisted_project_id = metadata.get("project_id") - persisted_tier = metadata.get("tier") - - if persisted_project_id: - lib_logger.info(f"Loaded persisted project ID from credential file: {persisted_project_id}") - self.project_id_cache[credential_path] = persisted_project_id + # Check for configured project ID override (from litellm_params or env var) + # This is REQUIRED for paid tier users per the official CLI behavior + configured_project_id = litellm_params.get("project_id") + if configured_project_id: + lib_logger.debug(f"Found configured project_id override: {configured_project_id}") + + # Load credentials from file to check for persisted project_id and tier + # Skip for env:// paths (environment-based credentials don't persist to files) + credential_index = self._parse_env_credential_path(credential_path) + if credential_index is None: + # Only try to load from file if it's not an env:// path + try: + with open(credential_path, 'r') as f: + creds = json.load(f) - # Also load tier if available - if persisted_tier: - self.project_tier_cache[credential_path] = persisted_tier - lib_logger.debug(f"Loaded persisted tier: {persisted_tier}") + metadata = creds.get("_proxy_metadata", {}) + persisted_project_id = metadata.get("project_id") + persisted_tier = metadata.get("tier") - return persisted_project_id - except (FileNotFoundError, json.JSONDecodeError, KeyError) as e: - lib_logger.debug(f"Could not load persisted project ID from file: {e}") + if persisted_project_id: + lib_logger.info(f"Loaded persisted project ID from credential file: {persisted_project_id}") + self.project_id_cache[credential_path] = persisted_project_id + + # Also load tier if available + if persisted_tier: + self.project_tier_cache[credential_path] = persisted_tier + lib_logger.debug(f"Loaded persisted tier: {persisted_tier}") + + return persisted_project_id + except (FileNotFoundError, json.JSONDecodeError, KeyError) as e: + lib_logger.debug(f"Could not load persisted project ID from file: {e}") lib_logger.debug("No cached or configured project ID found, initiating discovery...") headers = {'Authorization': f'Bearer {access_token}', 'Content-Type': 'application/json'} @@ -139,64 +340,168 @@ async def _discover_project_id(self, credential_path: str, access_token: str, li discovered_tier = None async with httpx.AsyncClient() as client: - # 1. Try discovery endpoint with onboarding logic + # 1. Try discovery endpoint with loadCodeAssist lib_logger.debug("Attempting project discovery via Code Assist loadCodeAssist endpoint...") try: - initial_project_id = "default" - client_metadata = { - "ideType": "IDE_UNSPECIFIED", "platform": "PLATFORM_UNSPECIFIED", - "pluginType": "GEMINI", "duetProject": initial_project_id, + # Build metadata - include duetProject only if we have a configured project + core_client_metadata = { + "ideType": "IDE_UNSPECIFIED", + "platform": "PLATFORM_UNSPECIFIED", + "pluginType": "GEMINI", + } + if configured_project_id: + core_client_metadata["duetProject"] = configured_project_id + + # Build load request - pass configured_project_id if available, otherwise None + load_request = { + "cloudaicompanionProject": configured_project_id, # Can be None + "metadata": core_client_metadata, } - load_request = {"cloudaicompanionProject": initial_project_id, "metadata": client_metadata} + lib_logger.debug(f"Sending loadCodeAssist request with cloudaicompanionProject={configured_project_id}") response = await client.post(f"{CODE_ASSIST_ENDPOINT}:loadCodeAssist", headers=headers, json=load_request, timeout=20) response.raise_for_status() data = response.json() - # Extract tier information for paid project detection - selected_tier_id = None - allowed_tiers = data.get('allowedTiers', []) - lib_logger.debug(f"Available tiers from loadCodeAssist response: {[t.get('id') for t in allowed_tiers]}") + # Log full response for debugging + lib_logger.debug(f"loadCodeAssist full response keys: {list(data.keys())}") + # Extract and log ALL tier information for debugging + allowed_tiers = data.get('allowedTiers', []) + current_tier = data.get('currentTier') + + lib_logger.debug(f"=== Tier Information ===") + lib_logger.debug(f"currentTier: {current_tier}") + lib_logger.debug(f"allowedTiers count: {len(allowed_tiers)}") + for i, tier in enumerate(allowed_tiers): + tier_id = tier.get('id', 'unknown') + is_default = tier.get('isDefault', False) + user_defined = tier.get('userDefinedCloudaicompanionProject', False) + lib_logger.debug(f" Tier {i+1}: id={tier_id}, isDefault={is_default}, userDefinedProject={user_defined}") + lib_logger.debug(f"========================") + + # Determine the current tier ID + current_tier_id = None + if current_tier: + current_tier_id = current_tier.get('id') + lib_logger.debug(f"User has currentTier: {current_tier_id}") + + # Check if user is already known to server (has currentTier) + if current_tier_id: + # User is already onboarded - check for project from server + server_project = data.get('cloudaicompanionProject') + + # Check if this tier requires user-defined project (paid tiers) + requires_user_project = any( + t.get('id') == current_tier_id and t.get('userDefinedCloudaicompanionProject', False) + for t in allowed_tiers + ) + is_free_tier = current_tier_id == 'free-tier' + + if server_project: + # Server returned a project - use it (server wins) + # This is the normal case for FREE tier users + project_id = server_project + lib_logger.debug(f"Server returned project: {project_id}") + elif configured_project_id: + # No server project but we have configured one - use it + # This is the PAID TIER case where server doesn't return a project + project_id = configured_project_id + lib_logger.debug(f"No server project, using configured: {project_id}") + elif is_free_tier: + # Free tier user without server project - this shouldn't happen normally + # but let's not fail, just proceed to onboarding + lib_logger.debug("Free tier user with currentTier but no project - will try onboarding") + project_id = None + elif requires_user_project: + # Paid tier requires a project ID to be set + raise ValueError( + f"Paid tier '{current_tier_id}' requires setting GEMINI_CLI_PROJECT_ID environment variable. " + "See https://goo.gle/gemini-cli-auth-docs#workspace-gca" + ) + else: + # Unknown tier without project - proceed carefully + lib_logger.warning(f"Tier '{current_tier_id}' has no project and none configured - will try onboarding") + project_id = None + + if project_id: + # Cache tier info + self.project_tier_cache[credential_path] = current_tier_id + discovered_tier = current_tier_id + + # Log appropriately based on tier + is_paid = current_tier_id and current_tier_id not in ['free-tier', 'legacy-tier', 'unknown'] + if is_paid: + lib_logger.info(f"Using Gemini paid tier '{current_tier_id}' with project: {project_id}") + else: + lib_logger.info(f"Discovered Gemini project ID via loadCodeAssist: {project_id}") + + self.project_id_cache[credential_path] = project_id + discovered_project_id = project_id + + # Persist to credential file + await self._persist_project_metadata(credential_path, project_id, discovered_tier) + + return project_id + + # 2. User needs onboarding - no currentTier + lib_logger.info("No existing Gemini session found (no currentTier), attempting to onboard user...") + + # Determine which tier to onboard with + onboard_tier = None for tier in allowed_tiers: if tier.get('isDefault'): - selected_tier_id = tier.get('id', 'unknown') - lib_logger.debug(f"Selected default tier: {selected_tier_id}") + onboard_tier = tier break - if not selected_tier_id and allowed_tiers: - selected_tier_id = allowed_tiers[0].get('id', 'unknown') - lib_logger.debug(f"No default tier found, using first available: {selected_tier_id}") - - if data.get('cloudaicompanionProject'): - project_id = data['cloudaicompanionProject'] - lib_logger.debug(f"Existing project found in loadCodeAssist response: {project_id}") - - # Cache tier info - if selected_tier_id: - self.project_tier_cache[credential_path] = selected_tier_id - discovered_tier = selected_tier_id - lib_logger.debug(f"Cached tier information: {selected_tier_id}") - - # Log concise message for paid projects - is_paid = selected_tier_id and selected_tier_id not in ['free-tier', 'legacy-tier', 'unknown'] - if is_paid: - lib_logger.info(f"Using Gemini paid project: {project_id}") - else: - lib_logger.info(f"Discovered Gemini project ID via loadCodeAssist: {project_id}") - - self.project_id_cache[credential_path] = project_id - discovered_project_id = project_id - - # [NEW] Persist to credential file - await self._persist_project_metadata(credential_path, project_id, discovered_tier) - - return project_id - # 2. If no project ID, trigger onboarding - lib_logger.info("No existing Gemini project found, attempting to onboard user...") - tier_id = next((t.get('id', 'free-tier') for t in data.get('allowedTiers', []) if t.get('isDefault')), 'free-tier') - lib_logger.debug(f"Onboarding with tier: {tier_id}") - onboard_request = {"tierId": tier_id, "cloudaicompanionProject": initial_project_id, "metadata": client_metadata} + # Fallback to LEGACY tier if no default (requires user project) + if not onboard_tier and allowed_tiers: + # Look for legacy-tier as fallback + for tier in allowed_tiers: + if tier.get('id') == 'legacy-tier': + onboard_tier = tier + break + # If still no tier, use first available + if not onboard_tier: + onboard_tier = allowed_tiers[0] + + if not onboard_tier: + raise ValueError("No onboarding tiers available from server") + + tier_id = onboard_tier.get('id', 'free-tier') + requires_user_project = onboard_tier.get('userDefinedCloudaicompanionProject', False) + + lib_logger.debug(f"Onboarding with tier: {tier_id}, requiresUserProject: {requires_user_project}") + + # Build onboard request based on tier type (following official CLI logic) + # FREE tier: cloudaicompanionProject = None (server-managed) + # PAID tier: cloudaicompanionProject = configured_project_id (user must provide) + is_free_tier = tier_id == 'free-tier' + + if is_free_tier: + # Free tier uses server-managed project + onboard_request = { + "tierId": tier_id, + "cloudaicompanionProject": None, # Server will create/manage + "metadata": core_client_metadata, + } + lib_logger.debug("Free tier onboarding: using server-managed project") + else: + # Paid/legacy tier requires user-provided project + if not configured_project_id and requires_user_project: + raise ValueError( + f"Tier '{tier_id}' requires setting GEMINI_CLI_PROJECT_ID environment variable. " + "See https://goo.gle/gemini-cli-auth-docs#workspace-gca" + ) + onboard_request = { + "tierId": tier_id, + "cloudaicompanionProject": configured_project_id, + "metadata": { + **core_client_metadata, + "duetProject": configured_project_id, + } if configured_project_id else core_client_metadata, + } + lib_logger.debug(f"Paid tier onboarding: using project {configured_project_id}") lib_logger.debug("Initiating onboardUser request...") lro_response = await client.post(f"{CODE_ASSIST_ENDPOINT}:onboardUser", headers=headers, json=onboard_request, timeout=30) @@ -204,7 +509,7 @@ async def _discover_project_id(self, credential_path: str, access_token: str, li lro_data = lro_response.json() lib_logger.debug(f"Initial onboarding response: done={lro_data.get('done')}") - for i in range(150): # Poll for up to 5 minutes (150 × 2s) + for i in range(150): # Poll for up to 5 minutes (150 × 2s) if lro_data.get('done'): lib_logger.debug(f"Onboarding completed after {i} polling attempts") break @@ -220,41 +525,62 @@ async def _discover_project_id(self, credential_path: str, access_token: str, li lib_logger.error("Onboarding process timed out after 5 minutes") raise ValueError("Onboarding process timed out after 5 minutes. Please try again or contact support.") - project_id = lro_data.get('response', {}).get('cloudaicompanionProject', {}).get('id') + # Extract project ID from LRO response + # Note: onboardUser returns response.cloudaicompanionProject as an object with .id + lro_response_data = lro_data.get('response', {}) + lro_project_obj = lro_response_data.get('cloudaicompanionProject', {}) + project_id = lro_project_obj.get('id') if isinstance(lro_project_obj, dict) else None + + # Fallback to configured project if LRO didn't return one + if not project_id and configured_project_id: + project_id = configured_project_id + lib_logger.debug(f"LRO didn't return project, using configured: {project_id}") + if not project_id: - lib_logger.error("Onboarding completed but no project ID in response") - raise ValueError("Onboarding completed, but no project ID was returned.") + lib_logger.error("Onboarding completed but no project ID in response and none configured") + raise ValueError( + "Onboarding completed, but no project ID was returned. " + "For paid tiers, set GEMINI_CLI_PROJECT_ID environment variable." + ) lib_logger.debug(f"Successfully extracted project ID from onboarding response: {project_id}") # Cache tier info - if tier_id: - self.project_tier_cache[credential_path] = tier_id - discovered_tier = tier_id - lib_logger.debug(f"Cached tier information: {tier_id}") + self.project_tier_cache[credential_path] = tier_id + discovered_tier = tier_id + lib_logger.debug(f"Cached tier information: {tier_id}") # Log concise message for paid projects is_paid = tier_id and tier_id not in ['free-tier', 'legacy-tier'] if is_paid: - lib_logger.info(f"Using Gemini paid project: {project_id}") + lib_logger.info(f"Using Gemini paid tier '{tier_id}' with project: {project_id}") else: lib_logger.info(f"Successfully onboarded user and discovered project ID: {project_id}") self.project_id_cache[credential_path] = project_id discovered_project_id = project_id - # [NEW] Persist to credential file + # Persist to credential file await self._persist_project_metadata(credential_path, project_id, discovered_tier) return project_id except httpx.HTTPStatusError as e: + error_body = "" + try: + error_body = e.response.text + except Exception: + pass if e.response.status_code == 403: - lib_logger.error(f"Gemini Code Assist API access denied (403). The cloudaicompanion.googleapis.com API may not be enabled for your account. Please enable it in Google Cloud Console.") + lib_logger.error(f"Gemini Code Assist API access denied (403). Response: {error_body}") + lib_logger.error("Possible causes: 1) cloudaicompanion.googleapis.com API not enabled, 2) Wrong project ID for paid tier, 3) Account lacks permissions") elif e.response.status_code == 404: lib_logger.warning(f"Gemini Code Assist endpoint not found (404). Falling back to project listing.") + elif e.response.status_code == 412: + # Precondition Failed - often means wrong project for free tier onboarding + lib_logger.error(f"Precondition failed (412): {error_body}. This may mean the project ID is incompatible with the selected tier.") else: - lib_logger.warning(f"Gemini onboarding/discovery failed with status {e.response.status_code}: {e}. Falling back to project listing.") + lib_logger.warning(f"Gemini onboarding/discovery failed with status {e.response.status_code}: {error_body}. Falling back to project listing.") except httpx.RequestError as e: lib_logger.warning(f"Gemini onboarding/discovery network error: {e}. Falling back to project listing.") @@ -303,6 +629,12 @@ async def _discover_project_id(self, credential_path: str, access_token: str, li async def _persist_project_metadata(self, credential_path: str, project_id: str, tier: Optional[str]): """Persists project ID and tier to the credential file for faster future startups.""" + # Skip persistence for env:// paths (environment-based credentials) + credential_index = self._parse_env_credential_path(credential_path) + if credential_index is not None: + lib_logger.debug(f"Skipping project metadata persistence for env:// credential path: {credential_path}") + return + try: # Load current credentials with open(credential_path, 'r') as f: @@ -311,13 +643,13 @@ async def _persist_project_metadata(self, credential_path: str, project_id: str, # Update metadata if "_proxy_metadata" not in creds: creds["_proxy_metadata"] = {} - + creds["_proxy_metadata"]["project_id"] = project_id if tier: creds["_proxy_metadata"]["tier"] = tier # Save back using the existing save method (handles atomic writes and permissions) - self._save_credentials(credential_path, creds) + await self._save_credentials(credential_path, creds) lib_logger.debug(f"Persisted project_id and tier to credential file: {credential_path}") except Exception as e: @@ -374,9 +706,20 @@ def _cli_preview_fallback_order(self, model: str) -> List[str]: # Return fallback chain if available, otherwise just return the original model return fallback_chains.get(model_name, [model_name]) - def _transform_messages(self, messages: List[Dict[str, Any]]) -> Tuple[Optional[Dict[str, Any]], List[Dict[str, Any]]]: + def _transform_messages(self, messages: List[Dict[str, Any]], model: str = "") -> Tuple[Optional[Dict[str, Any]], List[Dict[str, Any]]]: + """ + Transform OpenAI messages to Gemini CLI format. + + Handles: + - System instruction extraction + - Multi-part content (text, images) + - Tool calls and responses + - Gemini 3 thoughtSignature preservation + """ + messages = copy.deepcopy(messages) # Don't mutate original system_instruction = None gemini_contents = [] + is_gemini_3 = self._is_gemini_3(model) # Separate system prompt from other messages if messages and messages[0].get('role') == 'system': @@ -394,11 +737,21 @@ def _transform_messages(self, messages: List[Dict[str, Any]]) -> Tuple[Optional[ if tool_call.get("type") == "function": tool_call_id_to_name[tool_call["id"]] = tool_call["function"]["name"] + # Process messages and consolidate consecutive tool responses + # Per Gemini docs: parallel function responses must be in a single user message, + # not interleaved as separate messages + pending_tool_parts = [] # Accumulate tool responses + for msg in messages: role = msg.get("role") content = msg.get("content") parts = [] - gemini_role = "model" if role == "assistant" else "tool" if role == "tool" else "user" + gemini_role = "model" if role == "assistant" else "user" # tool -> user in Gemini + + # If we have pending tool parts and hit a non-tool message, flush them first + if pending_tool_parts and role != "tool": + gemini_contents.append({"role": "user", "parts": pending_tool_parts}) + pending_tool_parts = [] if role == "user": if isinstance(content, str): @@ -435,44 +788,121 @@ def _transform_messages(self, messages: List[Dict[str, Any]]) -> Tuple[Optional[ if isinstance(content, str): parts.append({"text": content}) if msg.get("tool_calls"): + # Track if we've seen the first function call in this message + # Per Gemini docs: Only the FIRST parallel function call gets a signature + first_func_in_msg = True for tool_call in msg["tool_calls"]: if tool_call.get("type") == "function": try: args_dict = json.loads(tool_call["function"]["arguments"]) except (json.JSONDecodeError, TypeError): args_dict = {} - parts.append({"functionCall": {"name": tool_call["function"]["name"], "args": args_dict}}) + + tool_id = tool_call.get("id", "") + func_name = tool_call["function"]["name"] + + # Add prefix for Gemini 3 + if is_gemini_3 and self._enable_gemini3_tool_fix: + func_name = f"{self._gemini3_tool_prefix}{func_name}" + + func_part = { + "functionCall": { + "name": func_name, + "args": args_dict, + "id": tool_id + } + } + + # Add thoughtSignature for Gemini 3 + # Per Gemini docs: Only the FIRST parallel function call gets a signature. + # Subsequent parallel calls should NOT have a thoughtSignature field. + if is_gemini_3: + sig = tool_call.get("thought_signature") + if not sig and tool_id and self._enable_signature_cache: + sig = self._signature_cache.retrieve(tool_id) + + if sig: + func_part["thoughtSignature"] = sig + elif first_func_in_msg: + # Only add bypass to the first function call if no sig available + func_part["thoughtSignature"] = "skip_thought_signature_validator" + lib_logger.warning(f"Missing thoughtSignature for first func call {tool_id}, using bypass") + # Subsequent parallel calls: no signature field at all + + first_func_in_msg = False + + parts.append(func_part) elif role == "tool": tool_call_id = msg.get("tool_call_id") function_name = tool_call_id_to_name.get(tool_call_id) if function_name: + # Add prefix for Gemini 3 + if is_gemini_3 and self._enable_gemini3_tool_fix: + function_name = f"{self._gemini3_tool_prefix}{function_name}" + # Wrap the tool response in a 'result' object response_content = {"result": content} - parts.append({"functionResponse": {"name": function_name, "response": response_content}}) + # Accumulate tool responses - they'll be combined into one user message + pending_tool_parts.append({ + "functionResponse": { + "name": function_name, + "response": response_content, + "id": tool_call_id + } + }) + # Don't add parts here - tool responses are handled via pending_tool_parts + continue if parts: gemini_contents.append({"role": gemini_role, "parts": parts}) + # Flush any remaining tool parts at end of messages + if pending_tool_parts: + gemini_contents.append({"role": "user", "parts": pending_tool_parts}) + if not gemini_contents or gemini_contents[0]['role'] != 'user': gemini_contents.insert(0, {"role": "user", "parts": [{"text": ""}]}) return system_instruction, gemini_contents def _handle_reasoning_parameters(self, payload: Dict[str, Any], model: str) -> Optional[Dict[str, Any]]: + """ + Map reasoning_effort to thinking configuration. + + - Gemini 2.5: thinkingBudget (integer tokens) + - Gemini 3: thinkingLevel (string: "low"/"high") + """ custom_reasoning_budget = payload.get("custom_reasoning_budget", False) reasoning_effort = payload.get("reasoning_effort") if "thinkingConfig" in payload.get("generationConfig", {}): return None - # Only apply reasoning logic to the gemini-2.5 model family - if "gemini-2.5" not in model: + is_gemini_25 = "gemini-2.5" in model + is_gemini_3 = self._is_gemini_3(model) + + # Only apply reasoning logic to supported models + if not (is_gemini_25 or is_gemini_3): payload.pop("reasoning_effort", None) payload.pop("custom_reasoning_budget", None) return None + + # Gemini 3: String-based thinkingLevel + if is_gemini_3: + # Clean up the original payload + payload.pop("reasoning_effort", None) + payload.pop("custom_reasoning_budget", None) + + if reasoning_effort == "low": + return {"thinkingLevel": "low", "include_thoughts": True} + return {"thinkingLevel": "high", "include_thoughts": True} + # Gemini 2.5: Integer thinkingBudget if not reasoning_effort: + # Clean up the original payload + payload.pop("reasoning_effort", None) + payload.pop("custom_reasoning_budget", None) return {"thinkingBudget": -1, "include_thoughts": True} # If reasoning_effort is provided, calculate the budget @@ -498,8 +928,15 @@ def _handle_reasoning_parameters(self, payload: Dict[str, Any], model: str) -> O return {"thinkingBudget": budget, "include_thoughts": True} - def _convert_chunk_to_openai(self, chunk: Dict[str, Any], model_id: str): - lib_logger.debug(f"Converting Gemini chunk: {json.dumps(chunk)}") + def _convert_chunk_to_openai(self, chunk: Dict[str, Any], model_id: str, accumulator: Optional[Dict[str, Any]] = None): + """ + Convert Gemini response chunk to OpenAI streaming format. + + Args: + chunk: Gemini API response chunk + model_id: Model name + accumulator: Optional dict to accumulate data for post-processing (signatures, etc.) + """ response_data = chunk.get('response', chunk) candidates = response_data.get('candidates', []) if not candidates: @@ -507,29 +944,65 @@ def _convert_chunk_to_openai(self, chunk: Dict[str, Any], model_id: str): candidate = candidates[0] parts = candidate.get('content', {}).get('parts', []) + is_gemini_3 = self._is_gemini_3(model_id) for part in parts: delta = {} - finish_reason = None + + has_func = 'functionCall' in part + has_text = 'text' in part + has_sig = bool(part.get('thoughtSignature')) + is_thought = part.get('thought') is True or (isinstance(part.get('thought'), str) and str(part.get('thought')).lower() == 'true') + + # Skip standalone signature parts (no function, no meaningful text) + if has_sig and not has_func and (not has_text or not part.get('text')): + continue - if 'functionCall' in part: + if has_func: function_call = part['functionCall'] function_name = function_call.get('name', 'unknown') - # Generate unique ID with nanosecond precision - tool_call_id = f"call_{function_name}_{int(time.time() * 1_000_000_000)}" - delta['tool_calls'] = [{ - "index": 0, + + # Strip Gemini 3 prefix from tool name + if is_gemini_3 and self._enable_gemini3_tool_fix: + function_name = self._strip_gemini3_prefix(function_name) + + # Use provided ID or generate unique one with nanosecond precision + tool_call_id = function_call.get('id') or f"call_{function_name}_{int(time.time() * 1_000_000_000)}" + + # Get current tool index from accumulator (default 0) and increment + current_tool_idx = accumulator.get('tool_idx', 0) if accumulator else 0 + + tool_call = { + "index": current_tool_idx, "id": tool_call_id, "type": "function", "function": { "name": function_name, "arguments": json.dumps(function_call.get('args', {})) } - }] - elif 'text' in part: + } + + # Handle thoughtSignature for Gemini 3 + # Store signature for each tool call (needed for parallel tool calls) + if is_gemini_3 and has_sig: + sig = part['thoughtSignature'] + + if self._enable_signature_cache: + self._signature_cache.store(tool_call_id, sig) + lib_logger.debug(f"Stored signature for {tool_call_id}") + + if self._preserve_signatures_in_client: + tool_call["thought_signature"] = sig + + delta['tool_calls'] = [tool_call] + # Mark that we've sent tool calls and increment tool_idx + if accumulator is not None: + accumulator['has_tool_calls'] = True + accumulator['tool_idx'] = current_tool_idx + 1 + + elif has_text: # Use an explicit check for the 'thought' flag, as its type can be inconsistent - thought = part.get('thought') - if thought is True or (isinstance(thought, str) and thought.lower() == 'true'): + if is_thought: delta['reasoning_content'] = part['text'] else: delta['content'] = part['text'] @@ -537,16 +1010,20 @@ def _convert_chunk_to_openai(self, chunk: Dict[str, Any], model_id: str): if not delta: continue - raw_finish_reason = candidate.get('finishReason') - if raw_finish_reason: - mapping = {'STOP': 'stop', 'MAX_TOKENS': 'length', 'SAFETY': 'content_filter'} - finish_reason = mapping.get(raw_finish_reason, 'stop') + # Mark that we have tool calls for accumulator tracking + # finish_reason determination is handled by the client + + # Mark stream complete if we have usageMetadata + is_final_chunk = 'usageMetadata' in response_data + if is_final_chunk and accumulator is not None: + accumulator['is_complete'] = True - choice = {"index": 0, "delta": delta, "finish_reason": finish_reason} + # Build choice - don't include finish_reason, let client handle it + choice = {"index": 0, "delta": delta} openai_chunk = { "choices": [choice], "model": model_id, "object": "chat.completion.chunk", - "id": f"chatcmpl-geminicli-{time.time()}", "created": int(time.time()) + "id": chunk.get("responseId", f"chatcmpl-geminicli-{time.time()}"), "created": int(time.time()) } if 'usageMetadata' in response_data: @@ -572,7 +1049,11 @@ def _convert_chunk_to_openai(self, chunk: Dict[str, Any], model_id: str): def _stream_to_completion_response(self, chunks: List[litellm.ModelResponse]) -> litellm.ModelResponse: """ Manually reassembles streaming chunks into a complete response. - This replaces the non-existent litellm.utils.stream_to_completion_response function. + + Key improvements: + - Determines finish_reason based on accumulated state + - Priority: tool_calls > chunk's finish_reason (length, content_filter, etc.) > stop + - Properly initializes tool_calls with type field """ if not chunks: raise ValueError("No chunks provided for reassembly") @@ -581,7 +1062,7 @@ def _stream_to_completion_response(self, chunks: List[litellm.ModelResponse]) -> final_message = {"role": "assistant"} aggregated_tool_calls = {} usage_data = None - finish_reason = None + chunk_finish_reason = None # Track finish_reason from chunks # Get the first chunk for basic response metadata first_chunk = chunks[0] @@ -609,11 +1090,13 @@ def _stream_to_completion_response(self, chunks: List[litellm.ModelResponse]) -> # Aggregate tool calls if "tool_calls" in delta and delta["tool_calls"]: for tc_chunk in delta["tool_calls"]: - index = tc_chunk["index"] + index = tc_chunk.get("index", 0) if index not in aggregated_tool_calls: aggregated_tool_calls[index] = {"type": "function", "function": {"name": "", "arguments": ""}} if "id" in tc_chunk: aggregated_tool_calls[index]["id"] = tc_chunk["id"] + if "type" in tc_chunk: + aggregated_tool_calls[index]["type"] = tc_chunk["type"] if "function" in tc_chunk: if "name" in tc_chunk["function"] and tc_chunk["function"]["name"] is not None: aggregated_tool_calls[index]["function"]["name"] += tc_chunk["function"]["name"] @@ -629,9 +1112,9 @@ def _stream_to_completion_response(self, chunks: List[litellm.ModelResponse]) -> if "arguments" in delta["function_call"] and delta["function_call"]["arguments"] is not None: final_message["function_call"]["arguments"] += delta["function_call"]["arguments"] - # Get finish reason from the last chunk that has it + # Track finish_reason from chunks (respects length, content_filter, etc.) if choice.get("finish_reason"): - finish_reason = choice["finish_reason"] + chunk_finish_reason = choice["finish_reason"] # Handle usage data from the last chunk that has it for chunk in reversed(chunks): @@ -648,6 +1131,15 @@ def _stream_to_completion_response(self, chunks: List[litellm.ModelResponse]) -> if field not in final_message: final_message[field] = None + # Determine finish_reason based on accumulated state + # Priority: tool_calls wins if present, then chunk's finish_reason (length, content_filter, etc.), then default to "stop" + if aggregated_tool_calls: + finish_reason = "tool_calls" + elif chunk_finish_reason: + finish_reason = chunk_finish_reason + else: + finish_reason = "stop" + # Construct the final response final_choice = { "index": 0, @@ -704,12 +1196,44 @@ def _gemini_cli_transform_schema(self, schema: Dict[str, Any]) -> Dict[str, Any] return schema - def _transform_tool_schemas(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + def _enforce_strict_schema(self, schema: Any) -> Any: + """ + Enforce strict JSON schema for Gemini 3 to prevent hallucinated parameters. + + Adds 'additionalProperties: false' recursively to all object schemas, + which tells the model it CANNOT add properties not in the schema. + """ + if not isinstance(schema, dict): + return schema + + result = {} + for key, value in schema.items(): + if isinstance(value, dict): + result[key] = self._enforce_strict_schema(value) + elif isinstance(value, list): + result[key] = [self._enforce_strict_schema(item) if isinstance(item, dict) else item for item in value] + else: + result[key] = value + + # Add additionalProperties: false to object schemas + if result.get("type") == "object" and "properties" in result: + result["additionalProperties"] = False + + return result + + def _transform_tool_schemas(self, tools: List[Dict[str, Any]], model: str = "") -> List[Dict[str, Any]]: """ Transforms a list of OpenAI-style tool schemas into the format required by the Gemini CLI API. This uses a custom schema transformer instead of litellm's generic one. + + For Gemini 3 models, also applies: + - Namespace prefix to tool names + - Parameter signature injection into descriptions + - Strict schema enforcement (additionalProperties: false) """ transformed_declarations = [] + is_gemini_3 = self._is_gemini_3(model) + for tool in tools: if tool.get("type") == "function" and "function" in tool: new_function = json.loads(json.dumps(tool["function"])) @@ -726,19 +1250,139 @@ def _transform_tool_schemas(self, tools: List[Dict[str, Any]]) -> List[Dict[str, # Set default empty schema if neither exists new_function["parametersJsonSchema"] = {"type": "object", "properties": {}} + # Gemini 3 specific transformations + if is_gemini_3 and self._enable_gemini3_tool_fix: + # Add namespace prefix to tool names + name = new_function.get("name", "") + if name: + new_function["name"] = f"{self._gemini3_tool_prefix}{name}" + + # Enforce strict schema (additionalProperties: false) + if self._gemini3_enforce_strict_schema and "parametersJsonSchema" in new_function: + new_function["parametersJsonSchema"] = self._enforce_strict_schema(new_function["parametersJsonSchema"]) + + # Inject parameter signature into description + new_function = self._inject_signature_into_description(new_function) + transformed_declarations.append(new_function) return transformed_declarations - def _translate_tool_choice(self, tool_choice: Union[str, Dict[str, Any]]) -> Optional[Dict[str, Any]]: + def _inject_signature_into_description(self, func_decl: Dict[str, Any]) -> Dict[str, Any]: + """Inject parameter signatures into tool description for Gemini 3.""" + schema = func_decl.get("parametersJsonSchema", {}) + if not schema: + return func_decl + + required = schema.get("required", []) + properties = schema.get("properties", {}) + + if not properties: + return func_decl + + param_list = [] + for prop_name, prop_data in properties.items(): + if not isinstance(prop_data, dict): + continue + + type_hint = self._format_type_hint(prop_data) + is_required = prop_name in required + param_list.append( + f"{prop_name} ({type_hint}{', REQUIRED' if is_required else ''})" + ) + + if param_list: + sig_str = self._gemini3_description_prompt.replace( + "{params}", ", ".join(param_list) + ) + func_decl["description"] = func_decl.get("description", "") + sig_str + + return func_decl + + def _format_type_hint(self, prop_data: Dict[str, Any], depth: int = 0) -> str: + """Format a detailed type hint for a property schema.""" + type_hint = prop_data.get("type", "unknown") + + # Handle enum values - show allowed options + if "enum" in prop_data: + enum_vals = prop_data["enum"] + if len(enum_vals) <= 5: + return f"string ENUM[{', '.join(repr(v) for v in enum_vals)}]" + return f"string ENUM[{len(enum_vals)} options]" + + # Handle const values + if "const" in prop_data: + return f"string CONST={repr(prop_data['const'])}" + + if type_hint == "array": + items = prop_data.get("items", {}) + if isinstance(items, dict): + item_type = items.get("type", "unknown") + if item_type == "object": + nested_props = items.get("properties", {}) + nested_req = items.get("required", []) + if nested_props: + nested_list = [] + for n, d in nested_props.items(): + if isinstance(d, dict): + # Recursively format nested types (limit depth) + if depth < 1: + t = self._format_type_hint(d, depth + 1) + else: + t = d.get("type", "unknown") + req = " REQUIRED" if n in nested_req else "" + nested_list.append(f"{n}: {t}{req}") + return f"ARRAY_OF_OBJECTS[{', '.join(nested_list)}]" + return "ARRAY_OF_OBJECTS" + return f"ARRAY_OF_{item_type.upper()}" + return "ARRAY" + + if type_hint == "object": + nested_props = prop_data.get("properties", {}) + nested_req = prop_data.get("required", []) + if nested_props and depth < 1: + nested_list = [] + for n, d in nested_props.items(): + if isinstance(d, dict): + t = d.get("type", "unknown") + req = " REQUIRED" if n in nested_req else "" + nested_list.append(f"{n}: {t}{req}") + return f"object{{{', '.join(nested_list)}}}" + + return type_hint + + def _inject_gemini3_system_instruction(self, request_payload: Dict[str, Any]) -> None: + """Inject Gemini 3 tool fix system instruction if tools are present.""" + if not request_payload.get("request", {}).get("tools"): + return + + existing_system = request_payload.get("request", {}).get("systemInstruction") + + if existing_system: + # Prepend to existing system instruction + existing_parts = existing_system.get("parts", []) + if existing_parts and existing_parts[0].get("text"): + existing_parts[0]["text"] = self._gemini3_system_instruction + "\n\n" + existing_parts[0]["text"] + else: + existing_parts.insert(0, {"text": self._gemini3_system_instruction}) + else: + # Create new system instruction + request_payload["request"]["systemInstruction"] = { + "role": "user", + "parts": [{"text": self._gemini3_system_instruction}] + } + + def _translate_tool_choice(self, tool_choice: Union[str, Dict[str, Any]], model: str = "") -> Optional[Dict[str, Any]]: """ Translates OpenAI's `tool_choice` to Gemini's `toolConfig`. + Handles Gemini 3 namespace prefixes for specific tool selection. """ if not tool_choice: return None config = {} mode = "AUTO" # Default to auto + is_gemini_3 = self._is_gemini_3(model) if isinstance(tool_choice, str): if tool_choice == "auto": @@ -750,6 +1394,10 @@ def _translate_tool_choice(self, tool_choice: Union[str, Dict[str, Any]]) -> Opt elif isinstance(tool_choice, dict) and tool_choice.get("type") == "function": function_name = tool_choice.get("function", {}).get("name") if function_name: + # Add Gemini 3 prefix if needed + if is_gemini_3 and self._enable_gemini3_tool_fix: + function_name = f"{self._gemini3_tool_prefix}{function_name}" + mode = "ANY" # Force a call, but only to this function config["functionCallingConfig"] = { "mode": mode, @@ -778,6 +1426,11 @@ async def do_call(attempt_model: str, is_fallback: bool = False): access_token = auth_header['Authorization'].split(' ')[1] project_id = await self._discover_project_id(credential_path, access_token, kwargs.get("litellm_params", {})) + # Log paid tier usage visibly on each request + credential_tier = self.project_tier_cache.get(credential_path) + if credential_tier and credential_tier not in ['free-tier', 'legacy-tier', 'unknown']: + lib_logger.info(f"[PAID TIER] Using Gemini '{credential_tier}' subscription for this request") + # Handle :thinking suffix model_name = attempt_model.split('/')[-1].replace(':thinking', '') @@ -786,6 +1439,8 @@ async def do_call(attempt_model: str, is_fallback: bool = False): model_name=model_name, enabled=enable_request_logging ) + + is_gemini_3 = self._is_gemini_3(model_name) gen_config = { "maxOutputTokens": kwargs.get("max_tokens", 64000), # Increased default @@ -801,7 +1456,7 @@ async def do_call(attempt_model: str, is_fallback: bool = False): if thinking_config: gen_config["thinkingConfig"] = thinking_config - system_instruction, contents = self._transform_messages(kwargs.get("messages", [])) + system_instruction, contents = self._transform_messages(kwargs.get("messages", []), model_name) request_payload = { "model": model_name, "project": project_id, @@ -815,15 +1470,19 @@ async def do_call(attempt_model: str, is_fallback: bool = False): request_payload["request"]["systemInstruction"] = system_instruction if "tools" in kwargs and kwargs["tools"]: - function_declarations = self._transform_tool_schemas(kwargs["tools"]) + function_declarations = self._transform_tool_schemas(kwargs["tools"], model_name) if function_declarations: request_payload["request"]["tools"] = [{"functionDeclarations": function_declarations}] # [NEW] Handle tool_choice translation if "tool_choice" in kwargs and kwargs["tool_choice"]: - tool_config = self._translate_tool_choice(kwargs["tool_choice"]) + tool_config = self._translate_tool_choice(kwargs["tool_choice"], model_name) if tool_config: request_payload["request"]["toolConfig"] = tool_config + + # Inject Gemini 3 system instruction if using tools + if is_gemini_3 and self._enable_gemini3_tool_fix: + self._inject_gemini3_system_instruction(request_payload) # Add default safety settings to prevent content filtering if "safetySettings" not in request_payload["request"]: @@ -842,6 +1501,9 @@ async def do_call(attempt_model: str, is_fallback: bool = False): url = f"{CODE_ASSIST_ENDPOINT}:streamGenerateContent" async def stream_handler(): + # Track state across chunks for tool indexing + accumulator = {"has_tool_calls": False, "tool_idx": 0, "is_complete": False} + final_headers = auth_header.copy() final_headers.update({ "User-Agent": "google-api-nodejs-client/9.15.1", @@ -851,6 +1513,15 @@ async def stream_handler(): }) try: async with client.stream("POST", url, headers=final_headers, json=request_payload, params={"alt": "sse"}, timeout=600) as response: + # Read and log error body before raise_for_status for better debugging + if response.status_code >= 400: + try: + error_body = await response.aread() + lib_logger.error(f"Gemini CLI API error {response.status_code}: {error_body.decode()}") + file_logger.log_error(f"API error {response.status_code}: {error_body.decode()}") + except Exception: + pass + # This will raise an HTTPStatusError for 4xx/5xx responses response.raise_for_status() @@ -861,10 +1532,24 @@ async def stream_handler(): if data_str == "[DONE]": break try: chunk = json.loads(data_str) - for openai_chunk in self._convert_chunk_to_openai(chunk, model): + for openai_chunk in self._convert_chunk_to_openai(chunk, model, accumulator): yield litellm.ModelResponse(**openai_chunk) except json.JSONDecodeError: lib_logger.warning(f"Could not decode JSON from Gemini CLI: {line}") + + # Emit final chunk if stream ended without usageMetadata + # Client will determine the correct finish_reason + if not accumulator.get("is_complete"): + final_chunk = { + "id": f"chatcmpl-geminicli-{time.time()}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [{"index": 0, "delta": {}, "finish_reason": None}], + # Include minimal usage to signal this is the final chunk + "usage": {"prompt_tokens": 0, "completion_tokens": 1, "total_tokens": 1} + } + yield litellm.ModelResponse(**final_chunk) except httpx.HTTPStatusError as e: error_body = None @@ -873,16 +1558,24 @@ async def stream_handler(): error_body = e.response.text except Exception: pass - log_line = f"Stream handler HTTPStatusError: {str(e)}" + + # Only log to file logger (for detailed logging) if error_body: - log_line = f"{log_line} | response_body={error_body}" - file_logger.log_error(log_line) + file_logger.log_error(f"HTTPStatusError {e.response.status_code}: {error_body}") + else: + file_logger.log_error(f"HTTPStatusError {e.response.status_code}: {str(e)}") + if e.response.status_code == 429: - # Pass the raw response object to the exception. Do not read the - # response body here as it will close the stream and cause a - # 'StreamClosed' error in the client's stream reader. + # Extract retry-after time from the error body + retry_after = extract_retry_after_from_body(error_body) + retry_info = f" (retry after {retry_after}s)" if retry_after else "" + error_msg = f"Gemini CLI rate limit exceeded{retry_info}" + if error_body: + error_msg = f"{error_msg} | {error_body}" + # Only log at debug level - rotation happens silently + lib_logger.debug(f"Gemini CLI 429 rate limit: retry_after={retry_after}s") raise RateLimitError( - message=f"Gemini CLI rate limit exceeded: {e.request.url}", + message=error_msg, llm_provider="gemini_cli", model=model, response=e.response @@ -919,7 +1612,8 @@ async def logging_stream_wrapper(): for idx, attempt_model in enumerate(fallback_models): is_fallback = idx > 0 if is_fallback: - lib_logger.info(f"Gemini CLI rate limited, retrying with fallback model: {attempt_model}") + # Silent rotation - only log at debug level + lib_logger.debug(f"Rate limited on previous model, trying fallback: {attempt_model}") elif has_fallbacks: lib_logger.debug(f"Attempting primary model: {attempt_model} (with {len(fallback_models)-1} fallback(s) available)") else: @@ -941,8 +1635,8 @@ async def logging_stream_wrapper(): if idx + 1 < len(fallback_models): lib_logger.debug(f"Rate limit hit on {attempt_model}, trying next fallback...") continue - # If this was the last fallback option, raise the error - lib_logger.error(f"Rate limit hit on all fallback models (tried {len(fallback_models)} models)") + # If this was the last fallback option, log error and raise + lib_logger.warning(f"Rate limit exhausted on all fallback models (tried {len(fallback_models)} models)") raise # Should not reach here, but raise last error if we do @@ -990,8 +1684,6 @@ async def count_tokens( # Build request payload request_payload = { - "model": model_name, - "project": project_id, "request": { "contents": contents, }, diff --git a/src/rotator_library/providers/google_oauth_base.py b/src/rotator_library/providers/google_oauth_base.py new file mode 100644 index 0000000..3f1ed9d --- /dev/null +++ b/src/rotator_library/providers/google_oauth_base.py @@ -0,0 +1,706 @@ +# src/rotator_library/providers/google_oauth_base.py + +import os +import webbrowser +from typing import Union, Optional +import json +import time +import asyncio +import logging +from pathlib import Path +from typing import Dict, Any +import tempfile +import shutil + +import httpx +from rich.console import Console +from rich.panel import Panel +from rich.text import Text + +from ..utils.headless_detection import is_headless_environment + +lib_logger = logging.getLogger('rotator_library') + +console = Console() + +class GoogleOAuthBase: + """ + Base class for Google OAuth2 authentication providers. + + Subclasses must override: + - CLIENT_ID: OAuth client ID + - CLIENT_SECRET: OAuth client secret + - OAUTH_SCOPES: List of OAuth scopes + - ENV_PREFIX: Prefix for environment variables (e.g., "GEMINI_CLI", "ANTIGRAVITY") + + Subclasses may optionally override: + - CALLBACK_PORT: Local OAuth callback server port (default: 8085) + - CALLBACK_PATH: OAuth callback path (default: "/oauth2callback") + - REFRESH_EXPIRY_BUFFER_SECONDS: Time buffer before token expiry (default: 30 minutes) + """ + + # Subclasses MUST override these + CLIENT_ID: str = None + CLIENT_SECRET: str = None + OAUTH_SCOPES: list = None + ENV_PREFIX: str = None + + # Subclasses MAY override these + TOKEN_URI: str = "https://oauth2.googleapis.com/token" + USER_INFO_URI: str = "https://www.googleapis.com/oauth2/v1/userinfo" + CALLBACK_PORT: int = 8085 + CALLBACK_PATH: str = "/oauth2callback" + REFRESH_EXPIRY_BUFFER_SECONDS: int = 30 * 60 # 30 minutes + + def __init__(self): + # Validate that subclass has set required attributes + if self.CLIENT_ID is None: + raise NotImplementedError(f"{self.__class__.__name__} must set CLIENT_ID") + if self.CLIENT_SECRET is None: + raise NotImplementedError(f"{self.__class__.__name__} must set CLIENT_SECRET") + if self.OAUTH_SCOPES is None: + raise NotImplementedError(f"{self.__class__.__name__} must set OAUTH_SCOPES") + if self.ENV_PREFIX is None: + raise NotImplementedError(f"{self.__class__.__name__} must set ENV_PREFIX") + + self._credentials_cache: Dict[str, Dict[str, Any]] = {} + self._refresh_locks: Dict[str, asyncio.Lock] = {} + self._locks_lock = asyncio.Lock() # Protects the locks dict from race conditions + # [BACKOFF TRACKING] Track consecutive failures per credential + self._refresh_failures: Dict[str, int] = {} # Track consecutive failures per credential + self._next_refresh_after: Dict[str, float] = {} # Track backoff timers (Unix timestamp) + + # [QUEUE SYSTEM] Sequential refresh processing + self._refresh_queue: asyncio.Queue = asyncio.Queue() + self._queued_credentials: set = set() # Track credentials already in queue + self._unavailable_credentials: set = set() # Mark credentials unavailable during re-auth + self._queue_tracking_lock = asyncio.Lock() # Protects queue sets + self._queue_processor_task: Optional[asyncio.Task] = None # Background worker task + + def _parse_env_credential_path(self, path: str) -> Optional[str]: + """ + Parse a virtual env:// path and return the credential index. + + Supported formats: + - "env://provider/0" - Legacy single credential (no index in env var names) + - "env://provider/1" - First numbered credential (PROVIDER_1_ACCESS_TOKEN) + - "env://provider/2" - Second numbered credential, etc. + + Returns: + The credential index as string ("0" for legacy, "1", "2", etc. for numbered) + or None if path is not an env:// path + """ + if not path.startswith("env://"): + return None + + # Parse: env://provider/index + parts = path[6:].split("/") # Remove "env://" prefix + if len(parts) >= 2: + return parts[1] # Return the index + return "0" # Default to legacy format + + def _load_from_env(self, credential_index: Optional[str] = None) -> Optional[Dict[str, Any]]: + """ + Load OAuth credentials from environment variables for stateless deployments. + + Supports two formats: + 1. Legacy (credential_index="0" or None): PROVIDER_ACCESS_TOKEN + 2. Numbered (credential_index="1", "2", etc.): PROVIDER_1_ACCESS_TOKEN, PROVIDER_2_ACCESS_TOKEN + + Expected environment variables (for numbered format with index N): + - {ENV_PREFIX}_{N}_ACCESS_TOKEN (required) + - {ENV_PREFIX}_{N}_REFRESH_TOKEN (required) + - {ENV_PREFIX}_{N}_EXPIRY_DATE (optional, defaults to 0) + - {ENV_PREFIX}_{N}_CLIENT_ID (optional, uses default) + - {ENV_PREFIX}_{N}_CLIENT_SECRET (optional, uses default) + - {ENV_PREFIX}_{N}_TOKEN_URI (optional, uses default) + - {ENV_PREFIX}_{N}_UNIVERSE_DOMAIN (optional, defaults to googleapis.com) + - {ENV_PREFIX}_{N}_EMAIL (optional, defaults to "env-user-{N}") + - {ENV_PREFIX}_{N}_PROJECT_ID (optional) + - {ENV_PREFIX}_{N}_TIER (optional) + + For legacy format (index="0" or None), omit the _{N}_ part. + + Returns: + Dict with credential structure if env vars present, None otherwise + """ + # Determine the env var prefix based on credential index + if credential_index and credential_index != "0": + # Numbered format: PROVIDER_N_ACCESS_TOKEN + prefix = f"{self.ENV_PREFIX}_{credential_index}" + default_email = f"env-user-{credential_index}" + else: + # Legacy format: PROVIDER_ACCESS_TOKEN + prefix = self.ENV_PREFIX + default_email = "env-user" + + access_token = os.getenv(f"{prefix}_ACCESS_TOKEN") + refresh_token = os.getenv(f"{prefix}_REFRESH_TOKEN") + + # Both access and refresh tokens are required + if not (access_token and refresh_token): + return None + + lib_logger.debug(f"Loading {prefix} credentials from environment variables") + + # Parse expiry_date as float, default to 0 if not present + expiry_str = os.getenv(f"{prefix}_EXPIRY_DATE", "0") + try: + expiry_date = float(expiry_str) + except ValueError: + lib_logger.warning(f"Invalid {prefix}_EXPIRY_DATE value: {expiry_str}, using 0") + expiry_date = 0 + + creds = { + "access_token": access_token, + "refresh_token": refresh_token, + "expiry_date": expiry_date, + "client_id": os.getenv(f"{prefix}_CLIENT_ID", self.CLIENT_ID), + "client_secret": os.getenv(f"{prefix}_CLIENT_SECRET", self.CLIENT_SECRET), + "token_uri": os.getenv(f"{prefix}_TOKEN_URI", self.TOKEN_URI), + "universe_domain": os.getenv(f"{prefix}_UNIVERSE_DOMAIN", "googleapis.com"), + "_proxy_metadata": { + "email": os.getenv(f"{prefix}_EMAIL", default_email), + "last_check_timestamp": time.time(), + "loaded_from_env": True, # Flag to indicate env-based credentials + "env_credential_index": credential_index or "0" # Track which env credential this is + } + } + + # Add project_id if provided + project_id = os.getenv(f"{prefix}_PROJECT_ID") + if project_id: + creds["_proxy_metadata"]["project_id"] = project_id + + # Add tier if provided + tier = os.getenv(f"{prefix}_TIER") + if tier: + creds["_proxy_metadata"]["tier"] = tier + + return creds + + async def _load_credentials(self, path: str) -> Dict[str, Any]: + if path in self._credentials_cache: + return self._credentials_cache[path] + + async with await self._get_lock(path): + if path in self._credentials_cache: + return self._credentials_cache[path] + + # Check if this is a virtual env:// path + credential_index = self._parse_env_credential_path(path) + if credential_index is not None: + # Load from environment variables with specific index + env_creds = self._load_from_env(credential_index) + if env_creds: + lib_logger.info(f"Using {self.ENV_PREFIX} credentials from environment variables (index: {credential_index})") + self._credentials_cache[path] = env_creds + return env_creds + else: + raise IOError(f"Environment variables for {self.ENV_PREFIX} credential index {credential_index} not found") + + # For file paths, first try loading from legacy env vars (for backwards compatibility) + env_creds = self._load_from_env() + if env_creds: + lib_logger.info(f"Using {self.ENV_PREFIX} credentials from environment variables") + # Cache env-based credentials using the path as key + self._credentials_cache[path] = env_creds + return env_creds + + # Fall back to file-based loading + try: + lib_logger.debug(f"Loading {self.ENV_PREFIX} credentials from file: {path}") + with open(path, 'r') as f: + creds = json.load(f) + # Handle gcloud-style creds file which nest tokens under "credential" + if "credential" in creds: + creds = creds["credential"] + self._credentials_cache[path] = creds + return creds + except FileNotFoundError: + raise IOError(f"{self.ENV_PREFIX} OAuth credential file not found at '{path}'") + except Exception as e: + raise IOError(f"Failed to load {self.ENV_PREFIX} OAuth credentials from '{path}': {e}") + except Exception as e: + raise IOError(f"Failed to load {self.ENV_PREFIX} OAuth credentials from '{path}': {e}") + + async def _save_credentials(self, path: str, creds: Dict[str, Any]): + # Don't save to file if credentials were loaded from environment + if creds.get("_proxy_metadata", {}).get("loaded_from_env"): + lib_logger.debug("Credentials loaded from env, skipping file save") + # Still update cache for in-memory consistency + self._credentials_cache[path] = creds + return + + # [ATOMIC WRITE] Use tempfile + move pattern to ensure atomic writes + # This prevents credential corruption if the process is interrupted during write + parent_dir = os.path.dirname(os.path.abspath(path)) + os.makedirs(parent_dir, exist_ok=True) + + tmp_fd = None + tmp_path = None + try: + # Create temp file in same directory as target (ensures same filesystem) + tmp_fd, tmp_path = tempfile.mkstemp(dir=parent_dir, prefix='.tmp_', suffix='.json', text=True) + + # Write JSON to temp file + with os.fdopen(tmp_fd, 'w') as f: + json.dump(creds, f, indent=2) + tmp_fd = None # fdopen closes the fd + + # Set secure permissions (0600 = owner read/write only) + try: + os.chmod(tmp_path, 0o600) + except (OSError, AttributeError): + # Windows may not support chmod, ignore + pass + + # Atomic move (overwrites target if it exists) + shutil.move(tmp_path, path) + tmp_path = None # Successfully moved + + # Update cache AFTER successful file write (prevents cache/file inconsistency) + self._credentials_cache[path] = creds + lib_logger.debug(f"Saved updated {self.ENV_PREFIX} OAuth credentials to '{path}' (atomic write).") + + except Exception as e: + lib_logger.error(f"Failed to save updated {self.ENV_PREFIX} OAuth credentials to '{path}': {e}") + # Clean up temp file if it still exists + if tmp_fd is not None: + try: + os.close(tmp_fd) + except: + pass + if tmp_path and os.path.exists(tmp_path): + try: + os.unlink(tmp_path) + except: + pass + raise + + def _is_token_expired(self, creds: Dict[str, Any]) -> bool: + expiry = creds.get("token_expiry") # gcloud format + if not expiry: # gemini-cli format + expiry_timestamp = creds.get("expiry_date", 0) / 1000 + else: + expiry_timestamp = time.mktime(time.strptime(expiry, "%Y-%m-%dT%H:%M:%SZ")) + return expiry_timestamp < time.time() + self.REFRESH_EXPIRY_BUFFER_SECONDS + + async def _refresh_token(self, path: str, creds: Dict[str, Any], force: bool = False) -> Dict[str, Any]: + async with await self._get_lock(path): + # Skip the expiry check if a refresh is being forced + if not force and not self._is_token_expired(self._credentials_cache.get(path, creds)): + return self._credentials_cache.get(path, creds) + + lib_logger.debug(f"Refreshing {self.ENV_PREFIX} OAuth token for '{Path(path).name}' (forced: {force})...") + refresh_token = creds.get("refresh_token") + if not refresh_token: + raise ValueError("No refresh_token found in credentials file.") + + # [RETRY LOGIC] Implement exponential backoff for transient errors + max_retries = 3 + new_token_data = None + last_error = None + needs_reauth = False + + async with httpx.AsyncClient() as client: + for attempt in range(max_retries): + try: + response = await client.post(self.TOKEN_URI, data={ + "client_id": creds.get("client_id", self.CLIENT_ID), + "client_secret": creds.get("client_secret", self.CLIENT_SECRET), + "refresh_token": refresh_token, + "grant_type": "refresh_token", + }, timeout=30.0) + response.raise_for_status() + new_token_data = response.json() + break # Success, exit retry loop + + except httpx.HTTPStatusError as e: + last_error = e + status_code = e.response.status_code + + # [INVALID GRANT HANDLING] Handle 401/403 by triggering re-authentication + if status_code == 401 or status_code == 403: + lib_logger.warning( + f"Refresh token invalid for '{Path(path).name}' (HTTP {status_code}). " + f"Token may have been revoked or expired. Starting re-authentication..." + ) + needs_reauth = True + break # Exit retry loop to trigger re-auth + + elif status_code == 429: + # Rate limit - honor Retry-After header if present + retry_after = int(e.response.headers.get("Retry-After", 60)) + lib_logger.warning(f"Rate limited (HTTP 429), retry after {retry_after}s") + if attempt < max_retries - 1: + await asyncio.sleep(retry_after) + continue + raise + + elif status_code >= 500 and status_code < 600: + # Server error - retry with exponential backoff + if attempt < max_retries - 1: + wait_time = 2 ** attempt # 1s, 2s, 4s + lib_logger.warning(f"Server error (HTTP {status_code}), retry {attempt + 1}/{max_retries} in {wait_time}s") + await asyncio.sleep(wait_time) + continue + raise # Final attempt failed + + else: + # Other errors - don't retry + raise + + except (httpx.RequestError, httpx.TimeoutException) as e: + # Network errors - retry with backoff + last_error = e + if attempt < max_retries - 1: + wait_time = 2 ** attempt + lib_logger.warning(f"Network error during refresh: {e}, retry {attempt + 1}/{max_retries} in {wait_time}s") + await asyncio.sleep(wait_time) + continue + raise + + # [INVALID GRANT RE-AUTH] Trigger OAuth flow if refresh token is invalid + if needs_reauth: + lib_logger.info(f"Starting re-authentication for '{Path(path).name}'...") + try: + # Call initialize_token to trigger OAuth flow + new_creds = await self.initialize_token(path) + return new_creds + except Exception as reauth_error: + lib_logger.error(f"Re-authentication failed for '{Path(path).name}': {reauth_error}") + raise ValueError(f"Refresh token invalid and re-authentication failed: {reauth_error}") + + # If we exhausted retries without success + if new_token_data is None: + raise last_error or Exception("Token refresh failed after all retries") + + # [FIX 1] Update OAuth token fields from response + creds["access_token"] = new_token_data["access_token"] + expiry_timestamp = time.time() + new_token_data["expires_in"] + creds["expiry_date"] = expiry_timestamp * 1000 # gemini-cli format + + # [FIX 2] Update refresh_token if server provided a new one (rare but possible with Google OAuth) + if "refresh_token" in new_token_data: + creds["refresh_token"] = new_token_data["refresh_token"] + + # [FIX 3] Ensure all required OAuth client fields are present (restore if missing) + if "client_id" not in creds or not creds["client_id"]: + creds["client_id"] = self.CLIENT_ID + if "client_secret" not in creds or not creds["client_secret"]: + creds["client_secret"] = self.CLIENT_SECRET + if "token_uri" not in creds or not creds["token_uri"]: + creds["token_uri"] = self.TOKEN_URI + if "universe_domain" not in creds or not creds["universe_domain"]: + creds["universe_domain"] = "googleapis.com" + + # [FIX 4] Add scopes array if missing + if "scopes" not in creds: + creds["scopes"] = self.OAUTH_SCOPES + + # [FIX 5] Ensure _proxy_metadata exists and update timestamp + if "_proxy_metadata" not in creds: + creds["_proxy_metadata"] = {} + creds["_proxy_metadata"]["last_check_timestamp"] = time.time() + + # [VALIDATION] Verify refreshed credentials have all required fields + required_fields = ["access_token", "refresh_token", "client_id", "client_secret", "token_uri"] + missing_fields = [field for field in required_fields if not creds.get(field)] + if missing_fields: + raise ValueError(f"Refreshed credentials missing required fields: {missing_fields}") + + # [VALIDATION] Optional: Test that the refreshed token is actually usable + try: + async with httpx.AsyncClient() as client: + test_response = await client.get( + self.USER_INFO_URI, + headers={"Authorization": f"Bearer {creds['access_token']}"}, + timeout=5.0 + ) + test_response.raise_for_status() + lib_logger.debug(f"Token validation successful for '{Path(path).name}'") + except Exception as e: + lib_logger.warning(f"Refreshed token validation failed for '{Path(path).name}': {e}") + # Don't fail the refresh - the token might still work for other endpoints + # But log it for debugging purposes + + await self._save_credentials(path, creds) + lib_logger.debug(f"Successfully refreshed {self.ENV_PREFIX} OAuth token for '{Path(path).name}'.") + return creds + + async def proactively_refresh(self, credential_path: str): + """Proactively refresh a credential by queueing it for refresh.""" + creds = await self._load_credentials(credential_path) + if self._is_token_expired(creds): + # Queue for refresh with needs_reauth=False (automated refresh) + await self._queue_refresh(credential_path, force=False, needs_reauth=False) + + async def _get_lock(self, path: str) -> asyncio.Lock: + # [FIX RACE CONDITION] Protect lock creation with a master lock + # This prevents TOCTOU bug where multiple coroutines check and create simultaneously + async with self._locks_lock: + if path not in self._refresh_locks: + self._refresh_locks[path] = asyncio.Lock() + return self._refresh_locks[path] + + def is_credential_available(self, path: str) -> bool: + """Check if a credential is available for rotation (not queued/refreshing).""" + return path not in self._unavailable_credentials + + async def _ensure_queue_processor_running(self): + """Lazily starts the queue processor if not already running.""" + if self._queue_processor_task is None or self._queue_processor_task.done(): + self._queue_processor_task = asyncio.create_task(self._process_refresh_queue()) + + async def _queue_refresh(self, path: str, force: bool = False, needs_reauth: bool = False): + """Add a credential to the refresh queue if not already queued. + + Args: + path: Credential file path + force: Force refresh even if not expired + needs_reauth: True if full re-authentication needed (bypasses backoff) + """ + # IMPORTANT: Only check backoff for simple automated refreshes + # Re-authentication (interactive OAuth) should BYPASS backoff since it needs user input + if not needs_reauth: + now = time.time() + if path in self._next_refresh_after: + backoff_until = self._next_refresh_after[path] + if now < backoff_until: + # Credential is in backoff for automated refresh, do not queue + remaining = int(backoff_until - now) + lib_logger.debug(f"Skipping automated refresh for '{Path(path).name}' (in backoff for {remaining}s)") + return + + async with self._queue_tracking_lock: + if path not in self._queued_credentials: + self._queued_credentials.add(path) + self._unavailable_credentials.add(path) # Mark as unavailable + await self._refresh_queue.put((path, force, needs_reauth)) + await self._ensure_queue_processor_running() + + async def _process_refresh_queue(self): + """Background worker that processes refresh requests sequentially.""" + while True: + path = None + try: + # Wait for an item with timeout to allow graceful shutdown + try: + path, force, needs_reauth = await asyncio.wait_for( + self._refresh_queue.get(), + timeout=60.0 + ) + except asyncio.TimeoutError: + # No items for 60s, exit to save resources + self._queue_processor_task = None + return + + try: + # Perform the actual refresh (still using per-credential lock) + async with await self._get_lock(path): + # Re-check if still expired (may have changed since queueing) + creds = self._credentials_cache.get(path) + if creds and not self._is_token_expired(creds): + # No longer expired, mark as available + async with self._queue_tracking_lock: + self._unavailable_credentials.discard(path) + continue + + # Perform refresh + if not creds: + creds = await self._load_credentials(path) + await self._refresh_token(path, creds, force=force) + + # SUCCESS: Mark as available again + async with self._queue_tracking_lock: + self._unavailable_credentials.discard(path) + + finally: + # Remove from queued set + async with self._queue_tracking_lock: + self._queued_credentials.discard(path) + self._refresh_queue.task_done() + except asyncio.CancelledError: + break + except Exception as e: + lib_logger.error(f"Error in queue processor: {e}") + # Even on error, mark as available (backoff will prevent immediate retry) + if path: + async with self._queue_tracking_lock: + self._unavailable_credentials.discard(path) + + async def initialize_token(self, creds_or_path: Union[Dict[str, Any], str]) -> Dict[str, Any]: + path = creds_or_path if isinstance(creds_or_path, str) else None + + # Get display name from metadata if available, otherwise derive from path + if isinstance(creds_or_path, dict): + display_name = creds_or_path.get("_proxy_metadata", {}).get("display_name", "in-memory object") + else: + display_name = Path(path).name if path else "in-memory object" + + lib_logger.debug(f"Initializing {self.ENV_PREFIX} token for '{display_name}'...") + try: + creds = await self._load_credentials(creds_or_path) if path else creds_or_path + reason = "" + if not creds.get("refresh_token"): + reason = "refresh token is missing" + elif self._is_token_expired(creds): + reason = "token is expired" + + if reason: + if reason == "token is expired" and creds.get("refresh_token"): + try: + return await self._refresh_token(path, creds) + except Exception as e: + lib_logger.warning(f"Automatic token refresh for '{display_name}' failed: {e}. Proceeding to interactive login.") + + lib_logger.warning(f"{self.ENV_PREFIX} OAuth token for '{display_name}' needs setup: {reason}.") + + # [HEADLESS DETECTION] Check if running in headless environment + is_headless = is_headless_environment() + + auth_code_future = asyncio.get_event_loop().create_future() + server = None + + async def handle_callback(reader, writer): + try: + request_line_bytes = await reader.readline() + if not request_line_bytes: return + path_str = request_line_bytes.decode('utf-8').strip().split(' ')[1] + while await reader.readline() != b'\r\n': pass + from urllib.parse import urlparse, parse_qs + query_params = parse_qs(urlparse(path_str).query) + writer.write(b"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n") + if 'code' in query_params: + if not auth_code_future.done(): + auth_code_future.set_result(query_params['code'][0]) + writer.write(b"

Authentication successful!

You can close this window.

") + else: + error = query_params.get('error', ['Unknown error'])[0] + if not auth_code_future.done(): + auth_code_future.set_exception(Exception(f"OAuth failed: {error}")) + writer.write(f"

Authentication Failed

Error: {error}. Please try again.

".encode()) + await writer.drain() + except Exception as e: + lib_logger.error(f"Error in OAuth callback handler: {e}") + finally: + writer.close() + + try: + server = await asyncio.start_server(handle_callback, '127.0.0.1', self.CALLBACK_PORT) + from urllib.parse import urlencode + auth_url = "https://accounts.google.com/o/oauth2/v2/auth?" + urlencode({ + "client_id": self.CLIENT_ID, + "redirect_uri": f"http://localhost:{self.CALLBACK_PORT}{self.CALLBACK_PATH}", + "scope": " ".join(self.OAUTH_SCOPES), + "access_type": "offline", "response_type": "code", "prompt": "consent" + }) + + # [HEADLESS SUPPORT] Display appropriate instructions + if is_headless: + auth_panel_text = Text.from_markup( + "Running in headless environment (no GUI detected).\n" + "Please open the URL below in a browser on another machine to authorize:\n" + ) + else: + auth_panel_text = Text.from_markup( + "1. Your browser will now open to log in and authorize the application.\n" + "2. If it doesn't open automatically, please open the URL below manually." + ) + + console.print(Panel(auth_panel_text, title=f"{self.ENV_PREFIX} OAuth Setup for [bold yellow]{display_name}[/bold yellow]", style="bold blue")) + console.print(f"[bold]URL:[/bold] [link={auth_url}]{auth_url}[/link]\n") + + # [HEADLESS SUPPORT] Only attempt browser open if NOT headless + if not is_headless: + try: + webbrowser.open(auth_url) + lib_logger.info("Browser opened successfully for OAuth flow") + except Exception as e: + lib_logger.warning(f"Failed to open browser automatically: {e}. Please open the URL manually.") + + with console.status(f"[bold green]Waiting for you to complete authentication in the browser...[/bold green]", spinner="dots"): + auth_code = await asyncio.wait_for(auth_code_future, timeout=300) + except asyncio.TimeoutError: + raise Exception("OAuth flow timed out. Please try again.") + finally: + if server: + server.close() + await server.wait_closed() + + lib_logger.info(f"Attempting to exchange authorization code for tokens...") + async with httpx.AsyncClient() as client: + response = await client.post(self.TOKEN_URI, data={ + "code": auth_code.strip(), "client_id": self.CLIENT_ID, "client_secret": self.CLIENT_SECRET, + "redirect_uri": f"http://localhost:{self.CALLBACK_PORT}{self.CALLBACK_PATH}", "grant_type": "authorization_code" + }) + response.raise_for_status() + token_data = response.json() + # Start with the full token data from the exchange + creds = token_data.copy() + + # Convert 'expires_in' to 'expiry_date' in milliseconds + creds["expiry_date"] = (time.time() + creds.pop("expires_in")) * 1000 + + # Ensure client_id and client_secret are present + creds["client_id"] = self.CLIENT_ID + creds["client_secret"] = self.CLIENT_SECRET + + creds["token_uri"] = self.TOKEN_URI + creds["universe_domain"] = "googleapis.com" + + # Fetch user info and add metadata + user_info_response = await client.get(self.USER_INFO_URI, headers={"Authorization": f"Bearer {creds['access_token']}"}) + user_info_response.raise_for_status() + user_info = user_info_response.json() + creds["_proxy_metadata"] = { + "email": user_info.get("email"), + "last_check_timestamp": time.time() + } + + if path: + await self._save_credentials(path, creds) + lib_logger.info(f"{self.ENV_PREFIX} OAuth initialized successfully for '{display_name}'.") + return creds + + lib_logger.info(f"{self.ENV_PREFIX} OAuth token at '{display_name}' is valid.") + return creds + except Exception as e: + raise ValueError(f"Failed to initialize {self.ENV_PREFIX} OAuth for '{path}': {e}") + + async def get_auth_header(self, credential_path: str) -> Dict[str, str]: + creds = await self._load_credentials(credential_path) + if self._is_token_expired(creds): + creds = await self._refresh_token(credential_path, creds) + return {"Authorization": f"Bearer {creds['access_token']}"} + + async def get_user_info(self, creds_or_path: Union[Dict[str, Any], str]) -> Dict[str, Any]: + path = creds_or_path if isinstance(creds_or_path, str) else None + creds = await self._load_credentials(creds_or_path) if path else creds_or_path + + if path and self._is_token_expired(creds): + creds = await self._refresh_token(path, creds) + + # Prefer locally stored metadata + if creds.get("_proxy_metadata", {}).get("email"): + if path: + creds["_proxy_metadata"]["last_check_timestamp"] = time.time() + await self._save_credentials(path, creds) + return {"email": creds["_proxy_metadata"]["email"]} + + # Fallback to API call if metadata is missing + headers = {"Authorization": f"Bearer {creds['access_token']}"} + async with httpx.AsyncClient() as client: + response = await client.get(self.USER_INFO_URI, headers=headers) + response.raise_for_status() + user_info = response.json() + + # Save the retrieved info for future use + creds["_proxy_metadata"] = { + "email": user_info.get("email"), + "last_check_timestamp": time.time() + } + if path: + await self._save_credentials(path, creds) + return {"email": user_info.get("email")} diff --git a/src/rotator_library/providers/iflow_auth_base.py b/src/rotator_library/providers/iflow_auth_base.py index 4d77b79..cae8592 100644 --- a/src/rotator_library/providers/iflow_auth_base.py +++ b/src/rotator_library/providers/iflow_auth_base.py @@ -158,47 +158,79 @@ def __init__(self): self._queue_tracking_lock = asyncio.Lock() # Protects queue sets self._queue_processor_task: Optional[asyncio.Task] = None # Background worker task - def _load_from_env(self) -> Optional[Dict[str, Any]]: + def _parse_env_credential_path(self, path: str) -> Optional[str]: + """ + Parse a virtual env:// path and return the credential index. + + Supported formats: + - "env://provider/0" - Legacy single credential (no index in env var names) + - "env://provider/1" - First numbered credential (IFLOW_1_ACCESS_TOKEN) + + Returns: + The credential index as string, or None if path is not an env:// path + """ + if not path.startswith("env://"): + return None + + parts = path[6:].split("/") + if len(parts) >= 2: + return parts[1] + return "0" + + def _load_from_env(self, credential_index: Optional[str] = None) -> Optional[Dict[str, Any]]: """ Load OAuth credentials from environment variables for stateless deployments. - Expected environment variables: - - IFLOW_ACCESS_TOKEN (required) - - IFLOW_REFRESH_TOKEN (required) - - IFLOW_API_KEY (required - critical for iFlow!) - - IFLOW_EXPIRY_DATE (optional, defaults to empty string) - - IFLOW_EMAIL (optional, defaults to "env-user") - - IFLOW_TOKEN_TYPE (optional, defaults to "Bearer") - - IFLOW_SCOPE (optional, defaults to "read write") + Supports two formats: + 1. Legacy (credential_index="0" or None): IFLOW_ACCESS_TOKEN + 2. Numbered (credential_index="1", "2", etc.): IFLOW_1_ACCESS_TOKEN, etc. + + Expected environment variables (for numbered format with index N): + - IFLOW_{N}_ACCESS_TOKEN (required) + - IFLOW_{N}_REFRESH_TOKEN (required) + - IFLOW_{N}_API_KEY (required - critical for iFlow!) + - IFLOW_{N}_EXPIRY_DATE (optional, defaults to empty string) + - IFLOW_{N}_EMAIL (optional, defaults to "env-user-{N}") + - IFLOW_{N}_TOKEN_TYPE (optional, defaults to "Bearer") + - IFLOW_{N}_SCOPE (optional, defaults to "read write") Returns: Dict with credential structure if env vars present, None otherwise """ - access_token = os.getenv("IFLOW_ACCESS_TOKEN") - refresh_token = os.getenv("IFLOW_REFRESH_TOKEN") - api_key = os.getenv("IFLOW_API_KEY") + # Determine the env var prefix based on credential index + if credential_index and credential_index != "0": + prefix = f"IFLOW_{credential_index}" + default_email = f"env-user-{credential_index}" + else: + prefix = "IFLOW" + default_email = "env-user" + + access_token = os.getenv(f"{prefix}_ACCESS_TOKEN") + refresh_token = os.getenv(f"{prefix}_REFRESH_TOKEN") + api_key = os.getenv(f"{prefix}_API_KEY") # All three are required for iFlow if not (access_token and refresh_token and api_key): return None - lib_logger.debug("Loading iFlow credentials from environment variables") + lib_logger.debug(f"Loading iFlow credentials from environment variables (prefix: {prefix})") # Parse expiry_date as string (ISO 8601 format) - expiry_str = os.getenv("IFLOW_EXPIRY_DATE", "") + expiry_str = os.getenv(f"{prefix}_EXPIRY_DATE", "") creds = { "access_token": access_token, "refresh_token": refresh_token, "api_key": api_key, # Critical for iFlow! "expiry_date": expiry_str, - "email": os.getenv("IFLOW_EMAIL", "env-user"), - "token_type": os.getenv("IFLOW_TOKEN_TYPE", "Bearer"), - "scope": os.getenv("IFLOW_SCOPE", "read write"), + "email": os.getenv(f"{prefix}_EMAIL", default_email), + "token_type": os.getenv(f"{prefix}_TOKEN_TYPE", "Bearer"), + "scope": os.getenv(f"{prefix}_SCOPE", "read write"), "_proxy_metadata": { - "email": os.getenv("IFLOW_EMAIL", "env-user"), + "email": os.getenv(f"{prefix}_EMAIL", default_email), "last_check_timestamp": time.time(), - "loaded_from_env": True # Flag to indicate env-based credentials + "loaded_from_env": True, + "env_credential_index": credential_index or "0" } } @@ -227,11 +259,21 @@ async def _load_credentials(self, path: str) -> Dict[str, Any]: if path in self._credentials_cache: return self._credentials_cache[path] - # First, try loading from environment variables + # Check if this is a virtual env:// path + credential_index = self._parse_env_credential_path(path) + if credential_index is not None: + env_creds = self._load_from_env(credential_index) + if env_creds: + lib_logger.info(f"Using iFlow credentials from environment variables (index: {credential_index})") + self._credentials_cache[path] = env_creds + return env_creds + else: + raise IOError(f"Environment variables for iFlow credential index {credential_index} not found") + + # For file paths, try loading from legacy env vars first env_creds = self._load_from_env() if env_creds: lib_logger.info("Using iFlow credentials from environment variables") - # Cache env-based credentials using the path as key self._credentials_cache[path] = env_creds return env_creds @@ -509,12 +551,25 @@ async def _refresh_token(self, path: str, force: bool = False) -> Dict[str, Any] try: # Call initialize_token to trigger OAuth flow new_creds = await self.initialize_token(path) + # Clear backoff on successful re-auth + self._refresh_failures.pop(path, None) + self._next_refresh_after.pop(path, None) return new_creds except Exception as reauth_error: lib_logger.error(f"Re-authentication failed for '{Path(path).name}': {reauth_error}") + # [BACKOFF TRACKING] Increment failure count and set backoff timer + self._refresh_failures[path] = self._refresh_failures.get(path, 0) + 1 + backoff_seconds = min(300, 30 * (2 ** self._refresh_failures[path])) # Max 5 min backoff + self._next_refresh_after[path] = time.time() + backoff_seconds + lib_logger.debug(f"Setting backoff for '{Path(path).name}': {backoff_seconds}s") raise ValueError(f"Refresh token invalid and re-authentication failed: {reauth_error}") if new_token_data is None: + # [BACKOFF TRACKING] Increment failure count and set backoff timer + self._refresh_failures[path] = self._refresh_failures.get(path, 0) + 1 + backoff_seconds = min(300, 30 * (2 ** self._refresh_failures[path])) # Max 5 min backoff + self._next_refresh_after[path] = time.time() + backoff_seconds + lib_logger.debug(f"Setting backoff for '{Path(path).name}': {backoff_seconds}s") raise last_error or Exception("Token refresh failed after all retries") # Update tokens @@ -547,6 +602,16 @@ async def _refresh_token(self, path: str, force: bool = False) -> Dict[str, Any] creds_from_file["_proxy_metadata"] = {} creds_from_file["_proxy_metadata"]["last_check_timestamp"] = time.time() + # [VALIDATION] Verify required fields exist after refresh + required_fields = ["access_token", "refresh_token", "api_key"] + missing_fields = [field for field in required_fields if not creds_from_file.get(field)] + if missing_fields: + raise ValueError(f"Refreshed credentials missing required fields: {missing_fields}") + + # [BACKOFF TRACKING] Clear failure count on successful refresh + self._refresh_failures.pop(path, None) + self._next_refresh_after.pop(path, None) + await self._save_credentials(path, creds_from_file) lib_logger.debug(f"Successfully refreshed iFlow OAuth token for '{Path(path).name}'.") return creds_from_file @@ -584,10 +649,13 @@ async def get_api_details(self, credential_identifier: str) -> Tuple[str, str]: async def proactively_refresh(self, credential_identifier: str): """ Proactively refreshes tokens if they're close to expiry. - Only applies to OAuth credentials (file paths). Direct API keys are skipped. + Only applies to OAuth credentials (file paths or env:// paths). Direct API keys are skipped. """ - # Only refresh if it's an OAuth credential (file path) - if not os.path.isfile(credential_identifier): + # Check if it's an env:// virtual path (OAuth credentials from environment) + is_env_path = credential_identifier.startswith("env://") + + # Only refresh if it's an OAuth credential (file path or env:// path) + if not is_env_path and not os.path.isfile(credential_identifier): return # Direct API key, no refresh needed creds = await self._load_credentials(credential_identifier) diff --git a/src/rotator_library/providers/iflow_provider.py b/src/rotator_library/providers/iflow_provider.py index b602112..28d84f6 100644 --- a/src/rotator_library/providers/iflow_provider.py +++ b/src/rotator_library/providers/iflow_provider.py @@ -1,5 +1,6 @@ # src/rotator_library/providers/iflow_provider.py +import copy import json import time import os @@ -203,7 +204,6 @@ def _clean_tool_schemas(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any Removes unsupported properties from tool schemas to prevent API errors. Similar to Qwen Code implementation. """ - import copy cleaned_tools = [] for tool in tools: @@ -345,6 +345,11 @@ def _convert_chunk_to_openai(self, chunk: Dict[str, Any], model_id: str): def _stream_to_completion_response(self, chunks: List[litellm.ModelResponse]) -> litellm.ModelResponse: """ Manually reassembles streaming chunks into a complete response. + + Key improvements: + - Determines finish_reason based on accumulated state (tool_calls vs stop) + - Properly initializes tool_calls with type field + - Handles usage data extraction from chunks """ if not chunks: raise ValueError("No chunks provided for reassembly") @@ -353,7 +358,7 @@ def _stream_to_completion_response(self, chunks: List[litellm.ModelResponse]) -> final_message = {"role": "assistant"} aggregated_tool_calls = {} usage_data = None - finish_reason = None + chunk_finish_reason = None # Track finish_reason from chunks (but we'll override) # Get the first chunk for basic response metadata first_chunk = chunks[0] @@ -378,12 +383,13 @@ def _stream_to_completion_response(self, chunks: List[litellm.ModelResponse]) -> final_message["reasoning_content"] = "" final_message["reasoning_content"] += delta["reasoning_content"] - # Aggregate tool calls + # Aggregate tool calls with proper initialization if "tool_calls" in delta and delta["tool_calls"]: for tc_chunk in delta["tool_calls"]: - index = tc_chunk["index"] + index = tc_chunk.get("index", 0) if index not in aggregated_tool_calls: - aggregated_tool_calls[index] = {"function": {"name": "", "arguments": ""}} + # Initialize with type field for OpenAI compatibility + aggregated_tool_calls[index] = {"type": "function", "function": {"name": "", "arguments": ""}} if "id" in tc_chunk: aggregated_tool_calls[index]["id"] = tc_chunk["id"] if "type" in tc_chunk: @@ -403,9 +409,9 @@ def _stream_to_completion_response(self, chunks: List[litellm.ModelResponse]) -> if "arguments" in delta["function_call"] and delta["function_call"]["arguments"] is not None: final_message["function_call"]["arguments"] += delta["function_call"]["arguments"] - # Get finish reason from the last chunk that has it + # Track finish_reason from chunks (for reference only) if choice.get("finish_reason"): - finish_reason = choice["finish_reason"] + chunk_finish_reason = choice["finish_reason"] # Handle usage data from the last chunk that has it for chunk in reversed(chunks): @@ -422,6 +428,15 @@ def _stream_to_completion_response(self, chunks: List[litellm.ModelResponse]) -> if field not in final_message: final_message[field] = None + # Determine finish_reason based on accumulated state + # Priority: tool_calls wins if present, then chunk's finish_reason, then default to "stop" + if aggregated_tool_calls: + finish_reason = "tool_calls" + elif chunk_finish_reason: + finish_reason = chunk_finish_reason + else: + finish_reason = "stop" + # Construct the final response final_choice = { "index": 0, diff --git a/src/rotator_library/providers/provider_cache.py b/src/rotator_library/providers/provider_cache.py new file mode 100644 index 0000000..b6bb2db --- /dev/null +++ b/src/rotator_library/providers/provider_cache.py @@ -0,0 +1,498 @@ +# src/rotator_library/providers/provider_cache.py +""" +Shared cache utility for providers. + +A modular, async-capable cache system supporting: +- Dual-TTL: short-lived memory cache, longer-lived disk persistence +- Background persistence with batched writes +- Automatic cleanup of expired entries +- Generic key-value storage for any provider-specific needs + +Usage examples: +- Gemini 3: thoughtSignatures (tool_call_id → encrypted signature) +- Claude: Thinking content (composite_key → thinking text + signature) +- General: Any transient data that benefits from persistence across requests +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import os +import shutil +import tempfile +import time +from pathlib import Path +from typing import Any, Dict, Optional, Tuple + +lib_logger = logging.getLogger('rotator_library') + + +# ============================================================================= +# UTILITY FUNCTIONS +# ============================================================================= + +def _env_bool(key: str, default: bool = False) -> bool: + """Get boolean from environment variable.""" + return os.getenv(key, str(default).lower()).lower() in ("true", "1", "yes") + + +def _env_int(key: str, default: int) -> int: + """Get integer from environment variable.""" + return int(os.getenv(key, str(default))) + + +# ============================================================================= +# PROVIDER CACHE CLASS +# ============================================================================= + +class ProviderCache: + """ + Server-side cache for provider conversation state preservation. + + A generic, modular cache supporting any key-value data that providers need + to persist across requests. Features: + + - Dual-TTL system: configurable memory TTL, longer disk TTL + - Async disk persistence with batched writes + - Background cleanup task for expired entries + - Statistics tracking (hits, misses, writes) + + Args: + cache_file: Path to disk cache file + memory_ttl_seconds: In-memory entry lifetime (default: 1 hour) + disk_ttl_seconds: Disk entry lifetime (default: 24 hours) + enable_disk: Whether to enable disk persistence (default: from env or True) + write_interval: Seconds between background disk writes (default: 60) + cleanup_interval: Seconds between expired entry cleanup (default: 30 min) + env_prefix: Environment variable prefix for configuration overrides + + Environment Variables (with default prefix "PROVIDER_CACHE"): + {PREFIX}_ENABLE: Enable/disable disk persistence + {PREFIX}_WRITE_INTERVAL: Background write interval in seconds + {PREFIX}_CLEANUP_INTERVAL: Cleanup interval in seconds + """ + + def __init__( + self, + cache_file: Path, + memory_ttl_seconds: int = 3600, + disk_ttl_seconds: int = 86400, + enable_disk: Optional[bool] = None, + write_interval: Optional[int] = None, + cleanup_interval: Optional[int] = None, + env_prefix: str = "PROVIDER_CACHE" + ): + # In-memory cache: {cache_key: (data, timestamp)} + self._cache: Dict[str, Tuple[str, float]] = {} + self._memory_ttl = memory_ttl_seconds + self._disk_ttl = disk_ttl_seconds + self._lock = asyncio.Lock() + self._disk_lock = asyncio.Lock() + + # Disk persistence configuration + self._cache_file = cache_file + self._enable_disk = enable_disk if enable_disk is not None else _env_bool(f"{env_prefix}_ENABLE", True) + self._dirty = False + self._write_interval = write_interval or _env_int(f"{env_prefix}_WRITE_INTERVAL", 60) + self._cleanup_interval = cleanup_interval or _env_int(f"{env_prefix}_CLEANUP_INTERVAL", 1800) + + # Background tasks + self._writer_task: Optional[asyncio.Task] = None + self._cleanup_task: Optional[asyncio.Task] = None + self._running = False + + # Statistics + self._stats = {"memory_hits": 0, "disk_hits": 0, "misses": 0, "writes": 0} + + # Metadata about this cache instance + self._cache_name = cache_file.stem if cache_file else "unnamed" + + if self._enable_disk: + lib_logger.debug( + f"ProviderCache[{self._cache_name}]: Disk enabled " + f"(memory_ttl={memory_ttl_seconds}s, disk_ttl={disk_ttl_seconds}s)" + ) + asyncio.create_task(self._async_init()) + else: + lib_logger.debug(f"ProviderCache[{self._cache_name}]: Memory-only mode") + + # ========================================================================= + # INITIALIZATION + # ========================================================================= + + async def _async_init(self) -> None: + """Async initialization: load from disk and start background tasks.""" + try: + await self._load_from_disk() + await self._start_background_tasks() + except Exception as e: + lib_logger.error(f"ProviderCache[{self._cache_name}] async init failed: {e}") + + async def _load_from_disk(self) -> None: + """Load cache from disk file with TTL validation.""" + if not self._enable_disk or not self._cache_file.exists(): + return + + try: + async with self._disk_lock: + with open(self._cache_file, 'r', encoding='utf-8') as f: + data = json.load(f) + + if data.get("version") != "1.0": + lib_logger.warning(f"ProviderCache[{self._cache_name}]: Version mismatch, starting fresh") + return + + now = time.time() + entries = data.get("entries", {}) + loaded = expired = 0 + + for cache_key, entry in entries.items(): + age = now - entry.get("timestamp", 0) + if age <= self._disk_ttl: + value = entry.get("value", entry.get("signature", "")) # Support both formats + if value: + self._cache[cache_key] = (value, entry["timestamp"]) + loaded += 1 + else: + expired += 1 + + lib_logger.debug( + f"ProviderCache[{self._cache_name}]: Loaded {loaded} entries ({expired} expired)" + ) + except json.JSONDecodeError as e: + lib_logger.warning(f"ProviderCache[{self._cache_name}]: File corrupted: {e}") + except Exception as e: + lib_logger.error(f"ProviderCache[{self._cache_name}]: Load failed: {e}") + + # ========================================================================= + # DISK PERSISTENCE + # ========================================================================= + + async def _save_to_disk(self) -> None: + """Persist cache to disk using atomic write.""" + if not self._enable_disk: + return + + try: + async with self._disk_lock: + self._cache_file.parent.mkdir(parents=True, exist_ok=True) + + cache_data = { + "version": "1.0", + "memory_ttl_seconds": self._memory_ttl, + "disk_ttl_seconds": self._disk_ttl, + "entries": { + key: {"value": val, "timestamp": ts} + for key, (val, ts) in self._cache.items() + }, + "statistics": { + "total_entries": len(self._cache), + "last_write": time.time(), + **self._stats + } + } + + # Atomic write using temp file + parent_dir = self._cache_file.parent + tmp_fd, tmp_path = tempfile.mkstemp(dir=parent_dir, prefix='.tmp_', suffix='.json') + + try: + with os.fdopen(tmp_fd, 'w', encoding='utf-8') as f: + json.dump(cache_data, f, indent=2) + + # Set restrictive permissions (if supported) + try: + os.chmod(tmp_path, 0o600) + except (OSError, AttributeError): + pass + + shutil.move(tmp_path, self._cache_file) + self._stats["writes"] += 1 + lib_logger.debug( + f"ProviderCache[{self._cache_name}]: Saved {len(self._cache)} entries" + ) + except Exception: + if tmp_path and os.path.exists(tmp_path): + os.unlink(tmp_path) + raise + except Exception as e: + lib_logger.error(f"ProviderCache[{self._cache_name}]: Disk save failed: {e}") + + # ========================================================================= + # BACKGROUND TASKS + # ========================================================================= + + async def _start_background_tasks(self) -> None: + """Start background writer and cleanup tasks.""" + if not self._enable_disk or self._running: + return + + self._running = True + self._writer_task = asyncio.create_task(self._writer_loop()) + self._cleanup_task = asyncio.create_task(self._cleanup_loop()) + lib_logger.debug(f"ProviderCache[{self._cache_name}]: Started background tasks") + + async def _writer_loop(self) -> None: + """Background task: periodically flush dirty cache to disk.""" + try: + while self._running: + await asyncio.sleep(self._write_interval) + if self._dirty: + try: + await self._save_to_disk() + self._dirty = False + except Exception as e: + lib_logger.error(f"ProviderCache[{self._cache_name}]: Writer error: {e}") + except asyncio.CancelledError: + pass + + async def _cleanup_loop(self) -> None: + """Background task: periodically clean up expired entries.""" + try: + while self._running: + await asyncio.sleep(self._cleanup_interval) + await self._cleanup_expired() + except asyncio.CancelledError: + pass + + async def _cleanup_expired(self) -> None: + """Remove expired entries from memory cache.""" + async with self._lock: + now = time.time() + expired = [k for k, (_, ts) in self._cache.items() if now - ts > self._memory_ttl] + for k in expired: + del self._cache[k] + if expired: + self._dirty = True + lib_logger.debug( + f"ProviderCache[{self._cache_name}]: Cleaned {len(expired)} expired entries" + ) + + # ========================================================================= + # CORE OPERATIONS + # ========================================================================= + + def store(self, key: str, value: str) -> None: + """ + Store a value synchronously (schedules async storage). + + Args: + key: Cache key + value: Value to store (typically JSON-serialized data) + """ + asyncio.create_task(self._async_store(key, value)) + + async def _async_store(self, key: str, value: str) -> None: + """Async implementation of store.""" + async with self._lock: + self._cache[key] = (value, time.time()) + self._dirty = True + + async def store_async(self, key: str, value: str) -> None: + """ + Store a value asynchronously (awaitable). + + Use this when you need to ensure the value is stored before continuing. + """ + await self._async_store(key, value) + + def retrieve(self, key: str) -> Optional[str]: + """ + Retrieve a value by key (synchronous, with optional async disk fallback). + + Args: + key: Cache key + + Returns: + Cached value if found and not expired, None otherwise + """ + if key in self._cache: + value, timestamp = self._cache[key] + if time.time() - timestamp <= self._memory_ttl: + self._stats["memory_hits"] += 1 + return value + else: + del self._cache[key] + self._dirty = True + + self._stats["misses"] += 1 + if self._enable_disk: + # Schedule async disk lookup for next time + asyncio.create_task(self._check_disk_fallback(key)) + return None + + async def retrieve_async(self, key: str) -> Optional[str]: + """ + Retrieve a value asynchronously (checks disk if not in memory). + + Use this when you can await and need guaranteed disk fallback. + """ + # Check memory first + if key in self._cache: + value, timestamp = self._cache[key] + if time.time() - timestamp <= self._memory_ttl: + self._stats["memory_hits"] += 1 + return value + else: + async with self._lock: + if key in self._cache: + del self._cache[key] + self._dirty = True + + # Check disk + if self._enable_disk: + return await self._disk_retrieve(key) + + self._stats["misses"] += 1 + return None + + async def _check_disk_fallback(self, key: str) -> None: + """Check disk for key and load into memory if found (background).""" + try: + if not self._cache_file.exists(): + return + + async with self._disk_lock: + with open(self._cache_file, 'r', encoding='utf-8') as f: + data = json.load(f) + + entries = data.get("entries", {}) + if key in entries: + entry = entries[key] + ts = entry.get("timestamp", 0) + if time.time() - ts <= self._disk_ttl: + value = entry.get("value", entry.get("signature", "")) + if value: + async with self._lock: + self._cache[key] = (value, ts) + self._stats["disk_hits"] += 1 + lib_logger.debug( + f"ProviderCache[{self._cache_name}]: Loaded {key} from disk" + ) + except Exception as e: + lib_logger.debug(f"ProviderCache[{self._cache_name}]: Disk fallback failed: {e}") + + async def _disk_retrieve(self, key: str) -> Optional[str]: + """Direct disk retrieval with loading into memory.""" + try: + if not self._cache_file.exists(): + self._stats["misses"] += 1 + return None + + async with self._disk_lock: + with open(self._cache_file, 'r', encoding='utf-8') as f: + data = json.load(f) + + entries = data.get("entries", {}) + if key in entries: + entry = entries[key] + ts = entry.get("timestamp", 0) + if time.time() - ts <= self._disk_ttl: + value = entry.get("value", entry.get("signature", "")) + if value: + async with self._lock: + self._cache[key] = (value, ts) + self._stats["disk_hits"] += 1 + return value + + self._stats["misses"] += 1 + return None + except Exception as e: + lib_logger.debug(f"ProviderCache[{self._cache_name}]: Disk retrieve failed: {e}") + self._stats["misses"] += 1 + return None + + # ========================================================================= + # UTILITY METHODS + # ========================================================================= + + def contains(self, key: str) -> bool: + """Check if key exists in memory cache (without updating stats).""" + if key in self._cache: + _, timestamp = self._cache[key] + return time.time() - timestamp <= self._memory_ttl + return False + + def get_stats(self) -> Dict[str, Any]: + """Get cache statistics.""" + return { + **self._stats, + "memory_entries": len(self._cache), + "dirty": self._dirty, + "disk_enabled": self._enable_disk + } + + async def clear(self) -> None: + """Clear all cached data.""" + async with self._lock: + self._cache.clear() + self._dirty = True + if self._enable_disk: + await self._save_to_disk() + + async def shutdown(self) -> None: + """Graceful shutdown: flush pending writes and stop background tasks.""" + lib_logger.info(f"ProviderCache[{self._cache_name}]: Shutting down...") + self._running = False + + # Cancel background tasks + for task in (self._writer_task, self._cleanup_task): + if task: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + # Final save + if self._dirty and self._enable_disk: + await self._save_to_disk() + + lib_logger.info( + f"ProviderCache[{self._cache_name}]: Shutdown complete " + f"(stats: mem_hits={self._stats['memory_hits']}, " + f"disk_hits={self._stats['disk_hits']}, misses={self._stats['misses']})" + ) + + +# ============================================================================= +# CONVENIENCE FACTORY +# ============================================================================= + +def create_provider_cache( + name: str, + cache_dir: Optional[Path] = None, + memory_ttl_seconds: int = 3600, + disk_ttl_seconds: int = 86400, + env_prefix: Optional[str] = None +) -> ProviderCache: + """ + Factory function to create a provider cache with sensible defaults. + + Args: + name: Cache name (used as filename and for logging) + cache_dir: Directory for cache file (default: project_root/cache/provider_name) + memory_ttl_seconds: In-memory TTL + disk_ttl_seconds: Disk TTL + env_prefix: Environment variable prefix (default: derived from name) + + Returns: + Configured ProviderCache instance + """ + if cache_dir is None: + cache_dir = Path(__file__).resolve().parent.parent.parent.parent / "cache" + + cache_file = cache_dir / f"{name}.json" + + if env_prefix is None: + # Convert name to env prefix: "gemini3_signatures" -> "GEMINI3_SIGNATURES_CACHE" + env_prefix = f"{name.upper().replace('-', '_')}_CACHE" + + return ProviderCache( + cache_file=cache_file, + memory_ttl_seconds=memory_ttl_seconds, + disk_ttl_seconds=disk_ttl_seconds, + env_prefix=env_prefix + ) diff --git a/src/rotator_library/providers/provider_interface.py b/src/rotator_library/providers/provider_interface.py index 9ca39ec..8a20a64 100644 --- a/src/rotator_library/providers/provider_interface.py +++ b/src/rotator_library/providers/provider_interface.py @@ -66,4 +66,49 @@ async def proactively_refresh(self, credential_path: str): """ Proactively refreshes a token if it's nearing expiry. """ - pass \ No newline at end of file + pass + + # [NEW] Credential Prioritization System + def get_credential_priority(self, credential: str) -> Optional[int]: + """ + Returns the priority level for a credential. + Lower numbers = higher priority (1 is highest). + Returns None if provider doesn't use priorities. + + This allows providers to auto-detect credential tiers (e.g., paid vs free) + and ensure higher-tier credentials are always tried first. + + Args: + credential: The credential identifier (API key or path) + + Returns: + Priority level (1-10) or None if no priority system + + Example: + For Gemini CLI: + - Paid tier credentials: priority 1 (highest) + - Free tier credentials: priority 2 + - Unknown tier: priority 10 (lowest) + """ + return None + + def get_model_tier_requirement(self, model: str) -> Optional[int]: + """ + Returns the minimum priority tier required for a model. + If a model requires priority 1, only credentials with priority <= 1 can use it. + + This allows providers to restrict certain models to specific credential tiers. + For example, Gemini 3 models require paid-tier credentials. + + Args: + model: The model name (with or without provider prefix) + + Returns: + Minimum required priority level or None if no restrictions + + Example: + For Gemini CLI: + - gemini-3-*: requires priority 1 (paid tier only) + - gemini-2.5-*: no restriction (None) + """ + return None \ No newline at end of file diff --git a/src/rotator_library/providers/qwen_auth_base.py b/src/rotator_library/providers/qwen_auth_base.py index 9d028c7..589e6be 100644 --- a/src/rotator_library/providers/qwen_auth_base.py +++ b/src/rotator_library/providers/qwen_auth_base.py @@ -47,46 +47,78 @@ def __init__(self): self._queue_tracking_lock = asyncio.Lock() # Protects queue sets self._queue_processor_task: Optional[asyncio.Task] = None # Background worker task - def _load_from_env(self) -> Optional[Dict[str, Any]]: + def _parse_env_credential_path(self, path: str) -> Optional[str]: + """ + Parse a virtual env:// path and return the credential index. + + Supported formats: + - "env://provider/0" - Legacy single credential (no index in env var names) + - "env://provider/1" - First numbered credential (QWEN_CODE_1_ACCESS_TOKEN) + + Returns: + The credential index as string, or None if path is not an env:// path + """ + if not path.startswith("env://"): + return None + + parts = path[6:].split("/") + if len(parts) >= 2: + return parts[1] + return "0" + + def _load_from_env(self, credential_index: Optional[str] = None) -> Optional[Dict[str, Any]]: """ Load OAuth credentials from environment variables for stateless deployments. - Expected environment variables: - - QWEN_CODE_ACCESS_TOKEN (required) - - QWEN_CODE_REFRESH_TOKEN (required) - - QWEN_CODE_EXPIRY_DATE (optional, defaults to 0) - - QWEN_CODE_RESOURCE_URL (optional, defaults to https://portal.qwen.ai/v1) - - QWEN_CODE_EMAIL (optional, defaults to "env-user") + Supports two formats: + 1. Legacy (credential_index="0" or None): QWEN_CODE_ACCESS_TOKEN + 2. Numbered (credential_index="1", "2", etc.): QWEN_CODE_1_ACCESS_TOKEN, etc. + + Expected environment variables (for numbered format with index N): + - QWEN_CODE_{N}_ACCESS_TOKEN (required) + - QWEN_CODE_{N}_REFRESH_TOKEN (required) + - QWEN_CODE_{N}_EXPIRY_DATE (optional, defaults to 0) + - QWEN_CODE_{N}_RESOURCE_URL (optional, defaults to https://portal.qwen.ai/v1) + - QWEN_CODE_{N}_EMAIL (optional, defaults to "env-user-{N}") Returns: Dict with credential structure if env vars present, None otherwise """ - access_token = os.getenv("QWEN_CODE_ACCESS_TOKEN") - refresh_token = os.getenv("QWEN_CODE_REFRESH_TOKEN") + # Determine the env var prefix based on credential index + if credential_index and credential_index != "0": + prefix = f"QWEN_CODE_{credential_index}" + default_email = f"env-user-{credential_index}" + else: + prefix = "QWEN_CODE" + default_email = "env-user" + + access_token = os.getenv(f"{prefix}_ACCESS_TOKEN") + refresh_token = os.getenv(f"{prefix}_REFRESH_TOKEN") # Both access and refresh tokens are required if not (access_token and refresh_token): return None - lib_logger.debug("Loading Qwen Code credentials from environment variables") + lib_logger.debug(f"Loading Qwen Code credentials from environment variables (prefix: {prefix})") # Parse expiry_date as float, default to 0 if not present - expiry_str = os.getenv("QWEN_CODE_EXPIRY_DATE", "0") + expiry_str = os.getenv(f"{prefix}_EXPIRY_DATE", "0") try: expiry_date = float(expiry_str) except ValueError: - lib_logger.warning(f"Invalid QWEN_CODE_EXPIRY_DATE value: {expiry_str}, using 0") + lib_logger.warning(f"Invalid {prefix}_EXPIRY_DATE value: {expiry_str}, using 0") expiry_date = 0 creds = { "access_token": access_token, "refresh_token": refresh_token, "expiry_date": expiry_date, - "resource_url": os.getenv("QWEN_CODE_RESOURCE_URL", "https://portal.qwen.ai/v1"), + "resource_url": os.getenv(f"{prefix}_RESOURCE_URL", "https://portal.qwen.ai/v1"), "_proxy_metadata": { - "email": os.getenv("QWEN_CODE_EMAIL", "env-user"), + "email": os.getenv(f"{prefix}_EMAIL", default_email), "last_check_timestamp": time.time(), - "loaded_from_env": True # Flag to indicate env-based credentials + "loaded_from_env": True, + "env_credential_index": credential_index or "0" } } @@ -115,11 +147,21 @@ async def _load_credentials(self, path: str) -> Dict[str, Any]: if path in self._credentials_cache: return self._credentials_cache[path] - # First, try loading from environment variables + # Check if this is a virtual env:// path + credential_index = self._parse_env_credential_path(path) + if credential_index is not None: + env_creds = self._load_from_env(credential_index) + if env_creds: + lib_logger.info(f"Using Qwen Code credentials from environment variables (index: {credential_index})") + self._credentials_cache[path] = env_creds + return env_creds + else: + raise IOError(f"Environment variables for Qwen Code credential index {credential_index} not found") + + # For file paths, try loading from legacy env vars first env_creds = self._load_from_env() if env_creds: lib_logger.info("Using Qwen Code credentials from environment variables") - # Cache env-based credentials using the path as key self._credentials_cache[path] = env_creds return env_creds @@ -274,12 +316,25 @@ async def _refresh_token(self, path: str, force: bool = False) -> Dict[str, Any] try: # Call initialize_token to trigger OAuth flow new_creds = await self.initialize_token(path) + # Clear backoff on successful re-auth + self._refresh_failures.pop(path, None) + self._next_refresh_after.pop(path, None) return new_creds except Exception as reauth_error: lib_logger.error(f"Re-authentication failed for '{Path(path).name}': {reauth_error}") + # [BACKOFF TRACKING] Increment failure count and set backoff timer + self._refresh_failures[path] = self._refresh_failures.get(path, 0) + 1 + backoff_seconds = min(300, 30 * (2 ** self._refresh_failures[path])) # Max 5 min backoff + self._next_refresh_after[path] = time.time() + backoff_seconds + lib_logger.debug(f"Setting backoff for '{Path(path).name}': {backoff_seconds}s") raise ValueError(f"Refresh token invalid and re-authentication failed: {reauth_error}") if new_token_data is None: + # [BACKOFF TRACKING] Increment failure count and set backoff timer + self._refresh_failures[path] = self._refresh_failures.get(path, 0) + 1 + backoff_seconds = min(300, 30 * (2 ** self._refresh_failures[path])) # Max 5 min backoff + self._next_refresh_after[path] = time.time() + backoff_seconds + lib_logger.debug(f"Setting backoff for '{Path(path).name}': {backoff_seconds}s") raise last_error or Exception("Token refresh failed after all retries") creds_from_file["access_token"] = new_token_data["access_token"] @@ -292,6 +347,16 @@ async def _refresh_token(self, path: str, force: bool = False) -> Dict[str, Any] creds_from_file["_proxy_metadata"] = {} creds_from_file["_proxy_metadata"]["last_check_timestamp"] = time.time() + # [VALIDATION] Verify required fields exist after refresh + required_fields = ["access_token", "refresh_token"] + missing_fields = [field for field in required_fields if not creds_from_file.get(field)] + if missing_fields: + raise ValueError(f"Refreshed credentials missing required fields: {missing_fields}") + + # [BACKOFF TRACKING] Clear failure count on successful refresh + self._refresh_failures.pop(path, None) + self._next_refresh_after.pop(path, None) + await self._save_credentials(path, creds_from_file) lib_logger.debug(f"Successfully refreshed Qwen OAuth token for '{Path(path).name}'.") return creds_from_file @@ -328,10 +393,13 @@ async def get_api_details(self, credential_identifier: str) -> Tuple[str, str]: async def proactively_refresh(self, credential_identifier: str): """ Proactively refreshes tokens if they're close to expiry. - Only applies to OAuth credentials (file paths). Direct API keys are skipped. + Only applies to OAuth credentials (file paths or env:// paths). Direct API keys are skipped. """ - # Only refresh if it's an OAuth credential (file path) - if not os.path.isfile(credential_identifier): + # Check if it's an env:// virtual path (OAuth credentials from environment) + is_env_path = credential_identifier.startswith("env://") + + # Only refresh if it's an OAuth credential (file path or env:// path) + if not is_env_path and not os.path.isfile(credential_identifier): return # Direct API key, no refresh needed creds = await self._load_credentials(credential_identifier) diff --git a/src/rotator_library/providers/qwen_code_provider.py b/src/rotator_library/providers/qwen_code_provider.py index d57c88d..334e314 100644 --- a/src/rotator_library/providers/qwen_code_provider.py +++ b/src/rotator_library/providers/qwen_code_provider.py @@ -1,5 +1,6 @@ # src/rotator_library/providers/qwen_code_provider.py +import copy import json import time import os @@ -186,7 +187,6 @@ def _clean_tool_schemas(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any Removes unsupported properties from tool schemas to prevent API errors. Adapted for Qwen's API requirements. """ - import copy cleaned_tools = [] for tool in tools: @@ -263,15 +263,38 @@ def _build_request_payload(self, **kwargs) -> Dict[str, Any]: return payload def _convert_chunk_to_openai(self, chunk: Dict[str, Any], model_id: str): - """Converts a raw Qwen SSE chunk to an OpenAI-compatible chunk.""" + """ + Converts a raw Qwen SSE chunk to an OpenAI-compatible chunk. + + CRITICAL FIX: Handle chunks with BOTH usage and choices (final chunk) + without early return to ensure finish_reason is properly processed. + """ if not isinstance(chunk, dict): return - # Handle usage data - if usage_data := chunk.get("usage"): + # Get choices and usage data + choices = chunk.get("choices", []) + usage_data = chunk.get("usage") + chunk_id = chunk.get("id", f"chatcmpl-qwen-{time.time()}") + chunk_created = chunk.get("created", int(time.time())) + + # Handle chunks with BOTH choices and usage (typical for final chunk) + # CRITICAL: Process choices FIRST to capture finish_reason, then yield usage + if choices and usage_data: + choice = choices[0] + delta = choice.get("delta", {}) + finish_reason = choice.get("finish_reason") + + # Yield the choice chunk first (contains finish_reason) + yield { + "choices": [{"index": 0, "delta": delta, "finish_reason": finish_reason}], + "model": model_id, "object": "chat.completion.chunk", + "id": chunk_id, "created": chunk_created + } + # Then yield the usage chunk yield { "choices": [], "model": model_id, "object": "chat.completion.chunk", - "id": f"chatcmpl-qwen-{time.time()}", "created": int(time.time()), + "id": chunk_id, "created": chunk_created, "usage": { "prompt_tokens": usage_data.get("prompt_tokens", 0), "completion_tokens": usage_data.get("completion_tokens", 0), @@ -280,8 +303,20 @@ def _convert_chunk_to_openai(self, chunk: Dict[str, Any], model_id: str): } return - # Handle content data - choices = chunk.get("choices", []) + # Handle usage-only chunks + if usage_data: + yield { + "choices": [], "model": model_id, "object": "chat.completion.chunk", + "id": chunk_id, "created": chunk_created, + "usage": { + "prompt_tokens": usage_data.get("prompt_tokens", 0), + "completion_tokens": usage_data.get("completion_tokens", 0), + "total_tokens": usage_data.get("total_tokens", 0), + } + } + return + + # Handle content-only chunks if not choices: return @@ -307,20 +342,24 @@ def _convert_chunk_to_openai(self, chunk: Dict[str, Any], model_id: str): yield { "choices": [{"index": 0, "delta": new_delta, "finish_reason": None}], "model": model_id, "object": "chat.completion.chunk", - "id": f"chatcmpl-qwen-{time.time()}", "created": int(time.time()) + "id": chunk_id, "created": chunk_created } else: # Standard content chunk yield { "choices": [{"index": 0, "delta": delta, "finish_reason": finish_reason}], "model": model_id, "object": "chat.completion.chunk", - "id": f"chatcmpl-qwen-{time.time()}", "created": int(time.time()) + "id": chunk_id, "created": chunk_created } def _stream_to_completion_response(self, chunks: List[litellm.ModelResponse]) -> litellm.ModelResponse: """ Manually reassembles streaming chunks into a complete response. - This replaces the non-existent litellm.utils.stream_to_completion_response function. + + Key improvements: + - Determines finish_reason based on accumulated state (tool_calls vs stop) + - Properly initializes tool_calls with type field + - Handles usage data extraction from chunks """ if not chunks: raise ValueError("No chunks provided for reassembly") @@ -329,7 +368,7 @@ def _stream_to_completion_response(self, chunks: List[litellm.ModelResponse]) -> final_message = {"role": "assistant"} aggregated_tool_calls = {} usage_data = None - finish_reason = None + chunk_finish_reason = None # Track finish_reason from chunks (but we'll override) # Get the first chunk for basic response metadata first_chunk = chunks[0] @@ -354,14 +393,17 @@ def _stream_to_completion_response(self, chunks: List[litellm.ModelResponse]) -> final_message["reasoning_content"] = "" final_message["reasoning_content"] += delta["reasoning_content"] - # Aggregate tool calls + # Aggregate tool calls with proper initialization if "tool_calls" in delta and delta["tool_calls"]: for tc_chunk in delta["tool_calls"]: - index = tc_chunk["index"] + index = tc_chunk.get("index", 0) if index not in aggregated_tool_calls: - aggregated_tool_calls[index] = {"function": {"name": "", "arguments": ""}} + # Initialize with type field for OpenAI compatibility + aggregated_tool_calls[index] = {"type": "function", "function": {"name": "", "arguments": ""}} if "id" in tc_chunk: aggregated_tool_calls[index]["id"] = tc_chunk["id"] + if "type" in tc_chunk: + aggregated_tool_calls[index]["type"] = tc_chunk["type"] if "function" in tc_chunk: if "name" in tc_chunk["function"] and tc_chunk["function"]["name"] is not None: aggregated_tool_calls[index]["function"]["name"] += tc_chunk["function"]["name"] @@ -377,9 +419,9 @@ def _stream_to_completion_response(self, chunks: List[litellm.ModelResponse]) -> if "arguments" in delta["function_call"] and delta["function_call"]["arguments"] is not None: final_message["function_call"]["arguments"] += delta["function_call"]["arguments"] - # Get finish reason from the last chunk that has it + # Track finish_reason from chunks (for reference only) if choice.get("finish_reason"): - finish_reason = choice["finish_reason"] + chunk_finish_reason = choice["finish_reason"] # Handle usage data from the last chunk that has it for chunk in reversed(chunks): @@ -396,6 +438,15 @@ def _stream_to_completion_response(self, chunks: List[litellm.ModelResponse]) -> if field not in final_message: final_message[field] = None + # Determine finish_reason based on accumulated state + # Priority: tool_calls wins if present, then chunk's finish_reason, then default to "stop" + if aggregated_tool_calls: + finish_reason = "tool_calls" + elif chunk_finish_reason: + finish_reason = chunk_finish_reason + else: + finish_reason = "stop" + # Construct the final response final_choice = { "index": 0, diff --git a/src/rotator_library/pyproject.toml b/src/rotator_library/pyproject.toml index a8dacd3..4cfa41a 100644 --- a/src/rotator_library/pyproject.toml +++ b/src/rotator_library/pyproject.toml @@ -3,8 +3,8 @@ requires = ["setuptools>=61.0"] build-backend = "setuptools.build_meta" [project] -name = "rotating-api-key-client" -version = "0.9" +name = "rotator_library" +version = "0.95" authors = [ { name="Mirrowel", email="nuh@uh.com" }, ] diff --git a/src/rotator_library/usage_manager.py b/src/rotator_library/usage_manager.py index ec1f122..4ec2b82 100644 --- a/src/rotator_library/usage_manager.py +++ b/src/rotator_library/usage_manager.py @@ -3,6 +3,7 @@ import time import logging import asyncio +import random from datetime import date, datetime, timezone, time as dt_time from typing import Any, Dict, List, Optional, Set import aiofiles @@ -20,15 +21,48 @@ class UsageManager: """ Manages usage statistics and cooldowns for API keys with asyncio-safe locking, - asynchronous file I/O, and a lazy-loading mechanism for usage data. + asynchronous file I/O, lazy-loading mechanism, and weighted random credential rotation. + + The credential rotation strategy can be configured via the `rotation_tolerance` parameter: + + - **tolerance = 0.0**: Deterministic least-used selection. The credential with + the lowest usage count is always selected. This provides predictable, perfectly balanced + load distribution but may be vulnerable to fingerprinting. + + - **tolerance = 2.0 - 4.0 (default, recommended)**: Balanced weighted randomness. Credentials are selected + randomly with weights biased toward less-used ones. Credentials within 2 uses of the + maximum can still be selected with reasonable probability. This provides security through + unpredictability while maintaining good load balance. + + - **tolerance = 5.0+**: High randomness. Even heavily-used credentials have significant + selection probability. Useful for stress testing or maximum unpredictability, but may + result in less balanced load distribution. + + The weight formula is: `weight = (max_usage - credential_usage) + tolerance + 1` + + This ensures lower-usage credentials are preferred while tolerance controls how much + randomness is introduced into the selection process. """ def __init__( self, file_path: str = "key_usage.json", daily_reset_time_utc: Optional[str] = "03:00", + rotation_tolerance: float = 0.0, ): + """ + Initialize the UsageManager. + + Args: + file_path: Path to the usage data JSON file + daily_reset_time_utc: Time in UTC when daily stats should reset (HH:MM format) + rotation_tolerance: Tolerance for weighted random credential rotation. + - 0.0: Deterministic, least-used credential always selected + - tolerance = 2.0 - 4.0 (default, recommended): Balanced randomness, can pick credentials within 2 uses of max + - 5.0+: High randomness, more unpredictable selection patterns + """ self.file_path = file_path + self.rotation_tolerance = rotation_tolerance self.key_states: Dict[str, Dict[str, Any]] = {} self._data_lock = asyncio.Lock() @@ -160,13 +194,90 @@ def _initialize_key_states(self, keys: List[str]): "models_in_use": {}, # Dict[model_name, concurrent_count] } + def _select_weighted_random( + self, + candidates: List[tuple], + tolerance: float + ) -> str: + """ + Selects a credential using weighted random selection based on usage counts. + + Args: + candidates: List of (credential_id, usage_count) tuples + tolerance: Tolerance value for weight calculation + + Returns: + Selected credential ID + + Formula: + weight = (max_usage - credential_usage) + tolerance + 1 + + This formula ensures: + - Lower usage = higher weight = higher selection probability + - Tolerance adds variability: higher tolerance means more randomness + - The +1 ensures all credentials have at least some chance of selection + """ + if not candidates: + raise ValueError("Cannot select from empty candidate list") + + if len(candidates) == 1: + return candidates[0][0] + + # Extract usage counts + usage_counts = [usage for _, usage in candidates] + max_usage = max(usage_counts) + + # Calculate weights using the formula: (max - current) + tolerance + 1 + weights = [] + for credential, usage in candidates: + weight = (max_usage - usage) + tolerance + 1 + weights.append(weight) + + # Log weight distribution for debugging + if lib_logger.isEnabledFor(logging.DEBUG): + total_weight = sum(weights) + weight_info = ", ".join( + f"...{cred[-6:]}: w={w:.1f} ({w/total_weight*100:.1f}%)" + for (cred, _), w in zip(candidates, weights) + ) + #lib_logger.debug(f"Weighted selection candidates: {weight_info}") + + # Random selection with weights + selected_credential = random.choices( + [cred for cred, _ in candidates], + weights=weights, + k=1 + )[0] + + return selected_credential + async def acquire_key( self, available_keys: List[str], model: str, deadline: float, - max_concurrent: int = 1 + max_concurrent: int = 1, + credential_priorities: Optional[Dict[str, int]] = None ) -> str: """ Acquires the best available key using a tiered, model-aware locking strategy, - respecting a global deadline. + respecting a global deadline and credential priorities. + + Priority Logic: + - Groups credentials by priority level (1=highest, 2=lower, etc.) + - Always tries highest priority (lowest number) first + - Within same priority, sorts by usage count (load balancing) + - Only moves to next priority if all higher-priority keys exhausted/busy + + Args: + available_keys: List of credential identifiers to choose from + model: Model name being requested + deadline: Timestamp after which to stop trying + max_concurrent: Maximum concurrent requests allowed per credential + credential_priorities: Optional dict mapping credentials to priority levels (1=highest) + + Returns: + Selected credential identifier + + Raises: + NoAvailableKeysError: If no key could be acquired within the deadline """ await self._lazy_init() await self._reset_daily_stats_if_needed() @@ -174,78 +285,207 @@ async def acquire_key( # This loop continues as long as the global deadline has not been met. while time.time() < deadline: - tier1_keys, tier2_keys = [], [] now = time.time() - # First, filter the list of available keys to exclude any on cooldown. - async with self._data_lock: - for key in available_keys: - key_data = self._usage_data.get(key, {}) - - if (key_data.get("key_cooldown_until") or 0) > now or ( - key_data.get("model_cooldowns", {}).get(model) or 0 - ) > now: - continue - - # Prioritize keys based on their current usage to ensure load balancing. - usage_count = ( - key_data.get("daily", {}) - .get("models", {}) - .get(model, {}) - .get("success_count", 0) - ) - key_state = self.key_states[key] - - # Tier 1: Completely idle keys (preferred). - if not key_state["models_in_use"]: - tier1_keys.append((key, usage_count)) - # Tier 2: Keys that can accept more concurrent requests for this model. - elif key_state["models_in_use"].get(model, 0) < max_concurrent: - tier2_keys.append((key, usage_count)) - - tier1_keys.sort(key=lambda x: x[1]) - tier2_keys.sort(key=lambda x: x[1]) - - # Attempt to acquire a key from Tier 1 first. - for key, _ in tier1_keys: - state = self.key_states[key] - async with state["lock"]: - if not state["models_in_use"]: - state["models_in_use"][model] = 1 - lib_logger.info( - f"Acquired Tier 1 key ...{key[-6:]} for model {model}" + # Group credentials by priority level (if priorities provided) + if credential_priorities: + # Group keys by priority level + priority_groups = {} + async with self._data_lock: + for key in available_keys: + key_data = self._usage_data.get(key, {}) + + # Skip keys on cooldown + if (key_data.get("key_cooldown_until") or 0) > now or ( + key_data.get("model_cooldowns", {}).get(model) or 0 + ) > now: + continue + + # Get priority for this key (default to 999 if not specified) + priority = credential_priorities.get(key, 999) + + # Get usage count for load balancing within priority groups + usage_count = ( + key_data.get("daily", {}) + .get("models", {}) + .get(model, {}) + .get("success_count", 0) ) - return key - - # If no Tier 1 keys are available, try Tier 2. - for key, _ in tier2_keys: - state = self.key_states[key] - async with state["lock"]: - current_count = state["models_in_use"].get(model, 0) - if current_count < max_concurrent: - state["models_in_use"][model] = current_count + 1 - lib_logger.info( - f"Acquired Tier 2 key ...{key[-6:]} for model {model} " - f"(concurrent: {state['models_in_use'][model]}/{max_concurrent})" + + # Group by priority + if priority not in priority_groups: + priority_groups[priority] = [] + priority_groups[priority].append((key, usage_count)) + + # Try priority groups in order (1, 2, 3, ...) + sorted_priorities = sorted(priority_groups.keys()) + + for priority_level in sorted_priorities: + keys_in_priority = priority_groups[priority_level] + + # Within each priority group, use existing tier1/tier2 logic + tier1_keys, tier2_keys = [], [] + for key, usage_count in keys_in_priority: + key_state = self.key_states[key] + + # Tier 1: Completely idle keys (preferred) + if not key_state["models_in_use"]: + tier1_keys.append((key, usage_count)) + # Tier 2: Keys that can accept more concurrent requests + elif key_state["models_in_use"].get(model, 0) < max_concurrent: + tier2_keys.append((key, usage_count)) + + # Apply weighted random selection or deterministic sorting + selection_method = "weighted-random" if self.rotation_tolerance > 0 else "least-used" + + if self.rotation_tolerance > 0: + # Weighted random selection within each tier + if tier1_keys: + selected_key = self._select_weighted_random(tier1_keys, self.rotation_tolerance) + tier1_keys = [(k, u) for k, u in tier1_keys if k == selected_key] + if tier2_keys: + selected_key = self._select_weighted_random(tier2_keys, self.rotation_tolerance) + tier2_keys = [(k, u) for k, u in tier2_keys if k == selected_key] + else: + # Deterministic: sort by usage within each tier + tier1_keys.sort(key=lambda x: x[1]) + tier2_keys.sort(key=lambda x: x[1]) + + # Try to acquire from Tier 1 first + for key, usage in tier1_keys: + state = self.key_states[key] + async with state["lock"]: + if not state["models_in_use"]: + state["models_in_use"][model] = 1 + lib_logger.info( + f"Acquired Priority-{priority_level} Tier-1 key ...{key[-6:]} for model {model} " + f"(selection: {selection_method}, usage: {usage})" + ) + return key + + # Then try Tier 2 + for key, usage in tier2_keys: + state = self.key_states[key] + async with state["lock"]: + current_count = state["models_in_use"].get(model, 0) + if current_count < max_concurrent: + state["models_in_use"][model] = current_count + 1 + lib_logger.info( + f"Acquired Priority-{priority_level} Tier-2 key ...{key[-6:]} for model {model} " + f"(selection: {selection_method}, concurrent: {state['models_in_use'][model]}/{max_concurrent}, usage: {usage})" + ) + return key + + # If we get here, all priority groups were exhausted but keys might become available + # Collect all keys across all priorities for waiting + all_potential_keys = [] + for keys_list in priority_groups.values(): + all_potential_keys.extend(keys_list) + + if not all_potential_keys: + lib_logger.warning( + "No keys are eligible (all on cooldown or filtered out). Waiting before re-evaluating." + ) + await asyncio.sleep(1) + continue + + # Wait for the highest priority key with lowest usage + best_priority = min(priority_groups.keys()) + best_priority_keys = priority_groups[best_priority] + best_wait_key = min(best_priority_keys, key=lambda x: x[1])[0] + wait_condition = self.key_states[best_wait_key]["condition"] + + lib_logger.info( + f"All Priority-{best_priority} keys are busy. Waiting for highest priority credential to become available..." + ) + + else: + # Original logic when no priorities specified + tier1_keys, tier2_keys = [], [] + + # First, filter the list of available keys to exclude any on cooldown. + async with self._data_lock: + for key in available_keys: + key_data = self._usage_data.get(key, {}) + + if (key_data.get("key_cooldown_until") or 0) > now or ( + key_data.get("model_cooldowns", {}).get(model) or 0 + ) > now: + continue + + # Prioritize keys based on their current usage to ensure load balancing. + usage_count = ( + key_data.get("daily", {}) + .get("models", {}) + .get(model, {}) + .get("success_count", 0) ) - return key - - # If all eligible keys are locked, wait for a key to be released. - lib_logger.info( - "All eligible keys are currently locked for this model. Waiting..." - ) + key_state = self.key_states[key] + + # Tier 1: Completely idle keys (preferred). + if not key_state["models_in_use"]: + tier1_keys.append((key, usage_count)) + # Tier 2: Keys that can accept more concurrent requests for this model. + elif key_state["models_in_use"].get(model, 0) < max_concurrent: + tier2_keys.append((key, usage_count)) + + # Apply weighted random selection or deterministic sorting + selection_method = "weighted-random" if self.rotation_tolerance > 0 else "least-used" + + if self.rotation_tolerance > 0: + # Weighted random selection within each tier + if tier1_keys: + selected_key = self._select_weighted_random(tier1_keys, self.rotation_tolerance) + tier1_keys = [(k, u) for k, u in tier1_keys if k == selected_key] + if tier2_keys: + selected_key = self._select_weighted_random(tier2_keys, self.rotation_tolerance) + tier2_keys = [(k, u) for k, u in tier2_keys if k == selected_key] + else: + # Deterministic: sort by usage within each tier + tier1_keys.sort(key=lambda x: x[1]) + tier2_keys.sort(key=lambda x: x[1]) + + # Attempt to acquire a key from Tier 1 first. + for key, usage in tier1_keys: + state = self.key_states[key] + async with state["lock"]: + if not state["models_in_use"]: + state["models_in_use"][model] = 1 + lib_logger.info( + f"Acquired Tier 1 key ...{key[-6:]} for model {model} " + f"(selection: {selection_method}, usage: {usage})" + ) + return key + + # If no Tier 1 keys are available, try Tier 2. + for key, usage in tier2_keys: + state = self.key_states[key] + async with state["lock"]: + current_count = state["models_in_use"].get(model, 0) + if current_count < max_concurrent: + state["models_in_use"][model] = current_count + 1 + lib_logger.info( + f"Acquired Tier 2 key ...{key[-6:]} for model {model} " + f"(selection: {selection_method}, concurrent: {state['models_in_use'][model]}/{max_concurrent}, usage: {usage})" + ) + return key - all_potential_keys = tier1_keys + tier2_keys - if not all_potential_keys: - lib_logger.warning( - "No keys are eligible (all on cooldown). Waiting before re-evaluating." + # If all eligible keys are locked, wait for a key to be released. + lib_logger.info( + "All eligible keys are currently locked for this model. Waiting..." ) - await asyncio.sleep(1) - continue - # Wait on the condition of the key with the lowest current usage. - best_wait_key = min(all_potential_keys, key=lambda x: x[1])[0] - wait_condition = self.key_states[best_wait_key]["condition"] + all_potential_keys = tier1_keys + tier2_keys + if not all_potential_keys: + lib_logger.warning( + "No keys are eligible (all on cooldown). Waiting before re-evaluating." + ) + await asyncio.sleep(1) + continue + + # Wait on the condition of the key with the lowest current usage. + best_wait_key = min(all_potential_keys, key=lambda x: x[1])[0] + wait_condition = self.key_states[best_wait_key]["condition"] try: async with wait_condition: @@ -266,6 +506,8 @@ async def acquire_key( f"Could not acquire a key for model {model} within the global time budget." ) + + async def release_key(self, key: str, model: str): """Releases a key's lock for a specific model and notifies waiting tasks.""" if key not in self.key_states: diff --git a/todo.md b/todo.md new file mode 100644 index 0000000..5966e4b --- /dev/null +++ b/todo.md @@ -0,0 +1,7 @@ +~~Refine claude injection to inject even if we have correct thinking - to force it to think if we made ultrathink prompt. If last msg is tool use and you prompt - it never thinks again.~~ Maybe done + +Anthropic translation and anthropic compatible endpoint. + +Refine for deployment. + +