From 019e939bed2fbe1eaa8631539c28063654c8689d Mon Sep 17 00:00:00 2001 From: THUAUD Simon Date: Mon, 2 Feb 2026 23:45:00 +0100 Subject: [PATCH 1/2] feat(providers): added ollama support and openai --- example.py | 24 +- pyproject.toml | 1 + src/harvestor/__init__.py | 28 +++ src/harvestor/cli/main.py | 79 ++++++- src/harvestor/config.py | 47 ++-- src/harvestor/core/cost_tracker.py | 33 ++- src/harvestor/core/harvestor.py | 330 +++++---------------------- src/harvestor/parsers/llm_parser.py | 222 ++++++++++-------- src/harvestor/providers/__init__.py | 89 ++++++++ src/harvestor/providers/anthropic.py | 184 +++++++++++++++ src/harvestor/providers/base.py | 114 +++++++++ src/harvestor/providers/ollama.py | 218 ++++++++++++++++++ src/harvestor/providers/openai.py | 177 ++++++++++++++ src/harvestor/schemas/base.py | 19 +- tests/conftest.py | 1 + tests/test_cost_tracker.py | 83 ++++--- tests/test_harvestor.py | 46 ++-- tests/test_input_types.py | 30 +-- uv.lock | 2 + 19 files changed, 1201 insertions(+), 526 deletions(-) create mode 100644 src/harvestor/providers/__init__.py create mode 100644 src/harvestor/providers/anthropic.py create mode 100644 src/harvestor/providers/base.py create mode 100644 src/harvestor/providers/ollama.py create mode 100644 src/harvestor/providers/openai.py diff --git a/example.py b/example.py index 657dc1b..069511d 100644 --- a/example.py +++ b/example.py @@ -1,16 +1,16 @@ from typing import Optional + from dotenv import load_dotenv from pydantic import BaseModel, Field -from harvestor import Harvestor # , harvest -import os +from harvestor import Harvestor, list_models load_dotenv() -class SimpleInoviceModelSchema(BaseModel): +class SimpleInvoiceSchema(BaseModel): """ - Implement the schema you want as output. Customise for each document types. + Implement the schema you want as output. Customize for each document type. """ vendor: Optional[str] = Field(None, description="The vendor name") @@ -20,16 +20,22 @@ class SimpleInoviceModelSchema(BaseModel): customer_lastname: Optional[str] = Field(None, description="The customer lastname") -ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY") +# List available models +print("Available models:", list(list_models().keys())) -h = Harvestor(api_key=ANTHROPIC_API_KEY, model="Claude Haiku 3") +# Use default model (claude-haiku) +h = Harvestor(model="claude-haiku") output = h.harvest_file( - source="data/uploads/keep_for_test.jpg", schema=SimpleInoviceModelSchema + source="data/uploads/keep_for_test.jpg", schema=SimpleInvoiceSchema ) print(output.to_summary()) -# output_2 = harvest("data/uploads/keep_for_test.jpg", schema=SimpleInoviceModelSchema) +# Alternative: use OpenAI +# h_openai = Harvestor(model="gpt-4o-mini") +# output = h_openai.harvest_file("invoice.jpg", schema=SimpleInvoiceSchema) -# print(output_2.to_summary()) +# Alternative: use local Ollama (free) +# h_ollama = Harvestor(model="llava") +# output = h_ollama.harvest_file("invoice.jpg", schema=SimpleInvoiceSchema) diff --git a/pyproject.toml b/pyproject.toml index 93f8657..d72ebe4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "langchain-openai>=0.0.5", "anthropic>=0.18.0", "openai>=1.10.0", + "httpx>=0.27.0", # For Ollama provider # Document Processing "PyMuPDF>=1.23.0", diff --git a/src/harvestor/__init__.py b/src/harvestor/__init__.py index a019171..f6b5668 100644 --- a/src/harvestor/__init__.py +++ b/src/harvestor/__init__.py @@ -2,6 +2,7 @@ Harvestor - Harvest intelligence from any document Extract structured data from any document with AI-powered extraction. +Supports multiple LLM providers: Anthropic, OpenAI, and Ollama. """ __version__ = "0.1.0" @@ -14,6 +15,20 @@ from .config import SUPPORTED_MODELS from .core.cost_tracker import cost_tracker from .core.harvestor import Harvestor, harvest +from .providers import ( + DEFAULT_MODEL, + MODELS, + PROVIDERS, + AnthropicProvider, + BaseLLMProvider, + CompletionResult, + ModelInfo, + OllamaProvider, + OpenAIProvider, + get_provider, + list_models, + list_providers, +) from .schemas.base import ( ExtractionResult, ExtractionStrategy, @@ -39,4 +54,17 @@ "LineItem", # Config "SUPPORTED_MODELS", + "MODELS", + "DEFAULT_MODEL", + # Providers + "PROVIDERS", + "BaseLLMProvider", + "CompletionResult", + "ModelInfo", + "AnthropicProvider", + "OpenAIProvider", + "OllamaProvider", + "get_provider", + "list_models", + "list_providers", ] diff --git a/src/harvestor/cli/main.py b/src/harvestor/cli/main.py index 68a8f13..a4302f9 100644 --- a/src/harvestor/cli/main.py +++ b/src/harvestor/cli/main.py @@ -7,7 +7,7 @@ import sys from pathlib import Path -from harvestor import harvest +from harvestor import DEFAULT_MODEL, harvest, list_models from harvestor.schemas.defaults import InvoiceData, ReceiptData @@ -20,17 +20,19 @@ def build_parser(): parser.add_argument( "file_path", type=Path, + nargs="?", help="Path to the document to process", ) parser.add_argument( "schema", + nargs="?", help="Schema to use (e.g., InvoiceData, ReceiptData)", ) parser.add_argument( "-m", "--model", - default="Claude Haiku 3", - help="Model to use (default: Claude Haiku 3)", + default=DEFAULT_MODEL, + help=f"Model to use (default: {DEFAULT_MODEL})", ) parser.add_argument( "-o", @@ -43,13 +45,22 @@ def build_parser(): action="store_true", help="Pretty print JSON output", ) + parser.add_argument( + "--list-models", + action="store_true", + help="List available models and exit", + ) + parser.add_argument( + "--list-schemas", + action="store_true", + help="List available schemas and exit", + ) return parser def get_schema(schema_name: str): """Resolve schema name to actual schema class.""" - schemas = { "InvoiceData": InvoiceData, "ReceiptData": ReceiptData, @@ -62,10 +73,70 @@ def get_schema(schema_name: str): return schemas[schema_name] +def print_models(): + """Print available models grouped by provider.""" + models = list_models() + + providers = {} + for name, info in models.items(): + provider = info.get("provider", "unknown") + if provider not in providers: + providers[provider] = [] + providers[provider].append((name, info)) + + print("\nAvailable models:") + print("=" * 50) + + for provider, model_list in sorted(providers.items()): + print(f"\n{provider.upper()}:") + for name, info in sorted(model_list): + vision = " (vision)" if info.get("supports_vision") else "" + cost = info.get("input_cost", 0) + if cost == 0: + cost_str = "free" + else: + cost_str = f"${cost}/M tokens" + print(f" {name:<20} {cost_str}{vision}") + + print(f"\nDefault: {DEFAULT_MODEL}") + print() + + +def print_schemas(): + """Print available schemas.""" + schemas = { + "InvoiceData": InvoiceData, + "ReceiptData": ReceiptData, + } + + print("\nAvailable schemas:") + print("=" * 50) + + for name, schema in schemas.items(): + doc = schema.__doc__ or "No description" + print(f" {name}: {doc.strip().split(chr(10))[0]}") + + print() + + def main(): parser = build_parser() args = parser.parse_args() + if args.list_models: + print_models() + sys.exit(0) + + if args.list_schemas: + print_schemas() + sys.exit(0) + + if not args.file_path: + parser.error("file_path is required") + + if not args.schema: + parser.error("schema is required") + if not args.file_path.exists(): print(f"Error: File not found: {args.file_path}", file=sys.stderr) sys.exit(1) diff --git a/src/harvestor/config.py b/src/harvestor/config.py index 3d67ec2..d84cea4 100644 --- a/src/harvestor/config.py +++ b/src/harvestor/config.py @@ -1,28 +1,19 @@ -SUPPORTED_MODELS = { - # Anthropic Claude - "Claude Haiku 3": {"id": "claude-3-haiku-20240307", "input": 0.25, "output": 1.25}, - "Claude Haiku 4.5": { - "id": "claude-haiku-4-5-20251001", - "input": 1.0, - "output": 5.0, - }, - "Claude Sonnet 3.7": { - "id": "claude-3-7-sonnet-20250219", - "input": 3.0, - "output": 15.0, - }, - "Claude Sonnet 4.5": { - "id": "claude-sonnet-4-5-20250929", - "input": 3.0, - "output": 15.0, - }, - "Claude Opus 4,5": { - "id": "claude-opus-4-5-20251101", - "input": 5.0, - "output": 25.0, - }, # very good stuff - # OpenAI TODO: check OpenAI models - # "gpt-3.5-turbo": {"input": 0.50, "output": 1.50}, - "gpt-4": {"input": 30.0, "output": 60.0}, - # "gpt-4-turbo": {"input": 10.0, "output": 30.0}, -} +""" +Configuration for Harvestor. + +Model definitions are now managed in the providers module. +This file re-exports them for backwards compatibility. +""" + +from .providers import DEFAULT_MODEL, MODELS, list_models, list_providers + +# Backwards compatibility alias +SUPPORTED_MODELS = MODELS + +__all__ = [ + "MODELS", + "SUPPORTED_MODELS", + "DEFAULT_MODEL", + "list_models", + "list_providers", +] diff --git a/src/harvestor/core/cost_tracker.py b/src/harvestor/core/cost_tracker.py index f37b22c..c2555e9 100644 --- a/src/harvestor/core/cost_tracker.py +++ b/src/harvestor/core/cost_tracker.py @@ -12,7 +12,6 @@ from typing import Dict, List, Optional from ..schemas.base import CostReport, ExtractionStrategy -from ..config import SUPPORTED_MODELS @dataclass @@ -107,14 +106,15 @@ def calculate_cost( self, model: str, input_tokens: int, output_tokens: int ) -> float: """Calculate cost for a given API call.""" - if model not in SUPPORTED_MODELS: - # Unknown model, use conservative estimate (GPT-4 pricing) - raise ModelNotSupported(f"Model {model} is not supported.") - else: - pricing = SUPPORTED_MODELS[model] + from ..providers import MODELS - input_cost = (input_tokens / 1_000_000) * pricing["input"] - output_cost = (output_tokens / 1_000_000) * pricing["output"] + if model not in MODELS: + # Unknown model (possibly Ollama custom), assume free + return 0.0 + + pricing = MODELS[model] + input_cost = (input_tokens / 1_000_000) * pricing["input_cost"] + output_cost = (output_tokens / 1_000_000) * pricing["output_cost"] return input_cost + output_cost @@ -263,19 +263,12 @@ def generate_report(self, days: int = 7) -> CostReport: # Calculate costs total_cost = sum(c.cost for c in recent_calls) - free_successes = 0 - llm_calls = len( - [ - c - for c in recent_calls - if c.strategy - in { - ExtractionStrategy.LLM_HAIKU, - ExtractionStrategy.LLM_SONNET, - ExtractionStrategy.LLM_GPT35, - } - ] + free_successes = sum( + 1 + for c in recent_calls + if c.strategy == ExtractionStrategy.LLM_OLLAMA and c.success ) + llm_calls = len(recent_calls) # Cost by strategy cost_by_strategy: Dict[str, float] = {} diff --git a/src/harvestor/core/harvestor.py b/src/harvestor/core/harvestor.py index 7d482c5..7e8eda6 100644 --- a/src/harvestor/core/harvestor.py +++ b/src/harvestor/core/harvestor.py @@ -4,24 +4,19 @@ This is the primary public API for Harvestor. """ -import base64 import io -import json -import os import re import time from datetime import datetime from pathlib import Path from typing import BinaryIO, List, Optional, Type, Union -from anthropic import Anthropic from pydantic import BaseModel -from ..config import SUPPORTED_MODELS from ..core.cost_tracker import cost_tracker from ..parsers.llm_parser import LLMParser -from ..schemas.base import ExtractionResult, ExtractionStrategy, HarvestResult -from ..schemas.prompt_builder import PromptBuilder +from ..providers import DEFAULT_MODEL +from ..schemas.base import HarvestResult class Harvestor: @@ -30,8 +25,8 @@ class Harvestor: Features: - Extract structured data from documents - - Multiple extraction strategies (LLM) - - Cost optimization (LLM fallback for now) + - Multi-provider support (Anthropic, OpenAI, Ollama) + - Cost optimization - Batch processing support - Progress tracking and reporting """ @@ -39,41 +34,32 @@ class Harvestor: def __init__( self, api_key: Optional[str] = None, - model: str = "Claude Haiku 3", + model: str = DEFAULT_MODEL, cost_limit_per_doc: float = 0.10, daily_cost_limit: Optional[float] = None, + base_url: Optional[str] = None, ): """ Initialize Harvestor. Args: - api_key: Anthropic API key (uses ANTHROPIC_API_KEY env var if not provided) - model: LLM model to use (default: Claude Haiku for cost optimization) + api_key: API key (uses env var if not provided, not needed for Ollama) + model: Model to use (e.g., 'claude-haiku', 'gpt-4o-mini', 'llama3') cost_limit_per_doc: Maximum cost per document (default: $0.10) daily_cost_limit: Optional daily cost limit + base_url: Optional base URL override for the provider """ - # Get API key - self.api_key = api_key or os.getenv("ANTHROPIC_API_KEY") - if not self.api_key: - raise ValueError( - "Anthropic API key required. Set ANTHROPIC_API_KEY env var or pass api_key parameter." - ) - - # Resolve model name to API model ID - if model not in SUPPORTED_MODELS: - raise ValueError( - f"Unsupported model: {model}. Supported models: {list(SUPPORTED_MODELS.keys())}" - ) - self.model_name = model # Friendly name for cost tracking - self.model = SUPPORTED_MODELS[model]["id"] # API model ID + self.model_name = model + self.api_key = api_key + self.base_url = base_url # Set cost limits cost_tracker.set_limits( daily_limit=daily_cost_limit, per_document_limit=cost_limit_per_doc ) - # Initialize LLM parser - self.llm_parser = LLMParser(model=model, api_key=self.api_key) + # Initialize LLM parser (handles provider selection) + self.llm_parser = LLMParser(model=model, api_key=api_key, base_url=base_url) @staticmethod def get_doc_type_from_schema(schema: Type[BaseModel]) -> str: @@ -93,7 +79,6 @@ def get_doc_type_from_schema(schema: Type[BaseModel]) -> str: break # Convert CamelCase to snake_case - # "IDDocument" -> "id_document" name = re.sub(r"(?>> # From file path (str or Path) - >>> result = harvestor.harvest_file("invoice.jpg", schema) - - >>> # From bytes - >>> with open("invoice.jpg", "rb") as f: - ... data = f.read() - >>> result = harvestor.harvest_file(data, schema, filename="invoice.jpg") - - >>> # From file-like object - >>> from io import BytesIO - >>> buffer = BytesIO(image_data) - >>> result = harvestor.harvest_file(buffer, schema, filename="invoice.jpg") """ start_time = time.time() @@ -208,7 +176,7 @@ def harvest_file( file_size: Optional[int] = None inferred_filename: Optional[str] = None - # Use provided doc_type or resolved it from gave schema + # Use provided doc_type or derive from schema doc_type = doc_type or self.get_doc_type_from_schema(schema) if isinstance(source, (str, Path)): @@ -231,12 +199,10 @@ def harvest_file( file_size = file_path.stat().st_size document_id = document_id or file_path.stem - # Read file content for image processing with open(file_path, "rb") as f: file_bytes = f.read() elif isinstance(source, bytes): - # Bytes input file_bytes = source file_size = len(source) inferred_filename = ( @@ -245,11 +211,9 @@ def harvest_file( document_id = document_id or Path(inferred_filename).stem elif hasattr(source, "read"): - # File-like object (BinaryIO) file_bytes = source.read() file_size = len(file_bytes) - # Try to get filename from file object if hasattr(source, "name"): inferred_filename = Path(source.name).name else: @@ -269,16 +233,11 @@ def harvest_file( total_time=time.time() - start_time, ) - # Use provided filename or inferred one final_filename = filename or inferred_filename - - # Determine file type from filename file_extension = Path(final_filename).suffix.lower() try: - # Route to appropriate extraction method based on file type if file_extension in [".jpg", ".jpeg", ".png", ".gif", ".webp"]: - # Image file - use vision API result = self._harvest_image( image_bytes=file_bytes, schema=schema, @@ -288,7 +247,6 @@ def harvest_file( filename=final_filename, ) elif file_extension in [".txt", ".pdf"]: - # Text-based file - extract text first text = self._extract_text_from_bytes(file_bytes, file_extension) result = self.harvest_text( text=text, @@ -309,7 +267,6 @@ def harvest_file( total_time=time.time() - start_time, ) - # Add file metadata to result if file_path_str: result.file_path = file_path_str result.file_size_bytes = file_size @@ -328,71 +285,12 @@ def harvest_file( total_time=time.time() - start_time, ) - def _extract_text_from_file(self, file_path: Path) -> str: - """ - Extract text from file based on type. - - Args: - file_path: Path to file - - Returns: - Extracted text - - Raises: - ValueError: If file type is not supported - """ - suffix = file_path.suffix.lower() - - if suffix == ".txt": - # Plain text file - with open(file_path, "r", encoding="utf-8") as f: - return f.read() - - elif suffix == ".pdf": - # PDF file - use pdfplumber for native text extraction - try: - import pdfplumber - - text_parts = [] - with pdfplumber.open(file_path) as pdf: - for page in pdf.pages: - page_text = page.extract_text() - if page_text: - text_parts.append(page_text) - - if not text_parts: - raise ValueError("No text found in PDF (might need OCR)") - - return "\n\n".join(text_parts) - - except ImportError: - raise ValueError( - "pdfplumber not installed. Install with: pip install pdfplumber" - ) - - else: - raise ValueError(f"Unsupported file type: {suffix}. Supported: .txt, .pdf") - def _extract_text_from_bytes(self, file_bytes: bytes, file_extension: str) -> str: - """ - Extract text from bytes based on file type. - - Args: - file_bytes: Raw file content - file_extension: File extension (e.g., '.txt', '.pdf') - - Returns: - Extracted text - - Raises: - ValueError: If file type is not supported - """ + """Extract text from bytes based on file type.""" if file_extension == ".txt": - # Plain text file return file_bytes.decode("utf-8") elif file_extension == ".pdf": - # PDF file - use pdfplumber try: import pdfplumber @@ -427,23 +325,10 @@ def _harvest_image( language: str = "en", filename: Optional[str] = None, ) -> HarvestResult: - """ - Extract structured data from an image using Claude's vision API. - - Args: - image_bytes: Raw image data - schema: Pydantic model defining the output structure - doc_type: Document type - document_id: Document identifier - language: Document language - filename: Original filename for determining image type - - Returns: - HarvestResult with extracted data - """ + """Extract structured data from an image using vision API.""" start_time = time.time() - # Determine image media type from filename + # Determine media type from filename if filename: extension = Path(filename).suffix.lower().replace(".", "") if extension == "jpg": @@ -451,137 +336,33 @@ def _harvest_image( else: media_type = f"image/{extension}" else: - # Default to jpeg media_type = "image/jpeg" - # Encode image to base64 - image_b64 = base64.standard_b64encode(image_bytes).decode("utf-8") - - # Create extraction prompt from schema - builder = PromptBuilder(schema) - prompt = builder.build_vision_prompt(doc_type) - - # Initialize Anthropic client - client = Anthropic(api_key=self.api_key) - - # Call vision API - response = client.messages.create( - model=self.model, - max_tokens=2048, - temperature=0.0, - messages=[ - { - "role": "user", - "content": [ - { - "type": "image", - "source": { - "type": "base64", - "media_type": media_type, - "data": image_b64, - }, - }, - {"type": "text", "text": prompt}, - ], - } - ], - ) - - # Extract token usage - input_tokens = response.usage.input_tokens - output_tokens = response.usage.output_tokens - - # Determine strategy based on model - if "haiku" in self.model.lower(): - strategy = ExtractionStrategy.LLM_HAIKU - elif "sonnet" in self.model.lower(): - strategy = ExtractionStrategy.LLM_SONNET - else: - strategy = ExtractionStrategy.LLM_HAIKU - - # Track cost - cost = cost_tracker.track_call( - model=self.model_name, - strategy=strategy, - input_tokens=input_tokens, - output_tokens=output_tokens, + # Use LLMParser's vision extraction + extraction_result = self.llm_parser.extract_vision( + image_data=image_bytes, + schema=schema, + doc_type=doc_type, document_id=document_id, - success=True, + media_type=media_type, ) - # Parse response - response_text = response.content[0].text processing_time = time.time() - start_time - try: - # Extract JSON from response - json_start = response_text.find("{") - json_end = response_text.rfind("}") + 1 - - if json_start >= 0 and json_end > json_start: - json_str = response_text[json_start:json_end] - data = json.loads(json_str) - else: - data = json.loads(response_text) - - # Validate against schema - validated_data = schema(**data) - data = validated_data.model_dump() - - # Create extraction result - extraction_result = ExtractionResult( - success=True, - data=data, - raw_text=response_text[:500], - strategy=strategy, - confidence=0.85, - processing_time=processing_time, - cost=cost, - tokens_used=input_tokens + output_tokens, - metadata={ - "model": self.model, - "media_type": media_type, - "vision_api": True, - }, - ) - - # Build harvest result - return HarvestResult( - success=True, - document_id=document_id, - document_type=doc_type, - data=data, - extraction_results=[extraction_result], - final_strategy=strategy, - final_confidence=0.85, - total_cost=cost, - cost_breakdown={strategy.value: cost}, - total_time=processing_time, - language=language, - ) - - except json.JSONDecodeError as e: - return HarvestResult( - success=False, - document_id=document_id, - document_type=doc_type, - data={}, - error=f"Failed to parse JSON response: {str(e)}", - total_cost=cost, - total_time=processing_time, - language=language, - ) - except Exception as e: - return HarvestResult( - success=False, - document_id=document_id, - document_type=doc_type, - data={}, - error=f"Vision API extraction failed: {str(e)}", - total_cost=cost, - total_time=processing_time, - language=language, - ) + return HarvestResult( + success=extraction_result.success, + document_id=document_id, + document_type=doc_type, + data=extraction_result.data, + extraction_results=[extraction_result], + final_strategy=extraction_result.strategy, + final_confidence=extraction_result.confidence, + total_cost=extraction_result.cost, + cost_breakdown={extraction_result.strategy.value: extraction_result.cost}, + total_time=processing_time, + error=extraction_result.error, + language=language, + ) def harvest_batch( self, @@ -596,7 +377,7 @@ def harvest_batch( Args: files: List of file paths to process schema: Pydantic model defining the output structure - doc_type: Document type for all files (derived from schema if not provided) + doc_type: Document type for all files show_progress: Show progress bar Returns: @@ -632,36 +413,30 @@ def harvest( schema: Type[BaseModel], doc_type: Optional[str] = None, language: str = "en", - model: str = "Claude Haiku 3", + model: str = DEFAULT_MODEL, api_key: Optional[str] = None, filename: Optional[str] = None, + base_url: Optional[str] = None, ) -> HarvestResult: """ One-liner function for quick extraction. - Accepts file paths, bytes, or file-like objects for maximum flexibility. + Accepts file paths, bytes, or file-like objects. Examples: ```python from harvestor import harvest from harvestor.schemas import InvoiceData - # From file path + # From file path with default model (claude-haiku) result = harvest("invoice.pdf", schema=InvoiceData) print(f"Invoice #: {result.data.get('invoice_number')}") - print(f"Total: ${result.data.get('total_amount')}") - print(f"Cost: ${result.total_cost:.4f}") - - # From bytes with custom schema - from pydantic import BaseModel, Field - class ContractData(BaseModel): - parties: list[str] = Field(description="Contract parties") - value: float | None = Field(None, description="Contract value") + # With OpenAI + result = harvest("invoice.jpg", schema=InvoiceData, model="gpt-4o-mini") - with open("contract.pdf", "rb") as f: - data = f.read() - result = harvest(data, schema=ContractData, filename="contract.pdf") + # With local Ollama + result = harvest("invoice.txt", schema=InvoiceData, model="llama3") ``` Args: @@ -669,14 +444,15 @@ class ContractData(BaseModel): schema: Pydantic model defining the output structure doc_type: Document type (derived from schema name if not provided) language: Document language - model: LLM model to use + model: Model to use (default: claude-haiku) api_key: API key (uses env var if not provided) filename: Original filename (required when source is bytes/file-like) + base_url: Optional base URL override Returns: HarvestResult with extracted data """ - harvestor = Harvestor(api_key=api_key, model=model) + harvestor = Harvestor(api_key=api_key, model=model, base_url=base_url) return harvestor.harvest_file( source=source, schema=schema, diff --git a/src/harvestor/parsers/llm_parser.py b/src/harvestor/parsers/llm_parser.py index b2180f6..2a87680 100644 --- a/src/harvestor/parsers/llm_parser.py +++ b/src/harvestor/parsers/llm_parser.py @@ -1,22 +1,18 @@ """ -LLM-based document parser using LangChain and Anthropic. +LLM-based document parser using provider abstraction. -Uses Claude Haiku (or other models) for extracting structured data from text. - -Haiku is the cheapest :) +Supports multiple LLM providers (Anthropic, OpenAI, Ollama) for extracting +structured data from text. """ import json import time from typing import Any, Dict, Optional, Type -from anthropic import Anthropic -from langchain_core.prompts import PromptTemplate -from langchain_anthropic import ChatAnthropic from pydantic import BaseModel, ValidationError -from ..config import SUPPORTED_MODELS from ..core.cost_tracker import cost_tracker +from ..providers import DEFAULT_MODEL, BaseLLMProvider, get_provider from ..schemas.base import ExtractionResult, ExtractionStrategy from ..schemas.prompt_builder import PromptBuilder @@ -26,8 +22,8 @@ class LLMParser: LLM-based parser for extracting structured data from text. Features: - - Uses Claude Haiku by default (cheapest) - - Structured output with Pydantic validation based on personnal wish + - Multi-provider support (Anthropic, OpenAI, Ollama) + - Structured output with Pydantic validation - Automatic retry on validation errors - Cost tracking integration - Smart truncation for long documents @@ -35,49 +31,47 @@ class LLMParser: def __init__( self, - model: str = "Claude Haiku 3", + model: str = DEFAULT_MODEL, api_key: Optional[str] = None, max_retries: int = 3, max_input_chars: int = 8000, + base_url: Optional[str] = None, ): """ Initialize LLM parser. Args: - model: Model to use (default: Claude Haiku) - api_key: Anthropic API key (uses env var if not provided) + model: Model to use (e.g., 'claude-haiku', 'gpt-4o-mini', 'llama3') + api_key: API key (uses env var if not provided, not needed for Ollama) max_retries: Maximum retry attempts for failed extractions max_input_chars: Maximum characters to send to LLM + base_url: Optional base URL override """ - # Resolve model name to API model ID - if model not in SUPPORTED_MODELS: - raise ValueError( - f"Unsupported model: {model}. Supported models: {list(SUPPORTED_MODELS.keys())}" - ) - self.model_name = model # Friendly name for cost tracking - self.model = SUPPORTED_MODELS[model]["id"] # API model ID + self.model_name = model self.max_retries = max_retries self.max_input_chars = max_input_chars - # Initialize LangChain LLM - self.llm = ChatAnthropic( - model=self.model, - anthropic_api_key=api_key, - temperature=0.0, # Deterministic for data extraction do not hallucanite + # Get provider for this model + self.provider: BaseLLMProvider = get_provider( + model=model, api_key=api_key, base_url=base_url ) - # Initialize Anthropic client for direct API access - self.anthropic_client = Anthropic(api_key=api_key) + # Get model info for cost tracking + self.model_info = self.provider.get_model_info() + + # Determine strategy based on provider + self.strategy = self._get_strategy() - # Determine strategy based on model - if "haiku" in model.lower(): - self.strategy = ExtractionStrategy.LLM_HAIKU - elif "sonnet" in model.lower(): - self.strategy = ExtractionStrategy.LLM_SONNET - elif "gpt" in model.lower(): - self.strategy = ExtractionStrategy.LLM_GPT35 - else: - self.strategy = ExtractionStrategy.LLM_HAIKU + def _get_strategy(self) -> ExtractionStrategy: + """Determine extraction strategy based on provider.""" + provider_name = self.model_info.provider + if provider_name == "anthropic": + return ExtractionStrategy.LLM_ANTHROPIC + elif provider_name == "openai": + return ExtractionStrategy.LLM_OPENAI + elif provider_name == "ollama": + return ExtractionStrategy.LLM_OLLAMA + return ExtractionStrategy.LLM_ANTHROPIC def truncate_text(self, text: str, max_chars: Optional[int] = None) -> str: """ @@ -148,9 +142,11 @@ def extract( ExtractionResult with extracted data """ start_time = time.time() + original_length = len(text) # Truncate if needed text = self.truncate_text(text) + was_truncated = len(text) < original_length # Create prompt from schema prompt = self.create_prompt(text, doc_type, schema) @@ -158,7 +154,7 @@ def extract( # Try extraction with retries for attempt in range(self.max_retries): try: - result = self._extract_with_anthropic( + result = self._extract_with_provider( prompt=prompt, schema=schema, document_id=document_id ) @@ -167,25 +163,24 @@ def extract( return ExtractionResult( success=True, data=result["data"], - raw_text=text[:500], # Store first 500 chars + raw_text=text[:500], strategy=self.strategy, confidence=result.get("confidence", 0.85), processing_time=processing_time, cost=result["cost"], tokens_used=result["tokens"], metadata={ - "model": self.model, + "model": self.model_info.model_id, + "provider": self.model_info.provider, "attempt": attempt + 1, - "truncated": len(text) < len(text), + "truncated": was_truncated, }, ) except ValidationError as e: if attempt < self.max_retries - 1: - # Retry with error feedback continue else: - # Final attempt failed processing_time = time.time() - start_time return ExtractionResult( success=False, @@ -207,11 +202,21 @@ def extract( error=f"Extraction failed: {str(e)}", ) - def _extract_with_anthropic( + # Should not reach here, but handle edge case + return ExtractionResult( + success=False, + data={}, + strategy=self.strategy, + confidence=0.0, + processing_time=time.time() - start_time, + error="Extraction failed: max retries exceeded", + ) + + def _extract_with_provider( self, prompt: str, schema: Type[BaseModel], document_id: Optional[str] = None ) -> Dict[str, Any]: """ - Extract using Anthropic API with structured output. + Extract using the configured provider. Args: prompt: Prompt text @@ -221,33 +226,31 @@ def _extract_with_anthropic( Returns: Dict with data, cost, and tokens """ - # Call Anthropic API - response = self.anthropic_client.messages.create( - model=self.model, + # Call provider + result = self.provider.complete( + prompt=prompt, max_tokens=2048, temperature=0.0, - messages=[{"role": "user", "content": prompt}], ) - # Extract tokens and calculate cost - input_tokens = response.usage.input_tokens - output_tokens = response.usage.output_tokens + if not result.success: + raise RuntimeError(result.error or "Provider returned unsuccessful result") + # Track cost cost = cost_tracker.track_call( model=self.model_name, strategy=self.strategy, - input_tokens=input_tokens, - output_tokens=output_tokens, + input_tokens=result.input_tokens, + output_tokens=result.output_tokens, document_id=document_id, success=True, ) # Parse response - response_text = response.content[0].text + response_text = result.content # Try to extract JSON from response try: - # Find JSON in response (might have explanatory text) json_start = response_text.find("{") json_end = response_text.rfind("}") + 1 @@ -255,7 +258,6 @@ def _extract_with_anthropic( json_str = response_text[json_start:json_end] data = json.loads(json_str) else: - # Try parsing entire response as JSON data = json.loads(response_text) # Validate against schema @@ -264,87 +266,119 @@ def _extract_with_anthropic( return { "data": validated_data.model_dump(), "cost": cost, - "tokens": input_tokens + output_tokens, - "confidence": 0.85, # Default confidence for LLM extraction + "tokens": result.total_tokens, + "confidence": 0.85, } except json.JSONDecodeError as e: raise ValidationError(f"Failed to parse JSON: {str(e)}") - def extract_with_langchain( + def extract_vision( self, - text: str, + image_data: bytes, schema: Type[BaseModel], doc_type: str = "document", document_id: Optional[str] = None, + media_type: str = "image/jpeg", ) -> ExtractionResult: """ - Alternative extraction using LangChain chains. - - This is simpler but doesn't use structured output. - Good for experimentation. + Extract structured data from an image using vision API. - Might use this for experimentation with agentic AI @Koweez. Args: - text: Text to extract from + image_data: Raw image bytes schema: Pydantic model for structured output doc_type: Document type - document_id: Optional document ID + document_id: Optional document ID for cost tracking + media_type: Image MIME type Returns: - ExtractionResult + ExtractionResult with extracted data """ start_time = time.time() - # Truncate if needed - text = self.truncate_text(text) + if not self.provider.supports_vision(): + return ExtractionResult( + success=False, + data={}, + strategy=self.strategy, + confidence=0.0, + processing_time=time.time() - start_time, + error=f"Model {self.model_name} does not support vision", + ) - # Create LangChain prompt template using schema + # Create vision prompt builder = PromptBuilder(schema) - prompt_template = PromptTemplate( - input_variables=["text"], - template=builder.build_text_prompt("{text}", doc_type), - ) + prompt = builder.build_vision_prompt(doc_type) try: - # Use LangChain LLM - result = self.llm.predict(prompt_template.format(text=text)) + result = self.provider.complete_vision( + prompt=prompt, + image_data=image_data, + media_type=media_type, + max_tokens=2048, + temperature=0.0, + ) + + if not result.success: + raise RuntimeError(result.error or "Vision API failed") + + # Track cost + cost = cost_tracker.track_call( + model=self.model_name, + strategy=self.strategy, + input_tokens=result.input_tokens, + output_tokens=result.output_tokens, + document_id=document_id, + success=True, + ) # Parse JSON response - json_start = result.find("{") - json_end = result.rfind("}") + 1 + response_text = result.content + json_start = response_text.find("{") + json_end = response_text.rfind("}") + 1 if json_start >= 0 and json_end > json_start: - json_str = result[json_start:json_end] + json_str = response_text[json_start:json_end] data = json.loads(json_str) else: - data = json.loads(result) + data = json.loads(response_text) - processing_time = time.time() - start_time + validated_data = schema(**data) - # Note: We don't have token counts with LangChain predict - # This is a limitation of using the simplified API - estimated_cost = 0.02 # Rough estimate + processing_time = time.time() - start_time return ExtractionResult( success=True, - data=data, - raw_text=text[:500], + data=validated_data.model_dump(), + raw_text=response_text[:500], strategy=self.strategy, - confidence=0.80, + confidence=0.85, processing_time=processing_time, - cost=estimated_cost, - tokens_used=0, # Unknown with LangChain - warnings=["Using LangChain predict - token counts unavailable"], + cost=cost, + tokens_used=result.total_tokens, + metadata={ + "model": self.model_info.model_id, + "provider": self.model_info.provider, + "vision": True, + "media_type": media_type, + }, ) + except json.JSONDecodeError as e: + return ExtractionResult( + success=False, + data={}, + strategy=self.strategy, + confidence=0.0, + processing_time=time.time() - start_time, + error=f"Failed to parse JSON response: {str(e)}", + ) except Exception as e: - processing_time = time.time() - start_time return ExtractionResult( success=False, data={}, strategy=self.strategy, confidence=0.0, - processing_time=processing_time, - error=f"LangChain extraction failed: {str(e)}", + processing_time=time.time() - start_time, + error=f"Vision extraction failed: {str(e)}", ) diff --git a/src/harvestor/providers/__init__.py b/src/harvestor/providers/__init__.py new file mode 100644 index 0000000..02114eb --- /dev/null +++ b/src/harvestor/providers/__init__.py @@ -0,0 +1,89 @@ +""" +LLM Provider implementations. + +Supported providers: +- Anthropic (Claude models) +- OpenAI (GPT models) +- Ollama (local models) +""" + +from typing import Optional, Type + +from .anthropic import ANTHROPIC_MODELS, AnthropicProvider +from .base import BaseLLMProvider, CompletionResult, ModelInfo +from .ollama import OLLAMA_MODELS, OllamaProvider +from .openai import OPENAI_MODELS, OpenAIProvider + +# Combine all models into a single registry +MODELS = { + **{k: {**v, "provider": "anthropic"} for k, v in ANTHROPIC_MODELS.items()}, + **{k: {**v, "provider": "openai"} for k, v in OPENAI_MODELS.items()}, + **{k: {**v, "provider": "ollama"} for k, v in OLLAMA_MODELS.items()}, +} + +PROVIDERS: dict[str, Type[BaseLLMProvider]] = { + "anthropic": AnthropicProvider, + "openai": OpenAIProvider, + "ollama": OllamaProvider, +} + +DEFAULT_MODEL = "claude-haiku" + + +def get_provider( + model: str, + api_key: Optional[str] = None, + base_url: Optional[str] = None, +) -> BaseLLMProvider: + """ + Get the appropriate provider for a model. + + Args: + model: Model name (e.g., 'claude-haiku', 'gpt-4o-mini', 'llama3') + api_key: API key (not needed for Ollama) + base_url: Optional base URL override + + Returns: + Initialized provider instance + + Raises: + ValueError: If model is not recognized + """ + if model not in MODELS: + # Check if it might be an Ollama model (allows custom local models) + if ":" in model or model.startswith("llama") or model.startswith("mistral"): + return OllamaProvider(model=model, base_url=base_url) + raise ValueError( + f"Unknown model: {model}. Available models: {list(MODELS.keys())}" + ) + + provider_name = MODELS[model]["provider"] + provider_class = PROVIDERS[provider_name] + + return provider_class(model=model, api_key=api_key, base_url=base_url) + + +def list_models() -> dict[str, dict]: + """List all available models with their info.""" + return MODELS.copy() + + +def list_providers() -> list[str]: + """List all available providers.""" + return list(PROVIDERS.keys()) + + +__all__ = [ + "BaseLLMProvider", + "CompletionResult", + "ModelInfo", + "AnthropicProvider", + "OpenAIProvider", + "OllamaProvider", + "MODELS", + "PROVIDERS", + "DEFAULT_MODEL", + "get_provider", + "list_models", + "list_providers", +] diff --git a/src/harvestor/providers/anthropic.py b/src/harvestor/providers/anthropic.py new file mode 100644 index 0000000..5f148fd --- /dev/null +++ b/src/harvestor/providers/anthropic.py @@ -0,0 +1,184 @@ +""" +Anthropic Claude provider implementation. +""" + +import base64 +import os +from typing import Optional + +from anthropic import Anthropic + +from .base import BaseLLMProvider, CompletionResult, ModelInfo + +ANTHROPIC_MODELS = { + "claude-haiku": { + "id": "claude-3-haiku-20240307", + "input_cost": 0.25, + "output_cost": 1.25, + "supports_vision": True, + "context_window": 200000, + }, + "claude-haiku-4": { + "id": "claude-haiku-4-5-20251001", + "input_cost": 1.0, + "output_cost": 5.0, + "supports_vision": True, + "context_window": 200000, + }, + "claude-sonnet": { + "id": "claude-sonnet-4-5-20250929", + "input_cost": 3.0, + "output_cost": 15.0, + "supports_vision": True, + "context_window": 200000, + }, + "claude-sonnet-3.7": { + "id": "claude-3-7-sonnet-20250219", + "input_cost": 3.0, + "output_cost": 15.0, + "supports_vision": True, + "context_window": 200000, + }, + "claude-opus": { + "id": "claude-opus-4-5-20251101", + "input_cost": 15.0, + "output_cost": 75.0, + "supports_vision": True, + "context_window": 200000, + }, +} + + +class AnthropicProvider(BaseLLMProvider): + """Anthropic Claude provider.""" + + def __init__( + self, + api_key: Optional[str] = None, + model: str = "claude-haiku", + base_url: Optional[str] = None, + ): + api_key = api_key or os.getenv("ANTHROPIC_API_KEY") + if not api_key: + raise ValueError( + "Anthropic API key required. Set ANTHROPIC_API_KEY env var or pass api_key." + ) + + super().__init__(api_key=api_key, model=model, base_url=base_url) + + if model not in ANTHROPIC_MODELS: + raise ValueError( + f"Unknown Anthropic model: {model}. " + f"Available: {list(ANTHROPIC_MODELS.keys())}" + ) + + self.model_config = ANTHROPIC_MODELS[model] + self.model_id = self.model_config["id"] + self.client = Anthropic(api_key=api_key, base_url=base_url) + + def complete( + self, + prompt: str, + max_tokens: int = 2048, + temperature: float = 0.0, + ) -> CompletionResult: + try: + response = self.client.messages.create( + model=self.model_id, + max_tokens=max_tokens, + temperature=temperature, + messages=[{"role": "user", "content": prompt}], + ) + + return CompletionResult( + success=True, + content=response.content[0].text, + input_tokens=response.usage.input_tokens, + output_tokens=response.usage.output_tokens, + model=self.model_id, + metadata={"stop_reason": response.stop_reason}, + ) + + except Exception as e: + return CompletionResult( + success=False, + content="", + model=self.model_id, + error=str(e), + ) + + def complete_vision( + self, + prompt: str, + image_data: bytes, + media_type: str = "image/jpeg", + max_tokens: int = 2048, + temperature: float = 0.0, + ) -> CompletionResult: + if not self.supports_vision(): + return CompletionResult( + success=False, + content="", + model=self.model_id, + error=f"Model {self.model} does not support vision", + ) + + try: + image_b64 = base64.standard_b64encode(image_data).decode("utf-8") + + response = self.client.messages.create( + model=self.model_id, + max_tokens=max_tokens, + temperature=temperature, + messages=[ + { + "role": "user", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": media_type, + "data": image_b64, + }, + }, + {"type": "text", "text": prompt}, + ], + } + ], + ) + + return CompletionResult( + success=True, + content=response.content[0].text, + input_tokens=response.usage.input_tokens, + output_tokens=response.usage.output_tokens, + model=self.model_id, + metadata={"stop_reason": response.stop_reason, "vision": True}, + ) + + except Exception as e: + return CompletionResult( + success=False, + content="", + model=self.model_id, + error=str(e), + ) + + def supports_vision(self) -> bool: + return self.model_config.get("supports_vision", False) + + def get_model_info(self) -> ModelInfo: + return ModelInfo( + name=self.model, + provider="anthropic", + model_id=self.model_id, + input_cost_per_million=self.model_config["input_cost"], + output_cost_per_million=self.model_config["output_cost"], + supports_vision=self.model_config.get("supports_vision", False), + context_window=self.model_config.get("context_window", 200000), + ) + + @classmethod + def get_provider_name(cls) -> str: + return "anthropic" diff --git a/src/harvestor/providers/base.py b/src/harvestor/providers/base.py new file mode 100644 index 0000000..bccdb95 --- /dev/null +++ b/src/harvestor/providers/base.py @@ -0,0 +1,114 @@ +""" +Base provider abstraction for LLM providers. + +Defines the interface that all LLM providers must implement. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any, Dict, Optional + + +@dataclass +class ModelInfo: + """Information about an LLM model.""" + + name: str + provider: str + model_id: str + input_cost_per_million: float + output_cost_per_million: float + supports_vision: bool = False + max_tokens: int = 4096 + context_window: int = 128000 + + +@dataclass +class CompletionResult: + """Unified result from an LLM completion.""" + + success: bool + content: str + input_tokens: int = 0 + output_tokens: int = 0 + model: str = "" + error: Optional[str] = None + metadata: Dict[str, Any] = field(default_factory=dict) + + @property + def total_tokens(self) -> int: + return self.input_tokens + self.output_tokens + + +class BaseLLMProvider(ABC): + """Abstract base class for LLM providers.""" + + def __init__( + self, + api_key: Optional[str] = None, + model: str = "", + base_url: Optional[str] = None, + ): + self.api_key = api_key + self.model = model + self.base_url = base_url + + @abstractmethod + def complete( + self, + prompt: str, + max_tokens: int = 2048, + temperature: float = 0.0, + ) -> CompletionResult: + """ + Generate a completion for the given prompt. + + Args: + prompt: The input prompt + max_tokens: Maximum tokens to generate + temperature: Sampling temperature (0.0 for deterministic) + + Returns: + CompletionResult with the generated content + """ + pass + + @abstractmethod + def complete_vision( + self, + prompt: str, + image_data: bytes, + media_type: str = "image/jpeg", + max_tokens: int = 2048, + temperature: float = 0.0, + ) -> CompletionResult: + """ + Generate a completion for an image + prompt. + + Args: + prompt: The input prompt + image_data: Raw image bytes + media_type: Image MIME type + max_tokens: Maximum tokens to generate + temperature: Sampling temperature + + Returns: + CompletionResult with the generated content + """ + pass + + @abstractmethod + def supports_vision(self) -> bool: + """Check if this provider/model supports vision.""" + pass + + @abstractmethod + def get_model_info(self) -> ModelInfo: + """Get information about the current model.""" + pass + + @classmethod + @abstractmethod + def get_provider_name(cls) -> str: + """Get the provider name (e.g., 'anthropic', 'openai').""" + pass diff --git a/src/harvestor/providers/ollama.py b/src/harvestor/providers/ollama.py new file mode 100644 index 0000000..1bc03d5 --- /dev/null +++ b/src/harvestor/providers/ollama.py @@ -0,0 +1,218 @@ +""" +Ollama provider implementation for local LLM models. +""" + +import base64 +import os +from typing import Optional + +import httpx + +from .base import BaseLLMProvider, CompletionResult, ModelInfo + +OLLAMA_MODELS = { + "llama3": { + "id": "llama3:latest", + "input_cost": 0.0, + "output_cost": 0.0, + "supports_vision": False, + "context_window": 8192, + }, + "llama3.2": { + "id": "llama3.2:latest", + "input_cost": 0.0, + "output_cost": 0.0, + "supports_vision": False, + "context_window": 128000, + }, + "mistral": { + "id": "mistral:latest", + "input_cost": 0.0, + "output_cost": 0.0, + "supports_vision": False, + "context_window": 32000, + }, + "llava": { + "id": "llava:latest", + "input_cost": 0.0, + "output_cost": 0.0, + "supports_vision": True, + "context_window": 4096, + }, + "llava-llama3": { + "id": "llava-llama3:latest", + "input_cost": 0.0, + "output_cost": 0.0, + "supports_vision": True, + "context_window": 8192, + }, +} + +DEFAULT_OLLAMA_URL = "http://localhost:11434" + + +class OllamaProvider(BaseLLMProvider): + """Ollama provider for local models.""" + + def __init__( + self, + api_key: Optional[str] = None, + model: str = "llama3", + base_url: Optional[str] = None, + ): + base_url = base_url or os.getenv("OLLAMA_BASE_URL", DEFAULT_OLLAMA_URL) + super().__init__(api_key=api_key, model=model, base_url=base_url) + + if model not in OLLAMA_MODELS: + # Allow custom models not in the predefined list + self.model_config = { + "id": f"{model}:latest" if ":" not in model else model, + "input_cost": 0.0, + "output_cost": 0.0, + "supports_vision": False, + "context_window": 8192, + } + else: + self.model_config = OLLAMA_MODELS[model] + + self.model_id = self.model_config["id"] + self.client = httpx.Client(base_url=base_url, timeout=120.0) + + def complete( + self, + prompt: str, + max_tokens: int = 2048, + temperature: float = 0.0, + ) -> CompletionResult: + try: + response = self.client.post( + "/api/generate", + json={ + "model": self.model_id, + "prompt": prompt, + "stream": False, + "options": { + "temperature": temperature, + "num_predict": max_tokens, + }, + }, + ) + response.raise_for_status() + data = response.json() + + return CompletionResult( + success=True, + content=data.get("response", ""), + input_tokens=data.get("prompt_eval_count", 0), + output_tokens=data.get("eval_count", 0), + model=self.model_id, + metadata={ + "total_duration": data.get("total_duration"), + "load_duration": data.get("load_duration"), + }, + ) + + except httpx.ConnectError: + return CompletionResult( + success=False, + content="", + model=self.model_id, + error=f"Cannot connect to Ollama at {self.base_url}. Is Ollama running?", + ) + except Exception as e: + return CompletionResult( + success=False, + content="", + model=self.model_id, + error=str(e), + ) + + def complete_vision( + self, + prompt: str, + image_data: bytes, + media_type: str = "image/jpeg", + max_tokens: int = 2048, + temperature: float = 0.0, + ) -> CompletionResult: + if not self.supports_vision(): + return CompletionResult( + success=False, + content="", + model=self.model_id, + error=f"Model {self.model} does not support vision. Use 'llava' or 'llava-llama3'.", + ) + + try: + image_b64 = base64.standard_b64encode(image_data).decode("utf-8") + + response = self.client.post( + "/api/generate", + json={ + "model": self.model_id, + "prompt": prompt, + "images": [image_b64], + "stream": False, + "options": { + "temperature": temperature, + "num_predict": max_tokens, + }, + }, + ) + response.raise_for_status() + data = response.json() + + return CompletionResult( + success=True, + content=data.get("response", ""), + input_tokens=data.get("prompt_eval_count", 0), + output_tokens=data.get("eval_count", 0), + model=self.model_id, + metadata={ + "total_duration": data.get("total_duration"), + "vision": True, + }, + ) + + except httpx.ConnectError: + return CompletionResult( + success=False, + content="", + model=self.model_id, + error=f"Cannot connect to Ollama at {self.base_url}. Is Ollama running?", + ) + except Exception as e: + return CompletionResult( + success=False, + content="", + model=self.model_id, + error=str(e), + ) + + def supports_vision(self) -> bool: + return self.model_config.get("supports_vision", False) + + def get_model_info(self) -> ModelInfo: + return ModelInfo( + name=self.model, + provider="ollama", + model_id=self.model_id, + input_cost_per_million=0.0, + output_cost_per_million=0.0, + supports_vision=self.model_config.get("supports_vision", False), + context_window=self.model_config.get("context_window", 8192), + ) + + @classmethod + def get_provider_name(cls) -> str: + return "ollama" + + def list_local_models(self) -> list[str]: + """List models available in the local Ollama installation.""" + try: + response = self.client.get("/api/tags") + response.raise_for_status() + data = response.json() + return [m["name"] for m in data.get("models", [])] + except Exception: + return [] diff --git a/src/harvestor/providers/openai.py b/src/harvestor/providers/openai.py new file mode 100644 index 0000000..6c59dcb --- /dev/null +++ b/src/harvestor/providers/openai.py @@ -0,0 +1,177 @@ +""" +OpenAI provider implementation. +""" + +import base64 +import os +from typing import Optional + +from openai import OpenAI + +from .base import BaseLLMProvider, CompletionResult, ModelInfo + +OPENAI_MODELS = { + "gpt-4o": { + "id": "gpt-4o", + "input_cost": 2.50, + "output_cost": 10.0, + "supports_vision": True, + "context_window": 128000, + }, + "gpt-4o-mini": { + "id": "gpt-4o-mini", + "input_cost": 0.15, + "output_cost": 0.60, + "supports_vision": True, + "context_window": 128000, + }, + "gpt-4-turbo": { + "id": "gpt-4-turbo", + "input_cost": 10.0, + "output_cost": 30.0, + "supports_vision": True, + "context_window": 128000, + }, + "gpt-4": { + "id": "gpt-4", + "input_cost": 30.0, + "output_cost": 60.0, + "supports_vision": False, + "context_window": 8192, + }, +} + + +class OpenAIProvider(BaseLLMProvider): + """OpenAI GPT provider.""" + + def __init__( + self, + api_key: Optional[str] = None, + model: str = "gpt-4o-mini", + base_url: Optional[str] = None, + ): + api_key = api_key or os.getenv("OPENAI_API_KEY") + if not api_key: + raise ValueError( + "OpenAI API key required. Set OPENAI_API_KEY env var or pass api_key." + ) + + super().__init__(api_key=api_key, model=model, base_url=base_url) + + if model not in OPENAI_MODELS: + raise ValueError( + f"Unknown OpenAI model: {model}. " + f"Available: {list(OPENAI_MODELS.keys())}" + ) + + self.model_config = OPENAI_MODELS[model] + self.model_id = self.model_config["id"] + self.client = OpenAI(api_key=api_key, base_url=base_url) + + def complete( + self, + prompt: str, + max_tokens: int = 2048, + temperature: float = 0.0, + ) -> CompletionResult: + try: + response = self.client.chat.completions.create( + model=self.model_id, + max_tokens=max_tokens, + temperature=temperature, + messages=[{"role": "user", "content": prompt}], + ) + + choice = response.choices[0] + usage = response.usage + + return CompletionResult( + success=True, + content=choice.message.content or "", + input_tokens=usage.prompt_tokens if usage else 0, + output_tokens=usage.completion_tokens if usage else 0, + model=self.model_id, + metadata={"finish_reason": choice.finish_reason}, + ) + + except Exception as e: + return CompletionResult( + success=False, + content="", + model=self.model_id, + error=str(e), + ) + + def complete_vision( + self, + prompt: str, + image_data: bytes, + media_type: str = "image/jpeg", + max_tokens: int = 2048, + temperature: float = 0.0, + ) -> CompletionResult: + if not self.supports_vision(): + return CompletionResult( + success=False, + content="", + model=self.model_id, + error=f"Model {self.model} does not support vision", + ) + + try: + image_b64 = base64.standard_b64encode(image_data).decode("utf-8") + data_url = f"data:{media_type};base64,{image_b64}" + + response = self.client.chat.completions.create( + model=self.model_id, + max_tokens=max_tokens, + temperature=temperature, + messages=[ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": data_url}}, + {"type": "text", "text": prompt}, + ], + } + ], + ) + + choice = response.choices[0] + usage = response.usage + + return CompletionResult( + success=True, + content=choice.message.content or "", + input_tokens=usage.prompt_tokens if usage else 0, + output_tokens=usage.completion_tokens if usage else 0, + model=self.model_id, + metadata={"finish_reason": choice.finish_reason, "vision": True}, + ) + + except Exception as e: + return CompletionResult( + success=False, + content="", + model=self.model_id, + error=str(e), + ) + + def supports_vision(self) -> bool: + return self.model_config.get("supports_vision", False) + + def get_model_info(self) -> ModelInfo: + return ModelInfo( + name=self.model, + provider="openai", + model_id=self.model_id, + input_cost_per_million=self.model_config["input_cost"], + output_cost_per_million=self.model_config["output_cost"], + supports_vision=self.model_config.get("supports_vision", False), + context_window=self.model_config.get("context_window", 128000), + ) + + @classmethod + def get_provider_name(cls) -> str: + return "openai" diff --git a/src/harvestor/schemas/base.py b/src/harvestor/schemas/base.py index e86cfda..3232344 100644 --- a/src/harvestor/schemas/base.py +++ b/src/harvestor/schemas/base.py @@ -13,9 +13,14 @@ class ExtractionStrategy(str, Enum): """Strategies for extracting data from documents.""" - LLM_HAIKU = "llm_haiku" # Claude Haiku - LLM_SONNET = "llm_sonnet" # Claude Sonnet - LLM_GPT35 = "llm_gpt35" # GPT-3.5-turbo + LLM_ANTHROPIC = "llm_anthropic" + LLM_OPENAI = "llm_openai" + LLM_OLLAMA = "llm_ollama" + + # Legacy aliases (for backwards compatibility) + LLM_HAIKU = "llm_anthropic" + LLM_SONNET = "llm_anthropic" + LLM_GPT35 = "llm_openai" @dataclass @@ -32,8 +37,8 @@ class ExtractionResult: confidence: float = 0.0 # 0.0 to 1.0 processing_time: float = 0.0 # seconds - # Cost tracking - cost: float = 0.0 # USD -> for llm calls + # Cost tracking (USD) + cost: float = 0.0 tokens_used: int = 0 # Error handling @@ -63,9 +68,9 @@ def is_free_method(self) -> bool: @dataclass class ValidationResult: - """Result from validation checks. Determine is the document is legit. + """Result from validation checks. - Will be implemented later on. + Determines if the document is legitimate. Will be implemented later. """ # Core validation diff --git a/tests/conftest.py b/tests/conftest.py index 7c06167..fa0eb55 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -103,6 +103,7 @@ class MockContent: class MockResponse: usage = MockUsage() content = [MockContent()] + stop_reason = "end_turn" return MockResponse() diff --git a/tests/test_cost_tracker.py b/tests/test_cost_tracker.py index 5f5c8f7..c049b00 100644 --- a/tests/test_cost_tracker.py +++ b/tests/test_cost_tracker.py @@ -6,7 +6,6 @@ CostLimitExceeded, CostTracker, cost_tracker, - ModelNotSupported, ) from harvestor.schemas.base import ExtractionStrategy @@ -21,29 +20,29 @@ def setup_method(self): def test_calculate_haiku_cost(self): """Test cost calculation for Claude Haiku.""" cost = cost_tracker.calculate_cost( - model="Claude Haiku 3", input_tokens=1000, output_tokens=500 + model="claude-haiku", input_tokens=1000, output_tokens=500 ) # Haiku: $0.25/MTok input, $1.25/MTok output expected = (1000 / 1_000_000 * 0.25) + (500 / 1_000_000 * 1.25) assert cost == pytest.approx(expected) - def test_calculate_sonnet_3_7_cost(self): - """Test cost calculation for Claude Sonnet 3.7.""" + def test_calculate_sonnet_cost(self): + """Test cost calculation for Claude Sonnet.""" cost = cost_tracker.calculate_cost( - model="Claude Sonnet 3.7", input_tokens=1000, output_tokens=500 + model="claude-sonnet", input_tokens=1000, output_tokens=500 ) - # Sonnet 3.7: $3/MTok input, $15/MTok output + # Sonnet: $3/MTok input, $15/MTok output expected = (1000 / 1_000_000 * 3.0) + (500 / 1_000_000 * 15.0) assert cost == pytest.approx(expected) - def test_unknown_model_uses_gpt4_pricing(self): - """Test that unknown models default to GPT-4 pricing.""" - with pytest.raises(ModelNotSupported): - cost_tracker.calculate_cost( - model="unknown-model", input_tokens=1000, output_tokens=500 - ) + def test_unknown_model_returns_zero_cost(self): + """Test that unknown models return zero cost (e.g., custom Ollama models).""" + cost = cost_tracker.calculate_cost( + model="unknown-model", input_tokens=1000, output_tokens=500 + ) + assert cost == 0.0 class TestCostTracking: @@ -58,8 +57,8 @@ def setup_method(self): def test_track_single_call(self): """Test tracking a single API call.""" cost = cost_tracker.track_call( - model="Claude Haiku 3", - strategy=ExtractionStrategy.LLM_HAIKU, + model="claude-haiku", + strategy=ExtractionStrategy.LLM_ANTHROPIC, input_tokens=1000, output_tokens=500, document_id="doc1", @@ -75,8 +74,8 @@ def test_track_multiple_calls(self): """Test tracking multiple API calls.""" for i in range(3): cost_tracker.track_call( - model="Claude Haiku 3", - strategy=ExtractionStrategy.LLM_HAIKU, + model="claude-haiku", + strategy=ExtractionStrategy.LLM_ANTHROPIC, input_tokens=1000, output_tokens=500, document_id=f"doc{i}", @@ -88,12 +87,12 @@ def test_track_multiple_calls(self): def test_per_document_limit_enforcement(self): """Test that per-document cost limit is enforced.""" - cost_tracker.set_limits(per_document_limit=0.01) + cost_tracker.set_limits(per_document_limit=0.001) # First call should succeed cost_tracker.track_call( - model="Claude Haiku 3", - strategy=ExtractionStrategy.LLM_HAIKU, + model="claude-haiku", + strategy=ExtractionStrategy.LLM_ANTHROPIC, input_tokens=1000, output_tokens=500, document_id="doc1", @@ -102,8 +101,8 @@ def test_per_document_limit_enforcement(self): # Second call for same document would exceed limit with pytest.raises(CostLimitExceeded): cost_tracker.track_call( - model="Claude Sonnet 3.7", - strategy=ExtractionStrategy.LLM_SONNET, + model="claude-sonnet", + strategy=ExtractionStrategy.LLM_ANTHROPIC, input_tokens=10000, output_tokens=5000, document_id="doc1", @@ -111,12 +110,12 @@ def test_per_document_limit_enforcement(self): def test_daily_limit_enforcement(self): """Test that daily cost limit is enforced.""" - cost_tracker.set_limits(daily_limit=0.01) + cost_tracker.set_limits(daily_limit=0.001) # First small call should succeed cost_tracker.track_call( - model="Claude Haiku 3", - strategy=ExtractionStrategy.LLM_HAIKU, + model="claude-haiku", + strategy=ExtractionStrategy.LLM_ANTHROPIC, input_tokens=100, output_tokens=50, document_id="doc1", @@ -125,8 +124,8 @@ def test_daily_limit_enforcement(self): # Second call that would exceed daily limit with pytest.raises(CostLimitExceeded): cost_tracker.track_call( - model="Claude Sonnet 3.7", - strategy=ExtractionStrategy.LLM_SONNET, + model="claude-sonnet", + strategy=ExtractionStrategy.LLM_ANTHROPIC, input_tokens=100000, output_tokens=50000, document_id="doc2", @@ -137,16 +136,16 @@ def test_get_document_cost(self): doc_id = "test_doc" cost1 = cost_tracker.track_call( - model="Claude Haiku 3", - strategy=ExtractionStrategy.LLM_HAIKU, + model="claude-haiku", + strategy=ExtractionStrategy.LLM_ANTHROPIC, input_tokens=1000, output_tokens=500, document_id=doc_id, ) cost2 = cost_tracker.track_call( - model="Claude Haiku 3", - strategy=ExtractionStrategy.LLM_HAIKU, + model="claude-haiku", + strategy=ExtractionStrategy.LLM_ANTHROPIC, input_tokens=500, output_tokens=250, document_id=doc_id, @@ -168,16 +167,16 @@ def setup_method(self): def test_stats_by_model(self): """Test statistics grouped by model.""" cost_tracker.track_call( - model="Claude Haiku 3", - strategy=ExtractionStrategy.LLM_HAIKU, + model="claude-haiku", + strategy=ExtractionStrategy.LLM_ANTHROPIC, input_tokens=1000, output_tokens=500, document_id="doc1", ) cost_tracker.track_call( - model="Claude Sonnet 3.7", - strategy=ExtractionStrategy.LLM_SONNET, + model="claude-sonnet", + strategy=ExtractionStrategy.LLM_ANTHROPIC, input_tokens=1000, output_tokens=500, document_id="doc2", @@ -185,15 +184,15 @@ def test_stats_by_model(self): stats = cost_tracker.get_stats() assert len(stats.calls_by_model) == 2 - assert "Claude Haiku 3" in stats.calls_by_model - assert "Claude Sonnet 3.7" in stats.calls_by_model + assert "claude-haiku" in stats.calls_by_model + assert "claude-sonnet" in stats.calls_by_model def test_average_cost_per_document(self): """Test average cost per document calculation.""" for i in range(3): cost_tracker.track_call( - model="Claude Haiku 3", - strategy=ExtractionStrategy.LLM_HAIKU, + model="claude-haiku", + strategy=ExtractionStrategy.LLM_ANTHROPIC, input_tokens=1000, output_tokens=500, document_id=f"doc{i}", @@ -207,8 +206,8 @@ def test_cost_report_generation(self): """Test cost report generation.""" # Track some successful calls cost_tracker.track_call( - model="Claude Haiku 3", - strategy=ExtractionStrategy.LLM_HAIKU, + model="claude-haiku", + strategy=ExtractionStrategy.LLM_ANTHROPIC, input_tokens=1000, output_tokens=500, document_id="doc1", @@ -230,8 +229,8 @@ class TestCostTrackerReset: def test_reset_clears_all_data(self): """Test that reset clears all tracked data.""" cost_tracker.track_call( - model="Claude Haiku 3", - strategy=ExtractionStrategy.LLM_HAIKU, + model="claude-haiku", + strategy=ExtractionStrategy.LLM_ANTHROPIC, input_tokens=1000, output_tokens=500, document_id="doc1", diff --git a/tests/test_harvestor.py b/tests/test_harvestor.py index a4d0fe7..fb9d0de 100644 --- a/tests/test_harvestor.py +++ b/tests/test_harvestor.py @@ -14,13 +14,13 @@ class TestHarvestorInitialization: def test_init_with_api_key(self): """Test initialization with explicit API key.""" harvestor = Harvestor(api_key="sk-test-key") - assert harvestor.api_key == "sk-test-key" + assert harvestor.llm_parser.provider.api_key == "sk-test-key" def test_init_with_env_api_key(self, monkeypatch): """Test initialization with API key from environment.""" monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-env-key") harvestor = Harvestor() - assert harvestor.api_key == "sk-env-key" + assert harvestor.llm_parser.provider.api_key == "sk-env-key" def test_init_without_api_key_raises_error(self, monkeypatch): """Test that initialization without API key raises error.""" @@ -31,8 +31,8 @@ def test_init_without_api_key_raises_error(self, monkeypatch): def test_init_with_custom_model(self): """Test initialization with custom model.""" - harvestor = Harvestor(api_key="sk-test-key", model="Claude Sonnet 3.7") - assert harvestor.model_name == "Claude Sonnet 3.7" + harvestor = Harvestor(api_key="sk-test-key", model="claude-sonnet") + assert harvestor.model_name == "claude-sonnet" def test_init_sets_cost_limits(self): """Test that initialization sets cost limits.""" @@ -49,7 +49,7 @@ def test_init_sets_cost_limits(self): class TestTextExtraction: """Test text extraction from different file types.""" - @patch("harvestor.parsers.llm_parser.Anthropic") + @patch("harvestor.providers.anthropic.Anthropic") def test_harvest_text_basic( self, mock_anthropic, sample_invoice_text, mock_anthropic_response, api_key ): @@ -66,18 +66,7 @@ def test_harvest_text_basic( assert isinstance(result, HarvestResult) assert result.success is True assert result.document_type == "invoice" - assert result.total_cost > 0 - - def test_extract_text_from_txt_file(self, tmp_path, api_key): - """Test text extraction from .txt file.""" - # Create a text file - txt_file = tmp_path / "test.txt" - txt_file.write_text("Test content") - - harvestor = Harvestor(api_key=api_key) - text = harvestor._extract_text_from_file(txt_file) - - assert text == "Test content" + assert result.total_cost >= 0 def test_extract_text_from_bytes_txt(self, api_key): """Test text extraction from bytes (.txt).""" @@ -88,21 +77,18 @@ def test_extract_text_from_bytes_txt(self, api_key): assert text == "Hello, world!" - def test_unsupported_file_extension_raises_error(self, tmp_path, api_key): + def test_unsupported_file_extension_raises_error(self, api_key): """Test that unsupported file extensions raise ValueError.""" - unsupported_file = tmp_path / "test.xyz" - unsupported_file.write_text("content") - harvestor = Harvestor(api_key=api_key) with pytest.raises(ValueError, match="Unsupported file type"): - harvestor._extract_text_from_file(unsupported_file) + harvestor._extract_text_from_bytes(b"content", ".xyz") class TestBatchProcessing: """Test batch processing functionality.""" - @patch("harvestor.core.harvestor.Anthropic") + @patch("harvestor.providers.anthropic.Anthropic") def test_harvest_batch( self, mock_anthropic, tmp_path, mock_anthropic_response, api_key ): @@ -127,7 +113,7 @@ def test_harvest_batch( assert all(isinstance(r, HarvestResult) for r in results) assert all(r.success for r in results) - @patch("harvestor.core.harvestor.Anthropic") + @patch("harvestor.providers.anthropic.Anthropic") def test_harvest_batch_with_failures(self, mock_anthropic, tmp_path, api_key): """Test batch processing handles failures gracefully.""" mock_client = MagicMock() @@ -136,6 +122,7 @@ def test_harvest_batch_with_failures(self, mock_anthropic, tmp_path, api_key): MagicMock( usage=MagicMock(input_tokens=100, output_tokens=50), content=[MagicMock(text='{"invoice_number": "123"}')], + stop_reason="end_turn", ), # Second succeeds ] mock_anthropic.return_value = mock_client @@ -159,7 +146,7 @@ def test_harvest_batch_with_failures(self, mock_anthropic, tmp_path, api_key): class TestDocumentIDGeneration: """Test document ID generation.""" - @patch("harvestor.core.harvestor.Anthropic") + @patch("harvestor.providers.anthropic.Anthropic") def test_document_id_from_filename( self, mock_anthropic, tmp_path, mock_anthropic_response, api_key ): @@ -176,7 +163,7 @@ def test_document_id_from_filename( assert result.document_id == "invoice_12345" - @patch("harvestor.core.harvestor.Anthropic") + @patch("harvestor.providers.anthropic.Anthropic") def test_custom_document_id( self, mock_anthropic, tmp_path, mock_anthropic_response, api_key ): @@ -199,7 +186,7 @@ def test_custom_document_id( class TestHarvestResult: """Test HarvestResult properties and methods.""" - @patch("harvestor.core.harvestor.Anthropic") + @patch("harvestor.providers.anthropic.Anthropic") def test_harvest_result_structure( self, mock_anthropic, tmp_path, mock_anthropic_response, api_key ): @@ -224,7 +211,7 @@ def test_harvest_result_structure( assert hasattr(result, "file_path") assert hasattr(result, "file_size_bytes") - @patch("harvestor.core.harvestor.Anthropic") + @patch("harvestor.providers.anthropic.Anthropic") def test_harvest_result_cost_efficiency( self, mock_anthropic, tmp_path, mock_anthropic_response, api_key ): @@ -255,7 +242,7 @@ def test_missing_api_key_error_message(self, monkeypatch): assert "api key" in str(exc_info.value).lower() - @patch("harvestor.core.harvestor.Anthropic") + @patch("harvestor.providers.anthropic.Anthropic") def test_api_error_returns_failed_result(self, mock_anthropic, tmp_path, api_key): """Test that API errors return failed HarvestResult.""" mock_client = MagicMock() @@ -270,7 +257,6 @@ def test_api_error_returns_failed_result(self, mock_anthropic, tmp_path, api_key assert result.success is False assert result.error is not None - assert "extraction failed" in result.error.lower() def test_nonexistent_file_returns_error(self, api_key): """Test that non-existent file returns error result.""" diff --git a/tests/test_input_types.py b/tests/test_input_types.py index 60c24c2..1f07c8c 100644 --- a/tests/test_input_types.py +++ b/tests/test_input_types.py @@ -12,7 +12,7 @@ class TestFilePathInput: """Test file path inputs (str and Path objects).""" - @patch("harvestor.core.harvestor.Anthropic") + @patch("harvestor.providers.anthropic.Anthropic") def test_harvest_with_string_path( self, mock_anthropic, @@ -35,7 +35,7 @@ def test_harvest_with_string_path( assert result.file_path == str(sample_invoice_image_path) assert result.file_size_bytes > 0 - @patch("harvestor.core.harvestor.Anthropic") + @patch("harvestor.providers.anthropic.Anthropic") def test_harvest_with_path_object( self, mock_anthropic, @@ -67,7 +67,7 @@ def test_harvest_with_nonexistent_path(self, api_key): class TestBytesInput: """Test raw bytes input.""" - @patch("harvestor.core.harvestor.Anthropic") + @patch("harvestor.providers.anthropic.Anthropic") def test_harvest_with_bytes( self, mock_anthropic, sample_invoice_bytes, mock_anthropic_response, api_key ): @@ -86,7 +86,7 @@ def test_harvest_with_bytes( assert result.file_size_bytes == len(sample_invoice_bytes) mock_client.messages.create.assert_called_once() - @patch("harvestor.core.harvestor.Anthropic") + @patch("harvestor.providers.anthropic.Anthropic") def test_harvest_with_bytes_without_filename( self, mock_anthropic, sample_invoice_bytes, mock_anthropic_response, api_key ): @@ -102,7 +102,7 @@ def test_harvest_with_bytes_without_filename( assert result.success is False assert "unsupported file type" in result.error.lower() - @patch("harvestor.core.harvestor.Anthropic") + @patch("harvestor.providers.anthropic.Anthropic") def test_harvest_with_bytes_different_formats( self, mock_anthropic, mock_anthropic_response, api_key ): @@ -125,7 +125,7 @@ def test_harvest_with_bytes_different_formats( class TestFileLikeInput: """Test file-like object inputs (BytesIO, opened files).""" - @patch("harvestor.core.harvestor.Anthropic") + @patch("harvestor.providers.anthropic.Anthropic") def test_harvest_with_bytesio( self, mock_anthropic, sample_invoice_bytes, mock_anthropic_response, api_key ): @@ -144,7 +144,7 @@ def test_harvest_with_bytesio( assert result.success is True assert result.file_size_bytes == len(sample_invoice_bytes) - @patch("harvestor.core.harvestor.Anthropic") + @patch("harvestor.providers.anthropic.Anthropic") def test_harvest_with_opened_file( self, mock_anthropic, @@ -167,7 +167,7 @@ def test_harvest_with_opened_file( # Should auto-detect filename from f.name assert result.document_id is not None - @patch("harvestor.core.harvestor.Anthropic") + @patch("harvestor.providers.anthropic.Anthropic") def test_harvest_with_fileobj_without_name_attribute( self, mock_anthropic, sample_invoice_bytes, mock_anthropic_response, api_key ): @@ -190,7 +190,7 @@ def test_harvest_with_fileobj_without_name_attribute( class TestInputTypeEquivalence: """Test that all input types produce equivalent results.""" - @patch("harvestor.core.harvestor.Anthropic") + @patch("harvestor.providers.anthropic.Anthropic") def test_all_input_types_produce_same_result( self, mock_anthropic, @@ -242,7 +242,7 @@ def test_all_input_types_produce_same_result( class TestConvenienceFunction: """Test the harvest() convenience function.""" - @patch("harvestor.core.harvestor.Anthropic") + @patch("harvestor.providers.anthropic.Anthropic") def test_harvest_function_with_path( self, mock_anthropic, @@ -260,7 +260,7 @@ def test_harvest_function_with_path( assert isinstance(result, HarvestResult) assert result.success is True - @patch("harvestor.core.harvestor.Anthropic") + @patch("harvestor.providers.anthropic.Anthropic") def test_harvest_function_with_bytes( self, mock_anthropic, sample_invoice_bytes, mock_anthropic_response, api_key ): @@ -279,7 +279,7 @@ def test_harvest_function_with_bytes( assert isinstance(result, HarvestResult) assert result.success is True - @patch("harvestor.core.harvestor.Anthropic") + @patch("harvestor.providers.anthropic.Anthropic") def test_harvest_function_with_fileobj( self, mock_anthropic, sample_invoice_fileobj, mock_anthropic_response, api_key ): @@ -302,7 +302,7 @@ def test_harvest_function_with_fileobj( class TestImageFormatDetection: """Test image format detection and media type mapping.""" - @patch("harvestor.core.harvestor.Anthropic") + @patch("harvestor.providers.anthropic.Anthropic") def test_jpg_maps_to_jpeg_mime( self, mock_anthropic, sample_invoice_bytes, mock_anthropic_response, api_key ): @@ -324,7 +324,7 @@ def test_jpg_maps_to_jpeg_mime( image_source = messages[0]["content"][0]["source"] assert image_source["media_type"] == "image/jpeg" - @patch("harvestor.core.harvestor.Anthropic") + @patch("harvestor.providers.anthropic.Anthropic") @pytest.mark.parametrize( "filename,expected_type", [ @@ -382,7 +382,7 @@ def test_unsupported_file_format(self, api_key): assert result.success is False assert "unsupported file type" in result.error.lower() - @patch("harvestor.core.harvestor.Anthropic") + @patch("harvestor.providers.anthropic.Anthropic") def test_api_error_handling(self, mock_anthropic, sample_invoice_bytes, api_key): """Test error handling when API call fails.""" mock_client = MagicMock() diff --git a/uv.lock b/uv.lock index 52baabd..9f84b22 100644 --- a/uv.lock +++ b/uv.lock @@ -271,6 +271,7 @@ source = { editable = "." } dependencies = [ { name = "anthropic" }, { name = "click" }, + { name = "httpx" }, { name = "langchain" }, { name = "langchain-anthropic" }, { name = "langchain-openai" }, @@ -305,6 +306,7 @@ dev = [ requires-dist = [ { name = "anthropic", specifier = ">=0.18.0" }, { name = "click", specifier = ">=8.1.0" }, + { name = "httpx", specifier = ">=0.27.0" }, { name = "langchain", specifier = ">=0.1.0" }, { name = "langchain-anthropic", specifier = ">=0.1.0" }, { name = "langchain-openai", specifier = ">=0.0.5" }, From c13993ca121b9fc0d738fd91d56c4a984d55ac1a Mon Sep 17 00:00:00 2001 From: THUAUD Simon Date: Mon, 2 Feb 2026 23:58:24 +0100 Subject: [PATCH 2/2] cleaned stuff --- .env.template | 22 ---------------------- pytest.ini | 4 ---- 2 files changed, 26 deletions(-) diff --git a/.env.template b/.env.template index bdc706a..47b7b6f 100644 --- a/.env.template +++ b/.env.template @@ -1,25 +1,3 @@ # LLM API Keys ANTHROPIC_API_KEY=sk-ant-your-key-here OPENAI_API_KEY=sk-your-key-here - -# Database -DATABASE_URL=sqlite:///./data/harvestor.db - -# Cost Limits -MAX_COST_PER_DOCUMENT=0.50 -DAILY_COST_LIMIT=100.00 - -# Models -DEFAULT_EXTRACTION_MODEL=claude-haiku-4-5-20251001 -DEFAULT_VALIDATION_MODEL=claude-sonnet-4-5-20250929 - -# OCR Settings -ENABLE_TESSERACT_PREPROCESSING=true -OCR_DPI=300 -OCR_LANGUAGES=eng+fra+deu+spa - -# Features -USE_LAYOUT_ANALYSIS=true -USE_TABLE_EXTRACTION=true -USE_KEYWORD_PROXIMITY=true -ENABLE_CACHING=true diff --git a/pytest.ini b/pytest.ini index aa2e4a4..e8d4162 100644 --- a/pytest.ini +++ b/pytest.ini @@ -24,9 +24,5 @@ markers = slow: Slow running tests vision: Tests that use vision API -# Coverage options (if using pytest-cov) -# Uncomment when pytest-cov is installed -# addopts = --cov=src/harvestor --cov-report=term-missing --cov-report=html - # Minimum Python version minversion = 3.10