Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/cozempic/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,7 @@ def build_parser() -> argparse.ArgumentParser:
)
parser.add_argument("--version", action="version", version="%(prog)s 0.8.0")
parser.add_argument("--context-window", type=int, default=None, help="Override context window size in tokens (e.g. 1000000 for 1M beta)")
parser.add_argument("--system-overhead-tokens", type=int, default=None, help="Override system overhead estimate (default: 21000). Increase for heavy rules/MCP configs.")
sub = parser.add_subparsers(dest="command")

session_help = "Session ID, UUID prefix, path, or 'current' for auto-detect"
Expand Down Expand Up @@ -710,8 +711,8 @@ def build_parser() -> argparse.ArgumentParser:
p_guard.add_argument("--threshold", type=float, default=50.0, help="Hard threshold in MB — full prune + reload (default: 50)")
p_guard.add_argument("--soft-threshold", type=float, default=None, help="Soft threshold in MB — gentle prune, no reload (default: 60%% of --threshold)")
p_guard.add_argument("--interval", type=int, default=30, help="Check interval in seconds (default: 30)")
p_guard.add_argument("--threshold-tokens", type=int, default=None, help="Hard threshold in tokens (checked alongside --threshold)")
p_guard.add_argument("--soft-threshold-tokens", type=int, default=None, help="Soft threshold in tokens (checked alongside --soft-threshold)")
p_guard.add_argument("--threshold-tokens", type=int, default=None, help="Hard threshold in tokens (default: 75%% of context window)")
p_guard.add_argument("--soft-threshold-tokens", type=int, default=None, help="Soft threshold in tokens (default: 45%% of context window)")
p_guard.add_argument("--no-reload", action="store_true", help="Prune without auto-reload at hard threshold")
p_guard.add_argument("--no-reactive", action="store_true", help="Disable reactive overflow recovery (kqueue/polling watcher)")
p_guard.add_argument("--daemon", action="store_true", help="Run in background (PID file prevents double-starts)")
Expand All @@ -736,9 +737,11 @@ def main():
parser = build_parser()
args = parser.parse_args()

# Set context window override if provided
# Set overrides via env vars (used by tokens.py)
if args.context_window:
os.environ["COZEMPIC_CONTEXT_WINDOW"] = str(args.context_window)
if args.system_overhead_tokens:
os.environ["COZEMPIC_SYSTEM_OVERHEAD_TOKENS"] = str(args.system_overhead_tokens)

if not args.command:
parser.print_help()
Expand Down
11 changes: 10 additions & 1 deletion src/cozempic/guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from .registry import PRESCRIPTIONS
from .session import find_claude_pid, find_current_session, find_sessions, load_messages, save_messages
from .team import TeamState, extract_team_state, inject_team_recovery, write_team_checkpoint
from .tokens import quick_token_estimate
from .tokens import default_token_thresholds, quick_token_estimate


def _resolve_session_by_id(session_id: str) -> dict | None:
Expand Down Expand Up @@ -212,6 +212,15 @@ def start_guard(

session_path = sess["path"]

# Default to token-based thresholds when none specified
if threshold_tokens is None:
from .tokens import detect_context_window
messages_for_model = load_messages(session_path)
context_window = detect_context_window(messages_for_model)
threshold_tokens, soft_threshold_tokens = default_token_thresholds(context_window)
elif soft_threshold_tokens is None:
soft_threshold_tokens = int(threshold_tokens * 0.6)

print(f"\n COZEMPIC GUARD v3")
print(f" ═══════════════════════════════════════════════════════════════════")
print(f" Session: {session_path.name}")
Expand Down
39 changes: 36 additions & 3 deletions src/cozempic/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,38 @@
DEFAULT_CONTEXT_WINDOW = 200_000
SYSTEM_OVERHEAD_TOKENS = 21_000

# Default token thresholds as fractions of context window
DEFAULT_HARD_TOKEN_PCT = 0.75 # 75% — hard prune + reload
DEFAULT_SOFT_TOKEN_PCT = 0.45 # 45% — gentle prune, no reload


def get_system_overhead_tokens() -> int:
"""Get system overhead token estimate, checking env var override.

Sessions with heavy rules files, MCP servers, and tool schemas can
have 30K-40K+ tokens of system overhead. The default (21K) is
conservative for lightweight sessions. Override with
COZEMPIC_SYSTEM_OVERHEAD_TOKENS env var or --system-overhead-tokens flag.
"""
import os
val = os.environ.get("COZEMPIC_SYSTEM_OVERHEAD_TOKENS")
if val:
try:
return int(val)
except ValueError:
pass
return SYSTEM_OVERHEAD_TOKENS


def default_token_thresholds(context_window: int = DEFAULT_CONTEXT_WINDOW) -> tuple[int, int]:
"""Compute default hard and soft token thresholds from context window.

Returns (hard_threshold, soft_threshold) in tokens.
"""
hard = int(context_window * DEFAULT_HARD_TOKEN_PCT)
soft = int(context_window * DEFAULT_SOFT_TOKEN_PCT)
return hard, soft

# Model → context window mapping
# Note: claude-opus-4-6 has 200K by default. 1M is beta-only via API header.
# Use COZEMPIC_CONTEXT_WINDOW env var or --context-window flag to override.
Expand Down Expand Up @@ -224,7 +256,7 @@ def estimate_tokens_heuristic(
breakdown[mtype] = breakdown.get(mtype, 0) + msg_chars
total_chars += msg_chars

total_tokens = int(total_chars / chars_per_token) + SYSTEM_OVERHEAD_TOKENS
total_tokens = int(total_chars / chars_per_token) + get_system_overhead_tokens()

# Convert char breakdown to token breakdown
token_breakdown = {
Expand Down Expand Up @@ -339,7 +371,8 @@ def calibrate_ratio(messages: list[Message]) -> float | None:
return None

exact_tokens = usage["total"]
if exact_tokens <= SYSTEM_OVERHEAD_TOKENS:
overhead = get_system_overhead_tokens()
if exact_tokens <= overhead:
return None

# Count content chars (same way as heuristic)
Expand All @@ -357,7 +390,7 @@ def calibrate_ratio(messages: list[Message]) -> float | None:
if isinstance(content, str):
total_chars += len(content)

content_tokens = exact_tokens - SYSTEM_OVERHEAD_TOKENS
content_tokens = exact_tokens - overhead
if content_tokens <= 0:
return None

Expand Down