diff --git a/akd/tools/factory.py b/akd/tools/factory.py index 3f81052e..fd0ffdfa 100644 --- a/akd/tools/factory.py +++ b/akd/tools/factory.py @@ -4,7 +4,7 @@ from akd.agents.relevancy import MultiRubricRelevancyAgent from akd.tools.relevancy import EnhancedRelevancyCheckerConfig, RubricWeights -from .scrapers.composite import CompositeWebScraper, ResearchArticleResolver +from .scrapers.composite import CompositeScraper, ResearchArticleResolver from .scrapers.pdf_scrapers import SimplePDFScraper from .scrapers.resolvers import ( ADSResolver, @@ -20,6 +20,11 @@ WebScraperToolBase, ) from .search import SearchTool, SearxNGSearchTool, SearxNGSearchToolConfig +from .source_validator import ( + SourceValidator, + SourceValidatorConfig, + create_source_validator, +) def create_default_scraper( @@ -28,7 +33,7 @@ def create_default_scraper( ) -> WebScraperToolBase: config = config or WebpageScraperToolConfig() config.debug = debug - return CompositeWebScraper( + return CompositeScraper( SimpleWebScraper(config), Crawl4AIWebScraper(config), SimplePDFScraper(config), @@ -54,6 +59,33 @@ def create_default_article_resolver( ) +def create_default_source_validator( + config: Optional[SourceValidatorConfig] = None, + whitelist_file_path: Optional[str] = None, + max_concurrent_requests: int = 10, + debug: bool = False, +) -> SourceValidator: + """ + Create a source validator with default parameters. + + Args: + config: Optional SourceValidatorConfig. If provided, other parameters are ignored. + whitelist_file_path: Path to source whitelist JSON file. If None, uses default path in akd/docs/pubs_whitelist.json. + max_concurrent_requests: Maximum number of concurrent API requests. + debug: Enable debug logging. + + Returns: + Configured SourceValidator instance. + """ + if config is None: + return create_source_validator( + whitelist_file_path=whitelist_file_path, + max_concurrent_requests=max_concurrent_requests, + debug=debug, + ) + return SourceValidator(config, debug=debug) + + def create_strict_literature_config_for_relevancy( n_iter: int = 1, relevance_threshold: float = 0.7, diff --git a/akd/tools/source_validator.py b/akd/tools/source_validator.py index 76ad7298..a5654636 100644 --- a/akd/tools/source_validator.py +++ b/akd/tools/source_validator.py @@ -8,33 +8,133 @@ from __future__ import annotations -import json +import asyncio import re -from typing import TYPE_CHECKING, Any, Dict, List, Optional +import urllib.parse +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple -import aiohttp +import ftfy +import orjson +from crossref_commons.retrieval import get_publication_as_json from loguru import logger -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator +from rapidfuzz import fuzz, process +from unidecode import unidecode from akd._base import InputSchema, OutputSchema -from akd.structures import SearchResultItem from akd.tools._base import BaseTool, BaseToolConfig from akd.utils import get_akd_root if TYPE_CHECKING: - pass + from akd.structures import SearchResultItem class SourceInfo(BaseModel): """Schema for source information from CrossRef API.""" - title: str = Field(..., description="Source title") + title: str = Field(..., description="Source title", min_length=1) publisher: Optional[str] = Field(None, description="Publisher name") issn: Optional[List[str]] = Field(None, description="List of ISSNs") is_open_access: Optional[bool] = Field(None, description="Open access status") - doi: str = Field(..., description="DOI of the article") + doi: str = Field(..., description="DOI of the article", min_length=7) url: Optional[str] = Field(None, description="Original URL") + @field_validator("doi") + @classmethod + def validate_doi(cls, v: str) -> str: + """ + Validate DOI format. + + Args: + v: DOI string to validate + + Returns: + Validated DOI string + + Raises: + ValueError: If DOI format is invalid + """ + if not v.startswith("10.") or "/" not in v: + raise ValueError('DOI must start with "10." and contain "/"') + return v + + @field_validator("issn") + @classmethod + def validate_issn(cls, v: Optional[List[str]]) -> Optional[List[str]]: + """ + Validate ISSN format with improved checksum validation. + + Args: + v: List of ISSN strings to validate + + Returns: + List of valid ISSN strings or None + """ + if v is None: + return v + + def is_valid_issn(issn: str) -> bool: + """Validate ISSN format and checksum.""" + issn = issn.strip() + + # Basic format check (XXXX-XXXX) + if not re.match(r"^\d{4}-\d{4}$", issn): + return False + + # Remove hyphen for checksum calculation + issn_digits = issn.replace("-", "") + + # ISSN checksum validation + try: + total = 0 + for i, digit in enumerate(issn_digits[:7]): + total += int(digit) * (8 - i) + + check_digit = issn_digits[7] + calculated_check = 11 - (total % 11) + + if calculated_check == 10: + return check_digit.upper() == "X" + elif calculated_check == 11: + return check_digit == "0" + else: + return check_digit == str(calculated_check) + + except (ValueError, IndexError): + return False + + validated_issns = [] + for issn in v: + if is_valid_issn(issn): + validated_issns.append(issn.strip()) + else: + # Log invalid ISSN but don't fail validation + logger.debug(f"Invalid ISSN format or checksum: {issn}") + + return validated_issns if validated_issns else None + + @field_validator("title") + @classmethod + def validate_title(cls, v: str) -> str: + """ + Clean and validate title. + + Args: + v: Title string to validate + + Returns: + Cleaned title string + + Raises: + ValueError: If title is empty + """ + if not v or not v.strip(): + raise ValueError("Title cannot be empty") + + # Use ftfy to clean the title + cleaned = ftfy.fix_text(v.strip()) + return cleaned + class ValidationResult(BaseModel): """Schema for validation result.""" @@ -55,15 +155,18 @@ class ValidationResult(BaseModel): confidence_score: float = Field( ..., description="Confidence in validation (0.0-1.0)", + ge=0.0, + le=1.0, ) class SourceValidatorInputSchema(InputSchema): """Input schema for source validation tool.""" - search_results: List[SearchResultItem] = Field( + search_results: List["SearchResultItem"] = Field( ..., description="List of search results to validate", + min_length=1, ) whitelist_file_path: Optional[str] = Field( None, @@ -84,29 +187,16 @@ class SourceValidatorOutputSchema(OutputSchema): class SourceValidatorConfig(BaseToolConfig): """Configuration for source validator tool.""" - crossref_base_url: str = Field( - default="https://api.crossref.org/works", - description="Base URL for CrossRef API", - ) whitelist_file_path: Optional[str] = Field( - default_factory=lambda: str( - get_akd_root() / "docs" / "pubs_whitelist.json", - ), + default=None, description="Path to source whitelist JSON file", ) - timeout_seconds: int = Field( - default=30, - description="Timeout for API requests in seconds", - ) max_concurrent_requests: int = Field( default=10, - description="Maximum number of concurrent API requests", - ) - user_agent: str = Field( - default="SourceValidator/1.0", - description="User agent for API requests", + description="Maximum number of concurrent CrossRef API requests", + gt=0, + le=50, ) - debug: bool = Field(default=False, description="Enable debug logging") class SourceValidator( @@ -126,56 +216,166 @@ class SourceValidator( output_schema = SourceValidatorOutputSchema config_schema = SourceValidatorConfig + # Class-level cache for whitelist data to avoid reloading + _whitelist_cache: Dict[str, Dict[str, Any]] = {} + def __init__( self, config: Optional[SourceValidatorConfig] = None, debug: bool = False, - ): - """Initialize the source validator tool.""" - config = config or SourceValidatorConfig() - config.debug = debug + ) -> None: + """ + Initialize the source validator tool. + + Args: + config: Configuration for the tool + debug: Enable debug logging + + Raises: + RuntimeError: If initialization fails + """ + if config is None: + config = SourceValidatorConfig() + + # Set default whitelist path if not provided + if config.whitelist_file_path is None: + config.whitelist_file_path = str( + get_akd_root() / "docs" / "pubs_whitelist.json" + ) + super().__init__(config, debug) - # Load whitelist on initialization - self._whitelist = self._load_whitelist() - - # DOI extraction patterns - self._doi_patterns = [ - # Standard DOI URLs - r"(?:https?://)?(?:dx\.)?doi\.org/(?:10\.\d+/.+)", - r"(?:https?://)?(?:www\.)?dx\.doi\.org/(?:10\.\d+/.+)", - # DOI in URL path - r"(?:https?://[^/]+)?.*?(?:doi/|DOI:|doi:|DOI/)(\d{2}\.\d+/.+?)(?:[&?#]|$)", - # Bare DOI pattern - r"\b(10\.\d+/.+?)(?:\s|$|[&?#])", - # DOI in query parameters - r"[\?&]doi=([^&\s]+)", + # Load and validate whitelist on initialization with proper error handling + try: + self._whitelist = self._load_whitelist() + if not self._whitelist.get("data"): + raise ValueError("Whitelist data is empty or invalid") + except Exception as e: + logger.error(f"Failed to initialize source validator: {e}") + raise RuntimeError(f"Source validator initialization failed: {e}") from e + + # Create searchable index for rapid fuzzy matching + self._journal_index = self._build_journal_index() + + # Pre-compile DOI extraction patterns for performance + self._compiled_doi_patterns = self._compile_doi_patterns() + + def _compile_doi_patterns(self) -> List[re.Pattern[str]]: + """ + Compile optimized DOI extraction patterns for performance. + + Returns: + List of compiled regex patterns for DOI extraction + """ + return [ + # Combined standard DOI URLs (consolidating similar patterns) + re.compile( + r"(?:https?://)?(?:(?:dx\.|www\.)?doi\.org|dx\.doi\.org)/(10\.\d+/[^\s&?#]+)", + re.IGNORECASE, + ), + # DOI in URL path (consolidated path patterns) + re.compile( + r"(?:doi[:/]|DOI[:/]|/doi/|/DOI/)(10\.\d+/[^\s&?#]+)", + re.IGNORECASE, + ), + # DOI in query parameters (consolidated) + re.compile(r"[\?&]doi=([^&\s#]+)", re.IGNORECASE), + # URL-encoded DOI patterns + re.compile(r"(?:doi\.org%2F|doi%3A)(10\.[\d%]+%2F[^&\s#]+)", re.IGNORECASE), + # Publisher-specific patterns + re.compile(r"/article/(?:pii/)?[^/]*/?(10\.\d+/[^\s&?#]+)", re.IGNORECASE), + # Bare DOI pattern (most restrictive, used last) + re.compile(r"\b(10\.\d{4,}/[^\s&?#]{6,})(?=[\s&?#]|$)", re.IGNORECASE), ] - def _load_whitelist(self) -> Dict[str, Any]: - """Load source whitelist from JSON file.""" - whitelist_path = ( - self.config.whitelist_file_path - or get_akd_root() / "docs" / "pubs_whitelist.json" + @classmethod + def from_params( + cls, + whitelist_file_path: Optional[str] = None, + max_concurrent_requests: int = 10, + debug: bool = False, + ) -> "SourceValidator": + """ + Create a source validator from specific parameters. + + Args: + whitelist_file_path: Path to source whitelist JSON file + max_concurrent_requests: Maximum number of concurrent CrossRef API requests + debug: Enable debug logging + + Returns: + Configured SourceValidator instance + """ + config = SourceValidatorConfig( + whitelist_file_path=whitelist_file_path, + max_concurrent_requests=max_concurrent_requests, ) + return cls(config, debug=debug) + + @classmethod + def clear_whitelist_cache(cls) -> None: + """Clear the whitelist cache. Useful for testing or when whitelist files change.""" + cls._whitelist_cache.clear() + + def _load_whitelist(self) -> Dict[str, Any]: + """ + Load source whitelist from JSON file with caching and proper error handling. + + Returns: + Whitelist data dictionary + + Raises: + FileNotFoundError: If whitelist file doesn't exist + ValueError: If whitelist structure is invalid + """ + whitelist_path = self.config.whitelist_file_path + + if not whitelist_path: + raise ValueError("Whitelist file path not configured") + + # Check cache first + if whitelist_path in self._whitelist_cache: + if self.debug: + logger.info(f"Using cached whitelist for: {whitelist_path}") + return self._whitelist_cache[whitelist_path] try: - with open(whitelist_path, "r", encoding="utf-8") as f: - whitelist_data = json.load(f) + with open(whitelist_path, "rb") as f: # orjson requires binary mode + whitelist_data = orjson.loads(f.read()) + + # Validate whitelist structure + if not isinstance(whitelist_data, dict): + raise ValueError("Whitelist must be a JSON object") + + if "data" not in whitelist_data: + raise ValueError('Whitelist must contain "data" key') + + if not isinstance(whitelist_data["data"], dict): + raise ValueError('Whitelist "data" must be an object') + + # Cache the validated whitelist + self._whitelist_cache[whitelist_path] = whitelist_data if self.debug: + categories_count = len(whitelist_data.get("data", {})) logger.info( - f"Loaded whitelist with {len(whitelist_data.get('data', {}))} categories", + f"Loaded and cached whitelist with {categories_count} categories" ) return whitelist_data - except Exception as e: - logger.error(f"Failed to load whitelist from {whitelist_path}: {e}") - return {"data": {}, "metadata": {}} + + except FileNotFoundError as e: + raise FileNotFoundError( + f"Whitelist file not found: {whitelist_path}" + ) from e + except (orjson.JSONDecodeError, ValueError) as e: + raise ValueError( + f"Invalid JSON in whitelist file: {whitelist_path}. Error: {e}" + ) from e def _extract_doi_from_url(self, url: str) -> Optional[str]: """ - Extract DOI from URL using multiple patterns. + Extract DOI from URL using pre-compiled patterns. Args: url: URL to extract DOI from @@ -183,63 +383,143 @@ def _extract_doi_from_url(self, url: str) -> Optional[str]: Returns: Extracted DOI or None if not found """ + if not url: + return None + url_str = str(url).strip() + if not url_str: + return None - for pattern in self._doi_patterns: - matches = re.findall(pattern, url_str, re.IGNORECASE) - if matches: - doi = matches[0] if isinstance(matches[0], str) else matches[0][0] - # Clean up the DOI - doi = doi.strip().rstrip(".,;)") - # Ensure DOI starts with 10. - if not doi.startswith("10."): - continue - return doi + # Try with original URL first + for pattern in self._compiled_doi_patterns: + match = pattern.search(url_str) + if match: + doi = match.group(1).strip().rstrip(".,;)") + + # Handle URL-encoded DOIs + if "%2F" in doi or "%3A" in doi: + doi = urllib.parse.unquote(doi) + + # Validate DOI format with improved checks + if self._is_valid_doi_format(doi): + return doi + + # If no DOI found, try URL-decoding the entire URL and search again + try: + decoded_url = urllib.parse.unquote(url_str) + if decoded_url != url_str: + for pattern in self._compiled_doi_patterns: + match = pattern.search(decoded_url) + if match: + doi = match.group(1).strip().rstrip(".,;)") + if self._is_valid_doi_format(doi): + return doi + except Exception as e: + logger.debug(f"Error during URL decoding: {e}") return None - async def _fetch_crossref_metadata( - self, - session: aiohttp.ClientSession, - doi: str, + def _is_valid_doi_format(self, doi: str) -> bool: + """ + Validate DOI format with comprehensive checks according to DOI standards. + + Args: + doi: DOI string to validate + + Returns: + True if DOI format is valid, False otherwise + """ + if not doi or not isinstance(doi, str): + return False + + # Basic structure check + if not doi.startswith("10."): + return False + + if "/" not in doi: + return False + + # Minimum length check (e.g., "10.1/a" is theoretical minimum) + if len(doi) < 6: + return False + + # Split into prefix and suffix + parts = doi.split("/", 1) + if len(parts) != 2: + return False + + prefix, suffix = parts + + # Validate prefix: must be "10." followed by 4 or more digits + if not re.match(r"^10\.\d{4,}$", prefix): + return False + + # Validate suffix: must not be empty and contain valid characters + if not suffix or len(suffix.strip()) == 0: + return False + + # Check for invalid characters in suffix (very permissive as per DOI spec) + # DOI suffix can contain most printable characters except spaces and some control chars + if re.search(r"[\s\x00-\x1f\x7f]", suffix): + return False + + # Additional check: suffix should have reasonable length + if len(suffix) > 1000: # Extremely long suffixes are suspicious + return False + + return True + + async def _fetch_crossref_metadata_simple( + self, doi: str ) -> Optional[Dict[str, Any]]: """ - Fetch source metadata from CrossRef API. + Fetch CrossRef metadata using crossref-commons library with improved error handling. Args: - session: aiohttp session - doi: DOI to fetch metadata for + doi: DOI to look up Returns: - Metadata dictionary or None if failed + CrossRef metadata dictionary or None if not found """ - url = f"{self.config.crossref_base_url}/{doi}" - headers = {"User-Agent": self.config.user_agent, "Accept": "application/json"} + import asyncio + import concurrent.futures try: - async with session.get( - url, - headers=headers, - timeout=aiohttp.ClientTimeout(total=self.config.timeout_seconds), - ) as response: - if response.status == 200: - data = await response.json() - return data.get("message", {}) - elif response.status == 404: - if self.debug: - logger.warning(f"DOI not found in CrossRef: {doi}") + # Use ThreadPoolExecutor for better control over thread safety + # instead of asyncio.to_thread() which may have issues with crossref-commons + def fetch_doi_sync() -> Optional[Dict[str, Any]]: + """Synchronous wrapper for crossref-commons call.""" + try: + result = get_publication_as_json(doi) + return result + except Exception as e: + logger.warning(f"CrossRef request failed for DOI {doi}: {e}") return None - else: - logger.warning( - f"CrossRef API error {response.status} for DOI: {doi}", + + # Use ThreadPoolExecutor for better thread management + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(fetch_doi_sync) + try: + # Add timeout to prevent hanging requests + result = await asyncio.wait_for( + asyncio.wrap_future(future), + timeout=30.0, # 30 second timeout for CrossRef requests ) + + if result and self.debug: + logger.info( + f"Successfully fetched CrossRef metadata for DOI: {doi}" + ) + + return result + + except asyncio.TimeoutError: + logger.warning(f"CrossRef request timeout after 30s for DOI: {doi}") + future.cancel() return None - except aiohttp.ClientError as e: - logger.warning(f"Network error fetching CrossRef data for {doi}: {e}") - return None except Exception as e: - logger.error(f"Unexpected error fetching CrossRef data for {doi}: {e}") + logger.warning(f"CrossRef API error for DOI {doi}: {e}") return None def _parse_crossref_response( @@ -258,6 +538,9 @@ def _parse_crossref_response( Returns: SourceInfo object + + Raises: + ValueError: If required data is missing or invalid """ # Extract source title container_title = data.get("container-title", []) @@ -277,18 +560,7 @@ def _parse_crossref_response( ) # Extract open access information - is_open_access = None - license_info = data.get("license", []) - if license_info: - # Check for common open access license indicators - for license_item in license_info: - license_url = license_item.get("URL", "").lower() - if any( - oa_indicator in license_url - for oa_indicator in ["creativecommons", "cc-by", "open", "public"] - ): - is_open_access = True - break + is_open_access = self._determine_open_access_status(data) return SourceInfo( title=title, @@ -299,12 +571,130 @@ def _parse_crossref_response( url=original_url, ) + def _determine_open_access_status(self, data: Dict[str, Any]) -> Optional[bool]: + """ + Determine open access status from CrossRef data. + + Args: + data: CrossRef API response data + + Returns: + True if open access, False if not, None if unknown + """ + license_info = data.get("license", []) + if license_info: + # Check for legitimate open access license indicators + open_access_indicators = [ + "creativecommons.org", + "cc-by", + "cc-zero", + "cc0", + "cc-by-sa", + "cc-by-nc", + "cc-by-nc-sa", + "cc-by-nc-nd", + "cc-by-nd", + "opensource.org", + "gnu.org/licenses", + "apache.org/licenses", + "mit-license", + "bsd-license", + "/publicdomain/", + "unlicense.org", + ] + for license_item in license_info: + license_url = license_item.get("URL", "").lower() + if any( + indicator in license_url for indicator in open_access_indicators + ): + return True + + return None + + def _normalize_journal_title(self, title: str) -> str: + """ + Normalize journal title using ftfy, unidecode and common abbreviations. + + Args: + title: Raw journal title + + Returns: + Normalized title + """ + if not title: + return "" + + # Use ftfy to fix text encoding issues first + cleaned = ftfy.fix_text(title) + + # Use unidecode for unicode normalization + normalized = unidecode(cleaned).lower().strip() + + # Common journal abbreviations (reduced set - most important ones) + abbreviations = { + "&": "and", + "j.": "journal", + "rev.": "review", + "res.": "research", + "lett.": "letters", + "sci.": "science", + "phys.": "physics", + "geophys.": "geophysical", + "astrophys.": "astrophysical", + "astron.": "astronomical", + "proc.": "proceedings", + } + + for abbrev, full in abbreviations.items(): + normalized = normalized.replace(abbrev, full) + + # Clean up punctuation and whitespace + normalized = re.sub(r"[^\w\s]", " ", normalized) + normalized = re.sub(r"\s+", " ", normalized) + + return normalized.strip() + + def _calculate_similarity_score(self, str1: str, str2: str) -> float: + """ + Calculate similarity using rapidfuzz with weighted scoring for better accuracy. + + Args: + str1: First string + str2: Second string + + Returns: + Similarity score between 0.0 and 1.0 + """ + if not str1 or not str2: + return 0.0 + + # Normalize both strings + norm1 = self._normalize_journal_title(str1) + norm2 = self._normalize_journal_title(str2) + + if norm1 == norm2: + return 1.0 + + # Use rapidfuzz for multiple similarity metrics with weighted scoring + ratio = fuzz.ratio(norm1, norm2) / 100 + partial = fuzz.partial_ratio(norm1, norm2) / 100 + token_sort = fuzz.token_sort_ratio(norm1, norm2) / 100 + token_set = fuzz.token_set_ratio(norm1, norm2) / 100 + + # Weighted average with emphasis on token_set for journal matching + # token_set is most important for journal names with reordered words + weighted_score = ( + ratio * 0.2 + partial * 0.2 + token_sort * 0.25 + token_set * 0.35 + ) + + return weighted_score + def _validate_against_whitelist( self, source_info: SourceInfo, - ) -> tuple[bool, Optional[str], float]: + ) -> Tuple[bool, Optional[str], float]: """ - Validate source against whitelist. + Validate source against whitelist using rapidfuzz for efficient fuzzy matching. Args: source_info: Source information from CrossRef @@ -312,54 +702,63 @@ def _validate_against_whitelist( Returns: Tuple of (is_whitelisted, category, confidence_score) """ - if not self._whitelist.get("data"): + if not self._journal_index: return False, None, 0.0 - source_title = source_info.title.lower().strip() - - # Search through all categories in whitelist - for category_name, category_data in self._whitelist["data"].items(): - sources = category_data.get("journals", []) + source_title = self._normalize_journal_title(source_info.title) + if not source_title: + return False, None, 0.0 - for source_entry in sources: - if not source_entry or not isinstance(source_entry, dict): - continue + # Use rapidfuzz for efficient fuzzy matching across all journals + try: + result = process.extractOne( + source_title, + self._journal_index.keys(), + scorer=fuzz.token_set_ratio, + score_cutoff=75, # Minimum score threshold + ) - whitelisted_title = ( - (source_entry.get("Journal Name") or "").lower().strip() - ) - if not whitelisted_title: - continue + if result: + matched_title, score, _ = result + original_title, category = self._journal_index[matched_title] + confidence = score / 100.0 - # Exact title match - if source_title == whitelisted_title: - return True, category_name, 1.0 + if self.debug: + logger.info( + f'Matched "{source_info.title}" -> "{original_title}" ' + f"(score: {score}, category: {category})" + ) - # Fuzzy title match (partial match) - if ( - whitelisted_title in source_title - or source_title in whitelisted_title - ): - # Check if it's a meaningful match (not just common words) - if len(whitelisted_title) > 10 or len(source_title) > 10: - return True, category_name, 0.8 + return True, category, confidence - # TODO: Could add ISSN matching here if we had ISSN data in whitelist + except Exception as e: + logger.warning(f"Error during fuzzy matching: {e}") return False, None, 0.0 async def _validate_single_result( self, - session: aiohttp.ClientSession, - result: Any, + result: "SearchResultItem", ) -> ValidationResult: - """Validate a single search result.""" - validation_errors = [] + """ + Validate a single search result. + + Args: + result: Search result to validate + + Returns: + ValidationResult with detailed validation information + """ + validation_errors: List[str] = [] + + # Extract DOI with multiple fallback strategies + doi = None + if hasattr(result, "doi") and result.doi: + doi = result.doi + + if not doi: + doi = self._extract_doi_from_url(str(result.url)) - # Extract DOI - doi = getattr(result, "doi", None) or self._extract_doi_from_url( - str(result.url), - ) if not doi and hasattr(result, "pdf_url") and result.pdf_url: doi = self._extract_doi_from_url(str(result.pdf_url)) @@ -374,7 +773,7 @@ async def _validate_single_result( ) # Fetch metadata from CrossRef - crossref_data = await self._fetch_crossref_metadata(session, doi) + crossref_data = await self._fetch_crossref_metadata_simple(doi) if not crossref_data: validation_errors.append( @@ -389,12 +788,30 @@ async def _validate_single_result( ) # Parse source information - source_info = self._parse_crossref_response(crossref_data, doi, str(result.url)) + try: + source_info = self._parse_crossref_response( + crossref_data, doi, str(result.url) + ) + except Exception as e: + validation_errors.append(f"Failed to parse CrossRef response: {e}") + return ValidationResult( + source_info=None, + is_whitelisted=False, + whitelist_category=None, + validation_errors=validation_errors, + confidence_score=0.0, + ) # Validate against whitelist - is_whitelisted, category, confidence = self._validate_against_whitelist( - source_info, - ) + try: + is_whitelisted, category, confidence = self._validate_against_whitelist( + source_info + ) + except Exception as e: + validation_errors.append(f"Error during whitelist validation: {e}") + is_whitelisted, category, confidence = False, None, 0.0 + if self.debug: + logger.error(f"Whitelist validation error for DOI {doi}: {e}") return ValidationResult( source_info=source_info, @@ -407,55 +824,58 @@ async def _validate_single_result( async def _arun( self, params: SourceValidatorInputSchema, - **kwargs, + **kwargs: Any, ) -> SourceValidatorOutputSchema: """ Run the source validation tool. Args: params: Input parameters + **kwargs: Additional keyword arguments Returns: Validation results """ - # Update whitelist path if provided - if params.whitelist_file_path: - self.config.whitelist_file_path = params.whitelist_file_path - self._whitelist = self._load_whitelist() - - validated_results = [] - - # Create aiohttp session with semaphore for concurrent requests - connector = aiohttp.TCPConnector(limit=self.config.max_concurrent_requests) - timeout = aiohttp.ClientTimeout(total=self.config.timeout_seconds) + # Validate input + if not params.search_results: + return SourceValidatorOutputSchema( + validated_results=[], + summary={ + "total_processed": 0, + "whitelisted_count": 0, + "whitelisted_percentage": 0.0, + "error_count": 0, + "category_breakdown": {}, + "avg_confidence": 0.0, + }, + ) - async with aiohttp.ClientSession( - connector=connector, - timeout=timeout, - ) as session: - # Process results concurrently but with limited concurrency - import asyncio + validated_results: List[ValidationResult] = [] + # Process results with controlled concurrency using crossref-commons + # No need for HTTP session management as crossref-commons handles it internally + try: semaphore = asyncio.Semaphore(self.config.max_concurrent_requests) async def validate_with_semaphore( - result: Any, + result: "SearchResultItem", ) -> ValidationResult: async with semaphore: - return await self._validate_single_result(session, result) + return await self._validate_single_result(result) # Process all results concurrently tasks = [ validate_with_semaphore(result) for result in params.search_results ] - validated_results = await asyncio.gather(*tasks, return_exceptions=True) - # Handle any exceptions - final_results = [] - for i, result in enumerate(validated_results): + # Handle results and exceptions properly + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Process results and handle exceptions + for i, result in enumerate(results): if isinstance(result, Exception): logger.error(f"Error validating result {i}: {result}") - final_results.append( + validated_results.append( ValidationResult( source_info=None, is_whitelisted=False, @@ -465,49 +885,110 @@ async def validate_with_semaphore( ), ) else: - final_results.append(result) + validated_results.append(result) - validated_results = final_results + except Exception as e: + logger.error(f"Error during validation: {e}") + # Create error results for all inputs + validated_results = [ + ValidationResult( + source_info=None, + is_whitelisted=False, + whitelist_category=None, + validation_errors=[f"Validation error: {str(e)}"], + confidence_score=0.0, + ) + for _ in params.search_results + ] # Generate summary statistics + summary = self._generate_summary_statistics(validated_results) + + if self.debug: + logger.info(f"Validation summary: {summary}") + + return SourceValidatorOutputSchema( + validated_results=validated_results, + summary=summary, + ) + + def _generate_summary_statistics( + self, validated_results: List[ValidationResult] + ) -> Dict[str, Any]: + """ + Generate summary statistics from validation results. + + Args: + validated_results: List of validation results + + Returns: + Dictionary containing summary statistics + """ total_results = len(validated_results) whitelisted_count = sum(1 for r in validated_results if r.is_whitelisted) error_count = sum(1 for r in validated_results if r.validation_errors) # Category breakdown - category_counts = {} + category_counts: Dict[str, int] = {} for result in validated_results: if result.whitelist_category: category_counts[result.whitelist_category] = ( category_counts.get(result.whitelist_category, 0) + 1 ) - summary = { + # Calculate average confidence (only for non-error results) + confidence_scores = [ + r.confidence_score for r in validated_results if not r.validation_errors + ] + avg_confidence = ( + sum(confidence_scores) / len(confidence_scores) + if confidence_scores + else 0.0 + ) + + return { "total_processed": total_results, "whitelisted_count": whitelisted_count, "whitelisted_percentage": (whitelisted_count / total_results * 100) if total_results > 0 - else 0, + else 0.0, "error_count": error_count, "category_breakdown": category_counts, - "avg_confidence": sum(r.confidence_score for r in validated_results) - / total_results - if total_results > 0 - else 0, + "avg_confidence": avg_confidence, } + def _build_journal_index(self) -> Dict[str, Tuple[str, str]]: + """ + Build a searchable index of journal titles for fast fuzzy matching. + + Returns: + Dictionary mapping normalized titles to (original_title, category) + """ + index: Dict[str, Tuple[str, str]] = {} + + for category_name, category_data in self._whitelist.get("data", {}).items(): + sources = category_data.get("journals", []) + + for source_entry in sources: + if not source_entry or not isinstance(source_entry, dict): + continue + + original_title = source_entry.get("Journal Name", "") + if not original_title: + continue + + normalized_title = self._normalize_journal_title(original_title) + if normalized_title: + index[normalized_title] = (original_title, category_name) + if self.debug: - logger.info(f"Validation summary: {summary}") + logger.info(f"Built journal index with {len(index)} entries") - return SourceValidatorOutputSchema( - validated_results=validated_results, - summary=summary, - ) + return index def create_source_validator( whitelist_file_path: Optional[str] = None, - timeout_seconds: int = 30, max_concurrent_requests: int = 10, debug: bool = False, ) -> SourceValidator: @@ -516,8 +997,7 @@ def create_source_validator( Args: whitelist_file_path: Path to source whitelist JSON file - timeout_seconds: Timeout for API requests - max_concurrent_requests: Maximum concurrent requests + max_concurrent_requests: Maximum concurrent CrossRef API requests debug: Enable debug logging Returns: @@ -525,8 +1005,6 @@ def create_source_validator( """ config = SourceValidatorConfig( whitelist_file_path=whitelist_file_path, - timeout_seconds=timeout_seconds, max_concurrent_requests=max_concurrent_requests, - debug=debug, ) return SourceValidator(config, debug=debug) diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/simple_source_validation_example.py b/examples/simple_source_validation_example.py new file mode 100644 index 00000000..be25967b --- /dev/null +++ b/examples/simple_source_validation_example.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python3 +""" +Simple source validation example. + +This example demonstrates the core source validation functionality +without complex search dependencies. +""" + +import asyncio +import sys +from pathlib import Path + +# Add the project root to Python path +project_root = Path(__file__).parent.parent +sys.path.insert(0, str(project_root)) + +from akd.structures import SearchResultItem +from akd.tools.source_validator import create_source_validator + + +def create_test_search_results(): + """Create test search results with real DOIs for validation.""" + return [ + # Real AGU source paper - should be whitelisted + SearchResultItem( + url="https://doi.org/10.1029/2019GL084947", + title="Arctic Sea Ice Decline and Recovery in Observations and the NASA GISS Climate Model", + content="Analysis of Arctic sea ice trends using observations and climate model...", + query="arctic sea ice climate", + category="science", + doi="10.1029/2019GL084947", + published_date="2019-09-15", + engine="test", + ), + # Real Nature Geoscience paper - should be whitelisted + SearchResultItem( + url="https://doi.org/10.1038/s41561-020-0566-6", + title="Global carbon dioxide emissions from inland waters", + content="Assessment of CO2 emissions from rivers, lakes, and reservoirs...", + query="carbon emissions water", + category="science", + published_date="2020-04-20", + engine="test", + ), + # Real Astronomy & Astrophysics paper - should be whitelisted + SearchResultItem( + url="https://doi.org/10.1051/0004-6361/202038711", + title="Gaia early data release 3: The celestial reference frame", + content="Precise astrometric measurements from Gaia mission...", + query="gaia astrometry", + category="science", + published_date="2021-01-10", + engine="test", + ), + # Paper without DOI - should fail + SearchResultItem( + url="https://example.com/no-doi-paper", + title="Paper without DOI identifier", + content="This paper has no DOI and should fail validation...", + query="random research", + category="science", + published_date="2023-08-01", + engine="test", + ), + ] + + +async def main(): + """Main function demonstrating source validation.""" + print("Simple Source Validation Example") + print("=" * 50) + + # Create validator instance + print("Creating source validator...") + validator = create_source_validator(debug=True) + + # Create test search results + print("\nCreating test search results...") + test_results = create_test_search_results() + + print(f"Created {len(test_results)} test search results:") + for i, result in enumerate(test_results, 1): + print(f" {i}. {result.title}") + print(f" DOI: {result.doi or 'None'}") + print(f" URL: {result.url}") + + # Run validation + print("\nRunning validation against source whitelist...") + validation_input = {"search_results": test_results} + + try: + validation_results = await validator.arun(validation_input) + + print("\n" + "=" * 80) + print("VALIDATION RESULTS") + print("=" * 80) + + # Print summary + summary = validation_results.summary + print("\nSUMMARY:") + print(f" Total papers processed: {summary['total_processed']}") + print(f" Whitelisted papers: {summary['whitelisted_count']}") + print(f" Success rate: {summary['whitelisted_percentage']:.1f}%") + print(f" Papers with errors: {summary['error_count']}") + print(f" Average confidence: {summary['avg_confidence']:.2f}") + + if summary.get("category_breakdown"): + print("\n Categories found:") + for category, count in summary["category_breakdown"].items(): + print(f" {category}: {count} papers") + + # Print detailed results + print("\nDETAILED RESULTS:") + print("-" * 80) + + for i, result in enumerate(validation_results.validated_results, 1): + print(f"\n{i}. {test_results[i - 1].title}") + + if result.validation_errors: + print(" Status: ❌ FAILED") + print(f" Errors: {', '.join(result.validation_errors)}") + elif result.is_whitelisted: + print(" Status: ✅ WHITELISTED") + print(f" Category: {result.whitelist_category}") + print(f" Confidence: {result.confidence_score:.2f}") + else: + print(" Status: ⚠️ NOT WHITELISTED") + print(f" Confidence: {result.confidence_score:.2f}") + + if result.source_info: + si = result.source_info + print(f" Source: {si.title}") + print(f" Publisher: {si.publisher or 'Unknown'}") + print(f" DOI: {si.doi}") + print(f" Open Access: {si.is_open_access or 'Unknown'}") + if si.issn: + print(f" ISSN: {', '.join(si.issn)}") + + print("\n" + "=" * 80) + + # Determine success + if summary["whitelisted_count"] > 0: + print("✅ Validation completed successfully - found whitelisted sources!") + elif summary["error_count"] < summary["total_processed"]: + print( + "⚠️ Validation completed - no whitelisted sources found but CrossRef lookups worked" + ) + else: + print( + "❌ All validations failed - check network connectivity or DOI formats" + ) + + except Exception as e: + print(f"Error during validation: {e}") + import traceback + + traceback.print_exc() + return 1 + + return 0 + + +if __name__ == "__main__": + exit_code = asyncio.run(main()) + sys.exit(exit_code) diff --git a/examples/source_validation_test.py b/examples/source_validation_pipeline_test.py similarity index 94% rename from examples/source_validation_test.py rename to examples/source_validation_pipeline_test.py index d58922f0..d4701439 100644 --- a/examples/source_validation_test.py +++ b/examples/source_validation_pipeline_test.py @@ -1,15 +1,15 @@ #!/usr/bin/env python3 """ -Example script demonstrating the journal validation pipeline. +Example script demonstrating the source validation pipeline. This example shows how to: 1. Search for research papers using SearxNG and Semantic Scholar 2. Extract DOIs from the search results -3. Validate journals against a controlled whitelist using CrossRef API +3. Validate sources against a controlled whitelist using CrossRef API 4. Generate validation reports Usage: - python examples/journal_validation_pipeline.py + python examples/source_validation_test.py """ import asyncio @@ -23,7 +23,7 @@ from akd.structures import SearchResultItem from akd.tools.factory import create_default_search_tool -from akd.tools.source_validator import create_journal_validator +from akd.tools.source_validator import create_source_validator # Import SemanticScholarSearchTool separately since it's not in factory try: @@ -77,8 +77,8 @@ def __init__(self, debug: bool = False): print(f"Warning: Could not initialize Semantic Scholar tool: {e}") self.semantic_scholar_tool = None - # Initialize journal validator - self.journal_validator = create_journal_validator(debug=debug) + # Initialize source validator + self.source_validator = create_source_validator(debug=debug) async def search_multiple_sources( self, queries: List[str], max_results_per_query: int = 5 @@ -158,7 +158,7 @@ async def validate_search_results( """ validation_input = {"search_results": search_results} - validation_results = await self.journal_validator.arun(validation_input) + validation_results = await self.source_validator.arun(validation_input) return { "results": validation_results.validated_results, @@ -208,15 +208,15 @@ def print_validation_report(self, validation_data: dict): print(" Status: ⚠️ NOT WHITELISTED") print(f" Confidence: {result.confidence_score:.2f}") - if result.journal_info: - ji = result.journal_info - print(f" Journal: {ji.title}") - print(f" Publisher: {ji.publisher or 'Unknown'}") - print(f" DOI: {ji.doi}") - print(f" Open Access: {ji.is_open_access or 'Unknown'}") - print(f" Original URL: {ji.url}") - if ji.issn: - print(f" ISSN: {', '.join(ji.issn)}") + if result.source_info: + si = result.source_info + print(f" Source: {si.title}") + print(f" Publisher: {si.publisher or 'Unknown'}") + print(f" DOI: {si.doi}") + print(f" Open Access: {si.is_open_access or 'Unknown'}") + print(f" Original URL: {si.url}") + if si.issn: + print(f" ISSN: {', '.join(si.issn)}") print("\n" + "=" * 80) @@ -466,12 +466,12 @@ async def _enhanced_doi_extraction(self, search_result) -> str: return search_result.doi if hasattr(search_result, "url"): - doi = self.journal_validator._extract_doi_from_url(str(search_result.url)) + doi = self.source_validator._extract_doi_from_url(str(search_result.url)) if doi: return doi if hasattr(search_result, "pdf_url") and search_result.pdf_url: - doi = self.journal_validator._extract_doi_from_url( + doi = self.source_validator._extract_doi_from_url( str(search_result.pdf_url) ) if doi: @@ -520,7 +520,7 @@ async def enhanced_validate_search_results( # Use standard validation with enhanced results validation_input = {"search_results": enhanced_results} - validation_results = await self.journal_validator.arun(validation_input) + validation_results = await self.source_validator.arun(validation_input) return { "results": validation_results.validated_results, @@ -660,7 +660,7 @@ async def demo_enhanced_doi_resolution(): # Show DOI resolution statistics total_cases = len(test_cases) resolved_cases = sum( - 1 for result in validation_data["results"] if result.journal_info + 1 for result in validation_data["results"] if result.source_info ) print("\nDOI RESOLUTION STATISTICS:") diff --git a/pyproject.toml b/pyproject.toml index 6dc61315..e0286c4c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,13 @@ dependencies = [ "readability-lxml>=0.8.1", "requests>=2.32.3", "wikipedia>=1.4.0", + "rapidfuzz>=3.10.0", + "crossref-commons>=0.15.0", + "unidecode>=1.3.8", + "httpx>=0.27.0", + "requests-cache>=1.2.0", + "orjson>=3.10.0", + "ftfy>=6.3.0", ] [tool.poetry] diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..bb2b01fc --- /dev/null +++ b/setup.py @@ -0,0 +1,21 @@ +from setuptools import find_packages, setup + +setup( + name="research_workflow", + version="0.1.0", + packages=find_packages(where="src"), + package_dir={"": "src"}, + install_requires=[ + "langgraph>=0.0.10", + "pydantic>=2.0.0", + "loguru>=0.7.0", + "openai>=1.0.0", + "anthropic>=0.5.0", + "aiohttp>=3.9.0", + "requests>=2.31.0", + "beautifulsoup4>=4.12.0", + "pandas>=2.0.0", + "numpy>=1.24.0", + ], + python_requires=">=3.9", +)