From 58042e0e1ccf86f072f1d09f68d2e4a5493533c4 Mon Sep 17 00:00:00 2001 From: Sean Brar Date: Mon, 16 Feb 2026 22:14:09 -0800 Subject: [PATCH 1/3] feat: batch arXiv API calls and add retry resilience (v0.3.1) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Batch category queries into single OR query (N API calls → 1) - Match page_size to min(max_results, 100) to avoid over-fetching - Add tenacity exponential backoff (5s→15s→45s→90s) on arxiv.HTTPError - Add ArxivRateLimitError with friendly message for HTTP 429 - Remove ThreadPoolExecutor from fetch_recent_papers (single call now) - Update mock local_client and all tests for new categories list signature - Add 8 new unit tests for batched queries, page_size, retry, and 429s - Bump version to 0.3.1 --- CHANGELOG.md | 15 ++- pyproject.toml | 2 +- src/mocks/local_client.py | 82 ++++++------ src/paperweight/__init__.py | 3 +- src/paperweight/main.py | 49 +++++-- src/paperweight/scraper.py | 199 ++++++++++++++++++----------- tests/test_local_mirror.py | 143 +++++++++++++++------ tests/test_scraper.py | 246 ++++++++++++++++++++++++++++-------- 8 files changed, 521 insertions(+), 218 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7e503e1..5cec423 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.3.1] - 2026-02-16 + +### Changed +- arXiv categories are now batched into a single OR query (`cat:cs.AI OR cat:cs.CL`), reducing N parallel API calls to 1 +- `page_size` now matches `min(max_results, 100)` instead of always requesting 100 results +- `ThreadPoolExecutor` removed from `fetch_recent_papers()` since only one API call is made + +### Added +- Exponential backoff (via `tenacity`) on `arxiv.HTTPError` with waits of 5 → 15 → 45 → 90 s +- `ArxivRateLimitError` exception with user-friendly message for HTTP 429 responses +- 8 new unit tests covering batched queries, page_size matching, and retry behavior + ## [0.3.0] - 2026-02-15 ### Added @@ -87,7 +99,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Email notification system - YAML-based configuration -[Unreleased]: https://github.com/seanbrar/paperweight/compare/v0.3.0...HEAD +[Unreleased]: https://github.com/seanbrar/paperweight/compare/v0.3.1...HEAD +[0.3.1]: https://github.com/seanbrar/paperweight/compare/v0.3.0...v0.3.1 [0.3.0]: https://github.com/seanbrar/paperweight/compare/v0.2.0...v0.3.0 [0.2.0]: https://github.com/seanbrar/paperweight/compare/v0.1.2...v0.2.0 [0.1.2]: https://github.com/seanbrar/paperweight/compare/v0.1.1...v0.1.2 diff --git a/pyproject.toml b/pyproject.toml index 4c8db48..dd5b30e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "academic-paperweight" -version = "0.3.0" +version = "0.3.1" description = "Automated retrieval, filtering, and LLM-powered summarization of arXiv papers based on your research interests." readme = "README.md" requires-python = ">=3.11, <3.14" diff --git a/src/mocks/local_client.py b/src/mocks/local_client.py index 4d9e836..94e7eff 100644 --- a/src/mocks/local_client.py +++ b/src/mocks/local_client.py @@ -20,8 +20,7 @@ def mock_fetch_paper_content( - paper_id: str, - files_dir: Path = DEFAULT_FILES_DIR + paper_id: str, files_dir: Path = DEFAULT_FILES_DIR ) -> Tuple[Optional[bytes], Optional[str]]: """Mock replacement for paperweight.scraper.fetch_paper_content. @@ -37,7 +36,7 @@ def mock_fetch_paper_content( or (None, None) if no file found. """ # Normalize paper_id - strip version if present for base lookup - base_id = paper_id.split('v')[0] if 'v' in paper_id else paper_id + base_id = paper_id.split("v")[0] if "v" in paper_id else paper_id # Try different ID patterns (with/without version) id_patterns = [paper_id] @@ -48,7 +47,7 @@ def mock_fetch_paper_content( if paper_id == base_id: # Look for any versioned file for f in files_dir.glob(f"{base_id}v*.tar.gz"): - id_patterns.insert(0, f.stem.replace('.tar', '')) + id_patterns.insert(0, f.stem.replace(".tar", "")) break for f in files_dir.glob(f"{base_id}v*.pdf"): if f.stem not in id_patterns: @@ -70,17 +69,17 @@ def mock_fetch_paper_content( def mock_fetch_arxiv_papers( - category: str, + categories: List[str], start_date: Any, max_results: Optional[int] = None, - db_path: Path = DEFAULT_DB_PATH + db_path: Path = DEFAULT_DB_PATH, ) -> List[Dict[str, Any]]: """Mock replacement for paperweight.scraper.fetch_arxiv_papers. Reads paper metadata from local SQLite database instead of arXiv API. Args: - category: The arXiv category to filter by (e.g., "cs.AI") + categories: arXiv categories to filter by (e.g., ``['cs.AI', 'cs.CL']``) start_date: Not used in mock (we return all matching papers) max_results: Maximum number of results to return db_path: Path to the SQLite database @@ -95,8 +94,10 @@ def mock_fetch_arxiv_papers( conn.row_factory = sqlite3.Row cursor = conn.cursor() - sql = "SELECT * FROM papers WHERE categories LIKE ?" - params: List[Any] = [f"%{category}%"] + # Build category filter with OR logic + cat_conditions = " OR ".join(["categories LIKE ?" for _ in categories]) + sql = f"SELECT * FROM papers WHERE ({cat_conditions})" + params: List[Any] = [f"%{cat}%" for cat in categories] if max_results: sql += " LIMIT ?" @@ -107,12 +108,14 @@ def mock_fetch_arxiv_papers( papers = [] for row in rows: - papers.append({ - "title": row["title"], - "link": f"http://arxiv.org/abs/{row['id']}", - "date": datetime.fromisoformat(row["published"]).date(), - "abstract": row["abstract"], - }) + papers.append( + { + "title": row["title"], + "link": f"http://arxiv.org/abs/{row['id']}", + "date": datetime.fromisoformat(row["published"]).date(), + "abstract": row["abstract"], + } + ) conn.close() return papers @@ -132,18 +135,17 @@ def patch_scraper_for_local_mirror(monkeypatch, files_dir: Path = DEFAULT_FILES_ def patched_scraper(monkeypatch): patch_scraper_for_local_mirror(monkeypatch) """ + def local_fetch_paper_content(paper_id): return mock_fetch_paper_content(paper_id, files_dir) monkeypatch.setattr( - "paperweight.scraper.fetch_paper_content", - local_fetch_paper_content + "paperweight.scraper.fetch_paper_content", local_fetch_paper_content ) # Also patch the retry-decorated wrapper if needed monkeypatch.setattr( - "paperweight.scraper.fetch_arxiv_papers", - mock_fetch_arxiv_papers + "paperweight.scraper.fetch_arxiv_papers", mock_fetch_arxiv_papers ) @@ -159,7 +161,7 @@ def __init__( page_size: int = 100, delay_seconds: float = 3, num_retries: int = 3, - mirror_path: Path = DEFAULT_MIRROR_PATH + mirror_path: Path = DEFAULT_MIRROR_PATH, ): self.page_size = page_size self.delay_seconds = delay_seconds @@ -174,17 +176,15 @@ def __init__( ) def results( - self, - search: arxiv.Search, - offset: int = 0 + self, search: arxiv.Search, offset: int = 0 ) -> Generator[arxiv.Result, None, None]: """Execute search against local SQLite database.""" conn = sqlite3.connect(self.mirror_db_path) conn.row_factory = sqlite3.Row cursor = conn.cursor() - query_str = getattr(search, 'query', '') - id_list = getattr(search, 'id_list', []) + query_str = getattr(search, "query", "") + id_list = getattr(search, "id_list", []) sql = "SELECT * FROM papers WHERE 1=1" params: List[Any] = [] @@ -211,7 +211,7 @@ def results( params.append(f"%{term}%") params.append(f"%{term}%") - max_results = getattr(search, 'max_results', None) + max_results = getattr(search, "max_results", None) if max_results: sql += " LIMIT ?" params.append(int(max_results)) @@ -231,29 +231,29 @@ class Author: def __init__(self, name: str): self.name = name - authors = [Author(n.strip()) for n in row['authors'].split(',')] - paper_id = row['id'] + authors = [Author(n.strip()) for n in row["authors"].split(",")] + paper_id = row["id"] res = arxiv.Result( entry_id=f"http://arxiv.org/abs/{paper_id}", - updated=datetime.fromisoformat(row['updated']), - published=datetime.fromisoformat(row['published']), - title=row['title'], + updated=datetime.fromisoformat(row["updated"]), + published=datetime.fromisoformat(row["published"]), + title=row["title"], authors=authors, - summary=row['abstract'], + summary=row["abstract"], comment=None, journal_ref=None, - doi=row['doi'], - primary_category=row['categories'].split(',')[0].strip(), - categories=[cat.strip() for cat in row['categories'].split(',')], - links=[] + doi=row["doi"], + primary_category=row["categories"].split(",")[0].strip(), + categories=[cat.strip() for cat in row["categories"].split(",")], + links=[], ) # Monkey-patch download methods to use local files - local_pdf_path = row['local_file_path'] - local_source_path = row['local_source_path'] + local_pdf_path = row["local_file_path"] + local_source_path = row["local_source_path"] - def mock_download_pdf(dirpath: str = './', filename: str = '') -> str: + def mock_download_pdf(dirpath: str = "./", filename: str = "") -> str: if not filename: filename = f"{paper_id}.pdf" target_path = Path(dirpath) / filename @@ -263,7 +263,7 @@ def mock_download_pdf(dirpath: str = './', filename: str = '') -> str: return str(target_path) raise FileNotFoundError(f"Mock PDF file missing for {paper_id}") - def mock_download_source(dirpath: str = './', filename: str = '') -> str: + def mock_download_source(dirpath: str = "./", filename: str = "") -> str: if not filename: filename = f"{paper_id}.tar.gz" target_path = Path(dirpath) / filename @@ -275,6 +275,6 @@ def mock_download_source(dirpath: str = './', filename: str = '') -> str: res.download_pdf = mock_download_pdf # type: ignore res.download_source = mock_download_source # type: ignore - res.pdf_url = row['pdf_url'] + res.pdf_url = row["pdf_url"] return res diff --git a/src/paperweight/__init__.py b/src/paperweight/__init__.py index 16c4e8a..60de9cc 100644 --- a/src/paperweight/__init__.py +++ b/src/paperweight/__init__.py @@ -15,11 +15,12 @@ setup_and_get_papers, summarize_scored_papers, ) -from paperweight.scraper import get_recent_papers # noqa: E402 +from paperweight.scraper import ArxivRateLimitError, get_recent_papers # noqa: E402 from paperweight.utils import load_config # noqa: E402 __all__ = [ "__version__", + "ArxivRateLimitError", "get_recent_papers", "load_config", "process_and_summarize_papers", diff --git a/src/paperweight/main.py b/src/paperweight/main.py index 7a70c22..b59af4c 100644 --- a/src/paperweight/main.py +++ b/src/paperweight/main.py @@ -27,7 +27,11 @@ ) from paperweight.processor import process_papers from paperweight.progress import ProgressReporter -from paperweight.scraper import get_recent_papers, hydrate_papers_with_content +from paperweight.scraper import ( + ArxivRateLimitError, + get_recent_papers, + hydrate_papers_with_content, +) from paperweight.storage import ( create_run, finish_run, @@ -80,7 +84,9 @@ """ -def setup_and_get_papers(force_refresh, include_content=True, config_path="config.yaml", profile=None): +def setup_and_get_papers( + force_refresh, include_content=True, config_path="config.yaml", profile=None +): """Set up the application and fetch papers. Args: @@ -162,7 +168,9 @@ def summarize_scored_papers(processed_papers, config): return None summary_concurrency = config.get("concurrency", {}).get("summary") - summaries = get_abstracts(processed_papers, config["analyzer"], summary_concurrency=summary_concurrency) + summaries = get_abstracts( + processed_papers, config["analyzer"], summary_concurrency=summary_concurrency + ) for paper, summary in zip(processed_papers, summaries): paper["summary"] = ( summary if summary else paper.get("abstract", "No summary available") @@ -256,6 +264,8 @@ def _get_error_message(error): Returns: Human-readable error description string. """ + if isinstance(error, ArxivRateLimitError): + return str(error) if isinstance(error, requests.RequestException): return "Network error occurred" if isinstance(error, yaml.YAMLError): @@ -302,7 +312,9 @@ def _deliver_output(processed_papers, config, args): sort_order=args.sort_order, feed_title=feed_config.get("title", "paperweight"), feed_id=feed_config.get("id", "https://github.com/seanbrar/paperweight"), - feed_link=feed_config.get("link", "https://github.com/seanbrar/paperweight"), + feed_link=feed_config.get( + "link", "https://github.com/seanbrar/paperweight" + ), ) write_output(feed_xml, args.output) return @@ -311,7 +323,9 @@ def _deliver_output(processed_papers, config, args): if not notifier_config: raise ValueError("Email delivery requested but notifier config is missing.") - notification_sent = compile_and_send_notifications(processed_papers, notifier_config) + notification_sent = compile_and_send_notifications( + processed_papers, notifier_config + ) if notification_sent: logger.info("Notifications compiled and sent successfully") else: @@ -448,7 +462,9 @@ def _parse_args(argv: list[str] | None = None) -> argparse.Namespace: def _write_minimal_config(path: str, force: bool = False) -> None: target = Path(path) if target.exists() and not force: - raise ValueError(f"Config file already exists: {target}. Use --force to overwrite.") + raise ValueError( + f"Config file already exists: {target}. Use --force to overwrite." + ) base_template = Path("config-base.yaml") content = ( @@ -543,7 +559,11 @@ def _run_pipeline(args: argparse.Namespace) -> int: # noqa: C901 config_path=args.config, profile=getattr(args, "profile", None), ) - if args.max_items and args.max_items > 0 and len(recent_papers) > args.max_items: + if ( + args.max_items + and args.max_items > 0 + and len(recent_papers) > args.max_items + ): logger.info( "Applying max-items compute cap: processing first %s of %s fetched papers", args.max_items, @@ -563,7 +583,9 @@ def _run_pipeline(args: argparse.Namespace) -> int: # noqa: C901 if not triaged_papers: logger.info("AI triage selected no papers. Exiting.") triaged_papers = [] - progress.phase_end("triaging...", f"{len(triaged_papers)}/{len(recent_papers)} selected") + progress.phase_end( + "triaging...", f"{len(triaged_papers)}/{len(recent_papers)} selected" + ) # 3. Score (title + abstract keywords — no content needed) progress.phase("scoring...") @@ -578,7 +600,9 @@ def _run_pipeline(args: argparse.Namespace) -> int: # noqa: C901 else: progress.phase_end( "scoring...", - f"{len(scored_papers)} papers above threshold" if scored_papers else "0 papers above threshold", + f"{len(scored_papers)} papers above threshold" + if scored_papers + else "0 papers above threshold", ) # 4. Hydrate ONLY if analyzer needs full content (summary mode) @@ -611,6 +635,7 @@ def _run_pipeline(args: argparse.Namespace) -> int: # noqa: C901 run_status = "success" except ( + ArxivRateLimitError, requests.RequestException, yaml.YAMLError, KeyError, @@ -651,7 +676,11 @@ def main(argv: list[str] | None = None) -> int: print(f"paperweight init: {exc}", file=sys.stderr) return 1 if args.command == "doctor": - return _doctor(args.config, strict=getattr(args, "strict", False), profile=getattr(args, "profile", None)) + return _doctor( + args.config, + strict=getattr(args, "strict", False), + profile=getattr(args, "profile", None), + ) return _run_pipeline(args) diff --git a/src/paperweight/scraper.py b/src/paperweight/scraper.py index 3e168f1..2a4d25b 100644 --- a/src/paperweight/scraper.py +++ b/src/paperweight/scraper.py @@ -12,13 +12,18 @@ import logging import os import tarfile -from concurrent.futures import ThreadPoolExecutor, as_completed +from concurrent.futures import ( + ThreadPoolExecutor, + as_completed, +) # ThreadPoolExecutor still used by fetch_paper_contents from datetime import date, datetime, timedelta from typing import Any, Dict, List, Optional import arxiv +import arxiv as _arxiv_module import requests from tenacity import ( + RetryCallState, retry, retry_if_exception_type, stop_after_attempt, @@ -36,13 +41,43 @@ logger = logging.getLogger(__name__) +class ArxivRateLimitError(RuntimeError): + """Raised when arXiv returns HTTP 429 (Too Many Requests). + + Provides a user-friendly error message instead of a raw stack trace. + """ + + def __init__(self, original: Exception | None = None): + message = ( + "arXiv rate-limited our request (HTTP 429).\n" + " The API allows ≤1 request every 3 seconds.\n" + " Please wait a few minutes and try again, " + "or reduce arxiv.max_results." + ) + super().__init__(message) + self.original = original + + +def _log_arxiv_retry(retry_state: RetryCallState) -> None: + """Log a tenacity retry attempt for arXiv API calls.""" + wait = retry_state.next_action.sleep if retry_state.next_action else 0 + logger.warning( + "arXiv request failed (attempt %d), retrying in %.0fs…", + retry_state.attempt_number, + wait, + ) + + def fetch_arxiv_papers( - category: str, start_date: date, max_results: Optional[int] = None + categories: List[str], start_date: date, max_results: Optional[int] = None ) -> List[Dict[str, Any]]: - """Fetch papers from arXiv API for a specific category and date range. + """Fetch papers from arXiv API for one or more categories. + + Categories are batched into a single API query using OR syntax + (e.g. ``cat:cs.AI OR cat:cs.CL``) to minimize HTTP requests. Args: - category: The arXiv category to fetch papers from (e.g., 'cs.AI'). + categories: arXiv categories to fetch (e.g., ``['cs.AI', 'cs.CL']``). start_date: The date from which to start fetching papers. max_results: Optional maximum number of results to return. @@ -50,19 +85,22 @@ def fetch_arxiv_papers( List of dictionaries containing paper metadata. Raises: - requests.ConnectionError: If connection to arXiv API fails. - requests.Timeout: If the request times out. + ArxivRateLimitError: If arXiv returns HTTP 429 after all retries. """ - logger.debug(f"Fetching arXiv papers for category '{category}' since {start_date}") + logger.debug( + "Fetching arXiv papers for categories %s since %s", categories, start_date + ) + + # Build a single batched query: "cat:cs.AI OR cat:cs.CL OR …" + query = " OR ".join(f"cat:{c}" for c in categories) - # Construct the query - query = f"cat:{category}" + # Match page_size to max_results so we don't over-fetch + effective_page_size = min(max_results, 100) if max_results else 100 - # Configure the client client = arxiv.Client( - page_size=100, + page_size=effective_page_size, delay_seconds=3.0, - num_retries=3 + num_retries=3, ) search = arxiv.Search( @@ -72,18 +110,41 @@ def fetch_arxiv_papers( sort_order=arxiv.SortOrder.Descending, ) - papers = [] + return _fetch_with_backoff(client, search, start_date, max_results) + + +@retry( + stop=stop_after_attempt(4), + wait=wait_exponential(multiplier=3, min=5, max=90), + retry=retry_if_exception_type(_arxiv_module.HTTPError), + before_sleep=_log_arxiv_retry, + reraise=True, +) +def _fetch_with_backoff( + client: arxiv.Client, + search: arxiv.Search, + start_date: date, + max_results: Optional[int], +) -> List[Dict[str, Any]]: + """Consume ``client.results()`` with tenacity retry on HTTP errors. + + The arxiv.py library raises ``arxiv.HTTPError`` for non-200 responses + (including 429). This wrapper adds exponential backoff + (5 s → 15 s → 45 s) on top of the library's own flat retry. + """ + papers: List[Dict[str, Any]] = [] try: - # Iterate through the results for result in client.results(search): submitted_date = result.published.date() - logger.debug(f"Paper '{result.title}' submitted on {submitted_date}") + logger.debug("Paper '%s' submitted on %s", result.title, submitted_date) if submitted_date < start_date: logger.debug( - f"Stopping fetch: paper date {submitted_date} is before start date {start_date}" + "Stopping fetch: paper date %s is before start date %s", + submitted_date, + start_date, ) break @@ -101,19 +162,23 @@ def fetch_arxiv_papers( } ) - # Safety break if max_results is set multiple times or if the generator doesn't stop - if max_results is not None and max_results > 0 and len(papers) >= max_results: + if ( + max_results is not None + and max_results > 0 + and len(papers) >= max_results + ): break - except Exception as e: - # Map arxiv errors or other unexpected errors - logger.error(f"Error fetching papers: {e}") - # We might want to re-raise or handle gracefully depending on the exact error - # For now, consistent with previous behavior, let's allow tenacity or caller to handle - raise + except _arxiv_module.HTTPError as exc: + if getattr(exc, "status", None) == 429: + raise ArxivRateLimitError(original=exc) from exc + raise logger.info( - f"Successfully fetched {len(papers)} papers for category '{category}' since {start_date}" + "Successfully fetched %d papers for query '%s' since %s", + len(papers), + search.query, + start_date, ) return papers @@ -121,62 +186,46 @@ def fetch_arxiv_papers( def fetch_recent_papers(config, start_days=1): """Fetch papers published within the last specified number of days. + All configured categories are combined into a single arXiv API query + using OR syntax to minimize HTTP requests. + Args: + config: Application configuration dictionary. start_days: Number of days to look back for papers. Returns: List of dictionaries containing paper metadata. """ categories = config["arxiv"]["categories"] - max_results = config["arxiv"].get("max_results", 0) # Default to 0 if not set + max_results = config["arxiv"].get("max_results", 0) end_date = datetime.now().date() start_date = end_date - timedelta(days=start_days) - logger.info(f"Fetching papers from {start_date} to {end_date}") + logger.info("Fetching papers from %s to %s", start_date, end_date) + logger.info("Categories: %s (single batched query)", categories) - def _fetch_category(category): - logger.info(f"Processing category: {category}") - try: - return category, fetch_arxiv_papers( - category, - start_date, - max_results=max_results if max_results > 0 else None, - ) - except ValueError as ve: - logger.error(f"Error fetching papers for category {category}: {ve}") - return category, [] - - all_papers = [] - processed_ids: set = set() - - workers = min(len(categories), 4) if categories else 1 - with ThreadPoolExecutor(max_workers=workers) as executor: - futures = {executor.submit(_fetch_category, cat): cat for cat in categories} - # Collect in submission order for deterministic results - results_by_cat = {} - for future in as_completed(futures): - cat, papers = future.result() - results_by_cat[cat] = papers - - for category in categories: - papers = results_by_cat.get(category, []) - new_papers = [ - paper - for paper in papers - if paper["link"].split("/abs/")[-1] not in processed_ids - ] - processed_ids.update( - paper["link"].split("/abs/")[-1] for paper in new_papers - ) + papers = fetch_arxiv_papers( + categories, + start_date, + max_results=max_results if max_results > 0 else None, + ) - if max_results > 0: - new_papers = new_papers[:max_results] + # Deduplicate by arXiv ID (papers can appear in multiple categories) + seen_ids: set = set() + unique_papers: list = [] + for paper in papers: + paper_id = paper["link"].split("/abs/")[-1] + if paper_id not in seen_ids: + seen_ids.add(paper_id) + unique_papers.append(paper) - all_papers.extend(new_papers) - logger.debug(f"Added {len(new_papers)} new papers from category {category}") + if max_results > 0: + unique_papers = unique_papers[:max_results] - logger.info(f"Fetched a total of {len(all_papers)} papers") - return all_papers + logger.info( + "Fetched %d unique papers (from %d raw)", len(unique_papers), len(papers) + ) + return unique_papers @retry( @@ -346,7 +395,9 @@ def _hydrate_papers_with_content(papers, config, db_enabled): artifacts = [] if db_enabled: - artifacts = _store_artifacts(paper_id, method, content, text, storage_base) + artifacts = _store_artifacts( + paper_id, method, content, text, storage_base + ) paper_with_content = dict(paper) paper_with_content.update( @@ -359,7 +410,9 @@ def _hydrate_papers_with_content(papers, config, db_enabled): ) papers_with_content.append(paper_with_content) - logger.info("Hydrated %s/%s papers with full content", len(papers_with_content), len(papers)) + logger.info( + "Hydrated %s/%s papers with full content", len(papers_with_content), len(papers) + ) return papers_with_content @@ -490,7 +543,9 @@ def get_recent_papers(config, force_refresh=False, include_content=True): # noq else: days = (current_date - last_processed_date).days if days == 0: - logger.info("Already processed papers for today. No new papers to fetch.") + logger.info( + "Already processed papers for today. No new papers to fetch." + ) return [] elif days > 7: # If more than a week has passed, limit to 7 days to avoid overload @@ -572,7 +627,9 @@ def _store_artifacts(paper_id, method, content, text, storage_base): try: _write_bytes(raw_path, content) artifacts.append( - _artifact_record("source" if method == "source" else "pdf", raw_path, content) + _artifact_record( + "source" if method == "source" else "pdf", raw_path, content + ) ) except OSError as e: logger.error("Failed to write source artifact %s: %s", raw_path, e) diff --git a/tests/test_local_mirror.py b/tests/test_local_mirror.py index 52de332..f7c0ee2 100644 --- a/tests/test_local_mirror.py +++ b/tests/test_local_mirror.py @@ -52,7 +52,9 @@ def test_mock_fetch_paper_content_source(self, local_mirror_files: Path): assert method == "source" assert len(content) > 0 - def test_mock_fetch_paper_content_pdf_fallback(self, local_mirror_files: Path, tmp_path: Path): + def test_mock_fetch_paper_content_pdf_fallback( + self, local_mirror_files: Path, tmp_path: Path + ): """Test PDF fallback when source is missing.""" # Create a PDF-only test case test_pdf = tmp_path / "test_paper.pdf" @@ -73,10 +75,10 @@ def test_mock_fetch_paper_content_not_found(self, tmp_path: Path): def test_mock_fetch_arxiv_papers(self, local_mirror_db: Path): """Fetch papers by category from local database.""" papers = mock_fetch_arxiv_papers( - category="cs.AI", + categories=["cs.AI"], start_date=date(2024, 1, 1), max_results=5, - db_path=local_mirror_db + db_path=local_mirror_db, ) # May be empty if no cs.AI papers, but should not error @@ -118,9 +120,19 @@ def test_base_id_matches_versioned_paper(self, tmp_path: Path): doi TEXT, local_file_path TEXT, local_source_path TEXT)""") conn.execute( "INSERT INTO papers VALUES (?,?,?,?,?,?,?,?,?,?,?)", - ("1706.03762v7", "Attention Is All You Need", "Abstract", - "Vaswani et al.", "cs.CL,cs.LG", "2017-06-12", "2017-06-12", - "http://arxiv.org/pdf/1706.03762v7", None, None, None), + ( + "1706.03762v7", + "Attention Is All You Need", + "Abstract", + "Vaswani et al.", + "cs.CL,cs.LG", + "2017-06-12", + "2017-06-12", + "http://arxiv.org/pdf/1706.03762v7", + None, + None, + None, + ), ) conn.commit() conn.close() @@ -152,8 +164,19 @@ def test_versioned_id_exact_match(self, tmp_path: Path): doi TEXT, local_file_path TEXT, local_source_path TEXT)""") conn.execute( "INSERT INTO papers VALUES (?,?,?,?,?,?,?,?,?,?,?)", - ("1706.03762v7", "Test Paper", "Abstract", "Author", "cs.AI", - "2017-06-12", "2017-06-12", "http://example.com", None, None, None), + ( + "1706.03762v7", + "Test Paper", + "Abstract", + "Author", + "cs.AI", + "2017-06-12", + "2017-06-12", + "http://example.com", + None, + None, + None, + ), ) conn.commit() conn.close() @@ -185,13 +208,35 @@ def test_base_id_finds_multiple_versions(self, tmp_path: Path): doi TEXT, local_file_path TEXT, local_source_path TEXT)""") conn.execute( "INSERT INTO papers VALUES (?,?,?,?,?,?,?,?,?,?,?)", - ("1706.03762v1", "Paper v1", "Abstract", "Author", "cs.AI", - "2017-06-12", "2017-06-12", "http://example.com", None, None, None), + ( + "1706.03762v1", + "Paper v1", + "Abstract", + "Author", + "cs.AI", + "2017-06-12", + "2017-06-12", + "http://example.com", + None, + None, + None, + ), ) conn.execute( "INSERT INTO papers VALUES (?,?,?,?,?,?,?,?,?,?,?)", - ("1706.03762v7", "Paper v7", "Abstract", "Author", "cs.AI", - "2017-06-12", "2017-12-01", "http://example.com", None, None, None), + ( + "1706.03762v7", + "Paper v7", + "Abstract", + "Author", + "cs.AI", + "2017-06-12", + "2017-12-01", + "http://example.com", + None, + None, + None, + ), ) conn.commit() conn.close() @@ -215,7 +260,9 @@ def test_client_search_by_category(self, mock_arxiv_client: MockArxivClient): # Results depend on what's in the mirror assert isinstance(results, list) - def test_client_search_by_id(self, mock_arxiv_client: MockArxivClient, local_mirror_db: Path): + def test_client_search_by_id( + self, mock_arxiv_client: MockArxivClient, local_mirror_db: Path + ): """Search by ID returns matching paper.""" import arxiv @@ -237,14 +284,18 @@ def test_client_search_by_id(self, mock_arxiv_client: MockArxivClient, local_mir assert len(results) == 1 assert paper_id in results[0].entry_id - def test_result_has_download_methods(self, mock_arxiv_client: MockArxivClient, local_mirror_db: Path): + def test_result_has_download_methods( + self, mock_arxiv_client: MockArxivClient, local_mirror_db: Path + ): """arxiv.Result objects have mocked download methods.""" import arxiv # Get a paper with files conn = sqlite3.connect(local_mirror_db) cursor = conn.cursor() - cursor.execute("SELECT id FROM papers WHERE local_file_path IS NOT NULL LIMIT 1") + cursor.execute( + "SELECT id FROM papers WHERE local_file_path IS NOT NULL LIMIT 1" + ) row = cursor.fetchone() conn.close() @@ -259,8 +310,8 @@ def test_result_has_download_methods(self, mock_arxiv_client: MockArxivClient, l result = results[0] # Check methods exist - assert hasattr(result, 'download_pdf') - assert hasattr(result, 'download_source') + assert hasattr(result, "download_pdf") + assert hasattr(result, "download_source") assert callable(result.download_pdf) assert callable(result.download_source) @@ -324,10 +375,10 @@ def test_fetch_process_score( # Get papers from local database papers_raw = mock_fetch_arxiv_papers( - category="cs", # Broad category + categories=["cs"], # Broad category start_date=date(2024, 1, 1), max_results=5, - db_path=local_mirror_db + db_path=local_mirror_db, ) if not papers_raw: @@ -342,15 +393,17 @@ def test_fetch_process_score( if content and method: try: text = extract_text_from_source(content, method) - papers_with_content.append({ - "id": paper_id, - "title": paper["title"], - "link": paper["link"], - "date": paper["date"], - "abstract": paper["abstract"], - "content": text or "", - "content_type": method, - }) + papers_with_content.append( + { + "id": paper_id, + "title": paper["title"], + "link": paper["link"], + "date": paper["date"], + "abstract": paper["abstract"], + "content": text or "", + "content_type": method, + } + ) except Exception: # Skip papers that fail extraction continue @@ -382,10 +435,10 @@ def test_pipeline_with_production_config( # Get papers papers_raw = mock_fetch_arxiv_papers( - category="cs.AI", + categories=["cs.AI"], start_date=date(2024, 1, 1), max_results=10, - db_path=local_mirror_db + db_path=local_mirror_db, ) if not papers_raw: @@ -400,15 +453,17 @@ def test_pipeline_with_production_config( if content and method: try: text = extract_text_from_source(content, method) - papers_with_content.append({ - "id": paper_id, - "title": paper["title"], - "link": paper["link"], - "date": paper["date"], - "abstract": paper["abstract"], - "content": text or paper["abstract"], - "content_type": method, - }) + papers_with_content.append( + { + "id": paper_id, + "title": paper["title"], + "link": paper["link"], + "date": paper["date"], + "abstract": paper["abstract"], + "content": text or paper["abstract"], + "content_type": method, + } + ) except Exception: continue @@ -426,12 +481,14 @@ def test_pipeline_with_production_config( class TestNoNetworkCalls: """Verify that local mirror tests make no real network calls.""" - def test_mock_functions_are_offline(self, local_mirror_files: Path, local_mirror_db: Path): + def test_mock_functions_are_offline( + self, local_mirror_files: Path, local_mirror_db: Path + ): """Ensure mock functions don't make HTTP requests.""" with patch("requests.get") as mock_get, patch("requests.post") as mock_post: # Call mock functions mock_fetch_paper_content("1706.03762", local_mirror_files) - mock_fetch_arxiv_papers("cs.AI", date.today(), 5, local_mirror_db) + mock_fetch_arxiv_papers(["cs.AI"], date.today(), 5, local_mirror_db) # Verify no HTTP calls mock_get.assert_not_called() @@ -472,7 +529,9 @@ def test_attention_paper_high_relevance( # Get metadata from DB conn = sqlite3.connect(local_mirror_db) cursor = conn.cursor() - cursor.execute("SELECT title, abstract FROM papers WHERE id LIKE ?", ("1706.03762%",)) + cursor.execute( + "SELECT title, abstract FROM papers WHERE id LIKE ?", ("1706.03762%",) + ) row = cursor.fetchone() conn.close() diff --git a/tests/test_scraper.py b/tests/test_scraper.py index 87863d5..0cabec8 100644 --- a/tests/test_scraper.py +++ b/tests/test_scraper.py @@ -5,6 +5,7 @@ from paperweight.db import DatabaseConnectionError from paperweight.scraper import ( + ArxivRateLimitError, _write_metadata_cache, extract_text_from_source, fetch_arxiv_papers, @@ -13,120 +14,261 @@ ) -@patch('paperweight.scraper.arxiv.Client') +# --------------------------------------------------------------------------- +# fetch_arxiv_papers — batched OR query +# --------------------------------------------------------------------------- + + +@patch("paperweight.scraper.arxiv.Client") def test_fetch_arxiv_papers(MockClient): mock_client_instance = MockClient.return_value # Mock results result1 = MagicMock() - result1.title = 'Test Paper 1' - result1.entry_id = 'http://arxiv.org/abs/2401.12345' + result1.title = "Test Paper 1" + result1.entry_id = "http://arxiv.org/abs/2401.12345" result1.published = datetime(2024, 1, 15) - result1.summary = 'This is test abstract 1.' + result1.summary = "This is test abstract 1." result2 = MagicMock() - result2.title = 'Test Paper 2' - result2.entry_id = 'http://arxiv.org/abs/2401.67890' + result2.title = "Test Paper 2" + result2.entry_id = "http://arxiv.org/abs/2401.67890" result2.published = datetime(2024, 1, 14, 12, 0, 0) - result2.summary = 'This is test abstract 2.' + result2.summary = "This is test abstract 2." mock_client_instance.results.return_value = [result1, result2] start_date = datetime(2024, 1, 14).date() - papers = fetch_arxiv_papers('cs.AI', start_date, max_results=2) + papers = fetch_arxiv_papers(["cs.AI"], start_date, max_results=2) assert len(papers) == 2 - assert papers[0]['title'] == 'Test Paper 1' - assert papers[1]['title'] == 'Test Paper 2' - assert papers[0]['date'] == datetime(2024, 1, 15).date() - assert papers[1]['date'] == datetime(2024, 1, 14).date() + assert papers[0]["title"] == "Test Paper 1" + assert papers[1]["title"] == "Test Paper 2" + assert papers[0]["date"] == datetime(2024, 1, 15).date() + assert papers[1]["date"] == datetime(2024, 1, 14).date() -@patch('paperweight.scraper.arxiv.Client') +@patch("paperweight.scraper.arxiv.Client") def test_fetch_arxiv_papers_error(MockClient): mock_client_instance = MockClient.return_value mock_client_instance.results.side_effect = Exception("General Error") with pytest.raises(Exception, match="General Error"): - fetch_arxiv_papers('cs.AI', date.today(), max_results=10) + fetch_arxiv_papers(["cs.AI"], date.today(), max_results=10) -@patch('paperweight.scraper.arxiv.Client') +@patch("paperweight.scraper.arxiv.Client") def test_fetch_arxiv_papers_max_results(MockClient): mock_client_instance = MockClient.return_value result1 = MagicMock() - result1.title = 'Test Paper 1' - result1.entry_id = 'http://arxiv.org/abs/2401.12345' + result1.title = "Test Paper 1" + result1.entry_id = "http://arxiv.org/abs/2401.12345" result1.published = datetime(2024, 1, 15) - result1.summary = 'Summary 1' + result1.summary = "Summary 1" result2 = MagicMock() - result2.title = 'Test Paper 2' - result2.entry_id = 'http://arxiv.org/abs/2401.67890' + result2.title = "Test Paper 2" + result2.entry_id = "http://arxiv.org/abs/2401.67890" result2.published = datetime(2024, 1, 14) - result2.summary = 'Summary 2' + result2.summary = "Summary 2" result3 = MagicMock() - result3.title = 'Test Paper 3' - result3.entry_id = 'http://arxiv.org/abs/2401.11111' + result3.title = "Test Paper 3" + result3.entry_id = "http://arxiv.org/abs/2401.11111" result3.published = datetime(2024, 1, 13) - result3.summary = 'Summary 3' + result3.summary = "Summary 3" - # We simulate the iterator returning these mock_client_instance.results.return_value = [result1, result2, result3] - start_date = datetime(2024, 1, 13).date() # Test with max_results=2 - # We need to reset the mock if we want to run multiple calls in one test safely regarding return values if they were stateful iterators, - # but here it returns a list which is iterable multiple times. - - papers = fetch_arxiv_papers('cs.AI', start_date, max_results=2) + papers = fetch_arxiv_papers(["cs.AI"], start_date, max_results=2) assert len(papers) == 2 - assert papers[0]['title'] == 'Test Paper 1' - assert papers[1]['title'] == 'Test Paper 2' + assert papers[0]["title"] == "Test Paper 1" + assert papers[1]["title"] == "Test Paper 2" # Test with max_results=None - papers = fetch_arxiv_papers('cs.AI', start_date, max_results=None) + papers = fetch_arxiv_papers(["cs.AI"], start_date, max_results=None) assert len(papers) == 3 - assert papers[2]['title'] == 'Test Paper 3' + assert papers[2]["title"] == "Test Paper 3" # Test with max_results=0 - papers = fetch_arxiv_papers('cs.AI', start_date, max_results=0) + papers = fetch_arxiv_papers(["cs.AI"], start_date, max_results=0) assert len(papers) == 3 - assert papers[2]['title'] == 'Test Paper 3' + assert papers[2]["title"] == "Test Paper 3" + + +# --------------------------------------------------------------------------- +# Batched OR query construction +# --------------------------------------------------------------------------- + + +@patch("paperweight.scraper.arxiv.Client") +@patch("paperweight.scraper.arxiv.Search") +def test_batched_or_query_construction(MockSearch, MockClient): + """Multiple categories are combined into a single OR query.""" + mock_client_instance = MockClient.return_value + mock_client_instance.results.return_value = [] + + fetch_arxiv_papers(["cs.AI", "cs.CL", "cs.LG"], date.today(), max_results=10) + + MockSearch.assert_called_once() + call_kwargs = MockSearch.call_args + assert call_kwargs[1]["query"] == "cat:cs.AI OR cat:cs.CL OR cat:cs.LG" + + +@patch("paperweight.scraper.arxiv.Client") +@patch("paperweight.scraper.arxiv.Search") +def test_single_category_query(MockSearch, MockClient): + """A single category produces a simple cat: query (no OR).""" + mock_client_instance = MockClient.return_value + mock_client_instance.results.return_value = [] + + fetch_arxiv_papers(["cs.AI"], date.today(), max_results=10) + + MockSearch.assert_called_once() + call_kwargs = MockSearch.call_args + assert call_kwargs[1]["query"] == "cat:cs.AI" + + +# --------------------------------------------------------------------------- +# page_size matching +# --------------------------------------------------------------------------- + + +@patch("paperweight.scraper.arxiv.Client") +def test_page_size_matches_max_results(MockClient): + """page_size should equal max_results when max_results < 100.""" + mock_client_instance = MockClient.return_value + mock_client_instance.results.return_value = [] + + fetch_arxiv_papers(["cs.AI"], date.today(), max_results=15) + + MockClient.assert_called_once_with( + page_size=15, + delay_seconds=3.0, + num_retries=3, + ) + + +@patch("paperweight.scraper.arxiv.Client") +def test_page_size_caps_at_100(MockClient): + """page_size should cap at 100 even when max_results > 100.""" + mock_client_instance = MockClient.return_value + mock_client_instance.results.return_value = [] + + fetch_arxiv_papers(["cs.AI"], date.today(), max_results=200) + + MockClient.assert_called_once_with( + page_size=100, + delay_seconds=3.0, + num_retries=3, + ) + + +@patch("paperweight.scraper.arxiv.Client") +def test_page_size_defaults_to_100_when_no_limit(MockClient): + """page_size should be 100 when max_results is None.""" + mock_client_instance = MockClient.return_value + mock_client_instance.results.return_value = [] + + fetch_arxiv_papers(["cs.AI"], date.today(), max_results=None) + + MockClient.assert_called_once_with( + page_size=100, + delay_seconds=3.0, + num_retries=3, + ) + + +# --------------------------------------------------------------------------- +# Single-call fetch_recent_papers +# --------------------------------------------------------------------------- + + +def test_fetch_recent_papers_single_api_call(monkeypatch): + """fetch_recent_papers should call fetch_arxiv_papers exactly once.""" + config = { + "arxiv": {"categories": ["cs.AI", "cs.CL", "cs.LG"], "max_results": 10}, + } + call_count = {"n": 0} + + def fake_fetch(categories, start_date, max_results=None): + call_count["n"] += 1 + assert categories == ["cs.AI", "cs.CL", "cs.LG"] + return [] + + monkeypatch.setattr("paperweight.scraper.fetch_arxiv_papers", fake_fetch) + fetch_from = __import__("paperweight.scraper", fromlist=["fetch_recent_papers"]) + fetch_from.fetch_recent_papers(config, start_days=1) + assert call_count["n"] == 1, "Expected exactly 1 API call for batched categories" + + +# --------------------------------------------------------------------------- +# Rate-limit / retry +# --------------------------------------------------------------------------- + + +def test_rate_limit_error_friendly_message(): + """ArxivRateLimitError should have a user-friendly message.""" + err = ArxivRateLimitError() + assert "429" in str(err) + assert "rate-limited" in str(err).lower() + assert "wait" in str(err).lower() + + +@patch("paperweight.scraper.arxiv.Client") +def test_429_raises_rate_limit_error(MockClient): + """HTTP 429 from arXiv should raise ArxivRateLimitError, not raw HTTPError.""" + import arxiv as _arxiv + + mock_client_instance = MockClient.return_value + http_err = _arxiv.HTTPError("http://example.com", 0, 429) + mock_client_instance.results.side_effect = http_err + + with pytest.raises(ArxivRateLimitError, match="429"): + fetch_arxiv_papers(["cs.AI"], date.today(), max_results=10) + + +# --------------------------------------------------------------------------- +# Existing tests (unchanged logic, updated signatures) +# --------------------------------------------------------------------------- def test_extract_text_from_latex_source(): """Extract text from LaTeX source content.""" - latex_content = b''' + latex_content = b""" \\documentclass{article} \\begin{document} This is a test LaTeX document. \\end{document} - ''' - latex_text = extract_text_from_source(latex_content, 'source') + """ + latex_text = extract_text_from_source(latex_content, "source") assert "This is a test LaTeX document." in latex_text + def test_extract_text_from_source_invalid_type(): with pytest.raises(ValueError, match="Invalid source type: invalid_type"): - extract_text_from_source(b'content', 'invalid_type') + extract_text_from_source(b"content", "invalid_type") + def test_get_recent_papers_db_unreachable(): config = { - 'db': { - 'enabled': True, - 'host': 'localhost', - 'port': 5432, - 'database': 'paperweight', - 'user': 'paperweight', - 'password': 'pass', - 'sslmode': 'prefer' + "db": { + "enabled": True, + "host": "localhost", + "port": 5432, + "database": "paperweight", + "user": "paperweight", + "password": "pass", + "sslmode": "prefer", } } - with patch('paperweight.scraper.connect_db', side_effect=Exception("boom")): - with pytest.raises(DatabaseConnectionError, match="Database enabled but unreachable"): + with patch("paperweight.scraper.connect_db", side_effect=Exception("boom")): + with pytest.raises( + DatabaseConnectionError, match="Database enabled but unreachable" + ): get_recent_papers(config) @@ -171,7 +313,9 @@ def test_get_recent_papers_without_content(monkeypatch): monkeypatch.setattr("paperweight.scraper.get_last_processed_date", lambda: None) monkeypatch.setattr("paperweight.scraper.save_last_processed_date", lambda _d: None) - monkeypatch.setattr("paperweight.scraper.fetch_recent_papers", lambda _c, _d: fake_papers) + monkeypatch.setattr( + "paperweight.scraper.fetch_recent_papers", lambda _c, _d: fake_papers + ) fetch_content = MagicMock() monkeypatch.setattr("paperweight.scraper.fetch_paper_contents", fetch_content) From 7d0327042b2f636aa330e429144b10695e0b294b Mon Sep 17 00:00:00 2001 From: Sean Brar Date: Tue, 17 Feb 2026 00:11:37 -0800 Subject: [PATCH 2/3] feat: RSS-first fetching and decouple --force-refresh from 7-day window Daily lookups now try arXiv RSS feeds before falling back to the API, eliminating rate-limit exposure for the most common usage pattern. --force-refresh no longer forces a 7-day backfill; it fetches today's papers via the fast RSS path. The 7-day bootstrap is reserved for true first runs (no prior watermark). Quick Start updated to recommend plain `paperweight run`. --- CHANGELOG.md | 5 +- README.md | 10 +- src/paperweight/scraper.py | 182 +++++++++++++++++++-- tests/test_scraper.py | 316 ++++++++++++++++++++++++++++++++++++- uv.lock | 2 +- 5 files changed, 498 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5cec423..2ab44f6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,11 +13,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - arXiv categories are now batched into a single OR query (`cat:cs.AI OR cat:cs.CL`), reducing N parallel API calls to 1 - `page_size` now matches `min(max_results, 100)` instead of always requesting 100 results - `ThreadPoolExecutor` removed from `fetch_recent_papers()` since only one API call is made +- `--force-refresh` now fetches today's papers (1-day window) instead of a full 7-day backfill; the 7-day bootstrap is reserved for first runs only +- Quick Start no longer recommends `--force-refresh` for the initial run since `paperweight run` already backfills automatically ### Added +- RSS feed fetcher (`fetch_rss_papers`) for daily lookups — no rate limits, sub-second metadata fetch +- RSS-first routing in `fetch_recent_papers`: daily runs try RSS before falling back to the arXiv API - Exponential backoff (via `tenacity`) on `arxiv.HTTPError` with waits of 5 → 15 → 45 → 90 s - `ArxivRateLimitError` exception with user-friendly message for HTTP 429 responses -- 8 new unit tests covering batched queries, page_size matching, and retry behavior ## [0.3.0] - 2026-02-15 diff --git a/README.md b/README.md index eb87980..94b9fdb 100644 --- a/README.md +++ b/README.md @@ -47,11 +47,15 @@ source .venv/bin/activate ## Quick start (works without API keys) ```bash -paperweight init # create config.yaml with safe defaults -paperweight doctor # check your setup for issues -paperweight run --force-refresh # fetch papers and produce a digest +paperweight init # create config.yaml with safe defaults +paperweight doctor # check your setup for issues +paperweight run # fetch papers and produce a digest ``` +The first run automatically backfills a week of papers. After that, the same +`paperweight run` fetches only what's new. Use `--force-refresh` to re-fetch +if you've already run today. + Notes: - Default analyzer mode is `abstract` (no API key required). diff --git a/src/paperweight/scraper.py b/src/paperweight/scraper.py index 2a4d25b..6e8fc02 100644 --- a/src/paperweight/scraper.py +++ b/src/paperweight/scraper.py @@ -7,17 +7,21 @@ import gzip import hashlib +import html import io import json import logging import os +import re import tarfile from concurrent.futures import ( ThreadPoolExecutor, as_completed, ) # ThreadPoolExecutor still used by fetch_paper_contents from datetime import date, datetime, timedelta +from email.utils import parsedate_to_datetime from typing import Any, Dict, List, Optional +from xml.etree import ElementTree as ET import arxiv import arxiv as _arxiv_module @@ -40,6 +44,8 @@ logger = logging.getLogger(__name__) +_RSS_BASE_URL = "https://rss.arxiv.org/rss/" + class ArxivRateLimitError(RuntimeError): """Raised when arXiv returns HTTP 429 (Too Many Requests). @@ -183,11 +189,150 @@ def _fetch_with_backoff( return papers +def _strip_html_tags(text): + """Remove HTML tags from a string.""" + return re.sub(r"<[^>]+>", "", text) + + +def _parse_rss_description(description): + """Extract abstract text from an RSS item's description field. + + The description typically contains HTML with an "Abstract:" marker. + """ + if not description: + return "" + text = _strip_html_tags(html.unescape(description)) + marker = "Abstract:" + idx = text.find(marker) + if idx != -1: + return text[idx + len(marker) :].strip() + return text.strip() + + +def _parse_rss_item(item, ns): + """Parse a single RSS element into a paper dict. + + Returns None for announce_type == "replace" (updates to old papers). + """ + # Check for replace announcements + announce_type_el = item.find("arxiv:announce_type", ns) + if announce_type_el is not None and announce_type_el.text == "replace": + return None + + title = item.findtext("title", default="", namespaces=ns).strip() + link = item.findtext("link", default="", namespaces=ns).strip() + description = item.findtext("description", default="", namespaces=ns) + abstract = _parse_rss_description(description) + + # Parse date + pub_date_text = item.findtext("pubDate", default="", namespaces=ns) + if pub_date_text: + try: + paper_date = parsedate_to_datetime(pub_date_text).date() + except (ValueError, TypeError): + paper_date = datetime.now().date() + else: + paper_date = datetime.now().date() + + # Parse authors from dc:creator + creator_el = item.find("dc:creator", ns) + if creator_el is not None and creator_el.text: + # dc:creator contains a comma+space-separated list like "Author One, Author Two" + # but can also use <a href=...> tags — strip those + raw_authors = _strip_html_tags(html.unescape(creator_el.text)) + authors = [a.strip() for a in raw_authors.split(",") if a.strip()] + else: + authors = [] + + # Collect categories + categories = [] + for cat_el in item.findall("category", ns): + if cat_el.text: + categories.append(cat_el.text.strip()) + + # Extract arXiv ID from link + arxiv_id, _ = split_arxiv_id(link) + pdf_url = f"https://arxiv.org/pdf/{arxiv_id}" + + return { + "title": title, + "link": link, + "date": paper_date, + "abstract": abstract, + "authors": authors, + "categories": categories, + "pdf_url": pdf_url, + "id": arxiv_id, + } + + +@retry( + stop=stop_after_attempt(2), + wait=wait_exponential(multiplier=2, min=2, max=8), + retry=retry_if_exception_type((requests.ConnectionError, requests.Timeout)), + reraise=True, +) +def _fetch_single_rss_feed(url): + """Fetch a single RSS feed URL with retry on connection errors.""" + response = requests.get(url, timeout=15) + response.raise_for_status() + return response.text + + +def fetch_rss_papers(categories): + """Fetch today's papers from arXiv RSS feeds. + + Iterates over categories, fetches each feed, parses items, + and deduplicates by arXiv ID. Per-category errors are logged and skipped. + + Args: + categories: List of arXiv category strings (e.g., ['cs.AI', 'cs.CL']). + + Returns: + List of paper dicts (same schema as fetch_arxiv_papers output). + """ + seen_ids = set() + papers = [] + + for category in categories: + url = f"{_RSS_BASE_URL}{category}" + try: + xml_text = _fetch_single_rss_feed(url) + root = ET.fromstring(xml_text) + except Exception: + logger.warning("RSS fetch failed for category %s, skipping", category) + continue + + # Build namespace map from the root element + ns = {} + for prefix, uri in [ + ("dc", "http://purl.org/dc/elements/1.1/"), + ("arxiv", "http://arxiv.org/schemas/atom"), + ]: + ns[prefix] = uri + + channel = root.find("channel") + if channel is None: + continue + + for item in channel.findall("item"): + paper = _parse_rss_item(item, ns) + if paper is None: + continue + if paper["id"] not in seen_ids: + seen_ids.add(paper["id"]) + papers.append(paper) + + logger.info("RSS fetched %d unique papers from %d categories", len(papers), len(categories)) + return papers + + def fetch_recent_papers(config, start_days=1): """Fetch papers published within the last specified number of days. - All configured categories are combined into a single arXiv API query - using OR syntax to minimize HTTP requests. + For daily lookups (start_days <= 1), RSS feeds are tried first since they + have no rate limits. Falls back to the arXiv API on failure or empty results. + For multi-day ranges (start_days > 1), the arXiv API is used directly. Args: config: Application configuration dictionary. @@ -202,13 +347,26 @@ def fetch_recent_papers(config, start_days=1): start_date = end_date - timedelta(days=start_days) logger.info("Fetching papers from %s to %s", start_date, end_date) - logger.info("Categories: %s (single batched query)", categories) - papers = fetch_arxiv_papers( - categories, - start_date, - max_results=max_results if max_results > 0 else None, - ) + papers = [] + + # RSS-first path for daily lookups (no rate limits) + if start_days <= 1: + logger.info("Categories: %s (RSS feeds)", categories) + try: + papers = fetch_rss_papers(categories) + except Exception: + logger.warning("RSS fetch failed, falling back to arXiv API") + papers = [] + + # arXiv API fallback (or primary path for multi-day ranges) + if not papers: + logger.info("Categories: %s (arXiv API query)", categories) + papers = fetch_arxiv_papers( + categories, + start_date, + max_results=max_results if max_results > 0 else None, + ) # Deduplicate by arXiv ID (papers can appear in multiple categories) seen_ids: set = set() @@ -536,10 +694,14 @@ def get_recent_papers(config, force_refresh=False, include_content=True): # noq logger.info("Loaded %d papers from metadata cache", len(recent_papers)) if recent_papers is None: - if last_processed_date is None or force_refresh: - # If never run before, fetch papers from the last 7 days + if last_processed_date is None: + # Bootstrap: first run ever — backfill a week of papers days = 7 logger.info("First run detected. Fetching papers from the last 7 days.") + elif force_refresh: + # User wants the freshest data; 1-day window enables the fast RSS path + days = 1 + logger.info("Force refresh: fetching today's papers.") else: days = (current_date - last_processed_date).days if days == 0: diff --git a/tests/test_scraper.py b/tests/test_scraper.py index 0cabec8..b05db35 100644 --- a/tests/test_scraper.py +++ b/tests/test_scraper.py @@ -1,14 +1,19 @@ from datetime import date, datetime from unittest.mock import MagicMock, patch +from xml.etree import ElementTree as ET import pytest from paperweight.db import DatabaseConnectionError from paperweight.scraper import ( ArxivRateLimitError, + _parse_rss_description, + _parse_rss_item, _write_metadata_cache, extract_text_from_source, fetch_arxiv_papers, + fetch_recent_papers, + fetch_rss_papers, get_recent_papers, hydrate_papers_with_content, ) @@ -188,7 +193,7 @@ def test_page_size_defaults_to_100_when_no_limit(MockClient): def test_fetch_recent_papers_single_api_call(monkeypatch): - """fetch_recent_papers should call fetch_arxiv_papers exactly once.""" + """fetch_recent_papers should call fetch_arxiv_papers exactly once (multi-day path).""" config = { "arxiv": {"categories": ["cs.AI", "cs.CL", "cs.LG"], "max_results": 10}, } @@ -201,7 +206,7 @@ def fake_fetch(categories, start_date, max_results=None): monkeypatch.setattr("paperweight.scraper.fetch_arxiv_papers", fake_fetch) fetch_from = __import__("paperweight.scraper", fromlist=["fetch_recent_papers"]) - fetch_from.fetch_recent_papers(config, start_days=1) + fetch_from.fetch_recent_papers(config, start_days=3) assert call_count["n"] == 1, "Expected exactly 1 API call for batched categories" @@ -366,3 +371,310 @@ def fake_fetch(_config, _days): assert papers[0]["title"] == "Cached Paper" assert not fetch_called["called"] fetch_content.assert_not_called() + + +# --------------------------------------------------------------------------- +# RSS description parsing +# --------------------------------------------------------------------------- + +_RSS_NS = { + "dc": "http://purl.org/dc/elements/1.1/", + "arxiv": "http://arxiv.org/schemas/atom", +} + + +def test_parse_rss_description_extracts_abstract(): + desc = "

Abstract: This is the abstract text.

" + assert _parse_rss_description(desc) == "This is the abstract text." + + +def test_parse_rss_description_handles_html_entities(): + desc = "Abstract: x < y & z" + assert _parse_rss_description(desc) == "x < y & z" + + +def test_parse_rss_description_empty_input(): + assert _parse_rss_description("") == "" + assert _parse_rss_description(None) == "" + + +def test_parse_rss_description_no_marker_falls_back(): + desc = "Just some text without the marker." + assert _parse_rss_description(desc) == "Just some text without the marker." + + +# --------------------------------------------------------------------------- +# RSS item parsing +# --------------------------------------------------------------------------- + + +def _make_item_xml( + title="Test Paper", + link="https://arxiv.org/abs/2401.12345", + description="Abstract: Some abstract", + pub_date="Mon, 15 Jan 2024 00:00:00 GMT", + creator="Author One, Author Two", + categories=("cs.AI",), + announce_type="new", +): + """Build a minimal RSS element for testing.""" + parts = [ + f"", + f"{title}", + f"{link}", + f"{description}", + ] + if pub_date is not None: + parts.append(f"{pub_date}") + if creator is not None: + parts.append( + f'{creator}' + ) + for cat in categories: + parts.append(f"{cat}") + if announce_type is not None: + parts.append( + f'{announce_type}' + ) + parts.append("") + return ET.fromstring("".join(parts)) + + +def test_parse_rss_item_complete(): + item = _make_item_xml() + paper = _parse_rss_item(item, _RSS_NS) + assert paper is not None + assert paper["title"] == "Test Paper" + assert paper["id"] == "2401.12345" + assert paper["abstract"] == "Some abstract" + assert paper["authors"] == ["Author One", "Author Two"] + assert paper["categories"] == ["cs.AI"] + assert paper["pdf_url"] == "https://arxiv.org/pdf/2401.12345" + assert paper["date"] == date(2024, 1, 15) + + +def test_parse_rss_item_replace_returns_none(): + item = _make_item_xml(announce_type="replace") + assert _parse_rss_item(item, _RSS_NS) is None + + +def test_parse_rss_item_missing_pubdate(): + item = _make_item_xml(pub_date=None) + paper = _parse_rss_item(item, _RSS_NS) + assert paper is not None + assert paper["date"] == datetime.now().date() + + +def test_parse_rss_item_missing_creator(): + item = _make_item_xml(creator=None) + paper = _parse_rss_item(item, _RSS_NS) + assert paper is not None + assert paper["authors"] == [] + + +def test_parse_rss_item_multiple_categories(): + item = _make_item_xml(categories=("cs.AI", "cs.LG", "stat.ML")) + paper = _parse_rss_item(item, _RSS_NS) + assert paper["categories"] == ["cs.AI", "cs.LG", "stat.ML"] + + +# --------------------------------------------------------------------------- +# RSS fetch integration (mocked HTTP) +# --------------------------------------------------------------------------- + + +def _wrap_rss_feed(items_xml): + """Wrap XML strings in a minimal RSS feed.""" + return ( + '' + '' + "" + f"{''.join(items_xml)}" + "" + ) + + +_ITEM_A = ( + "Paper A" + "https://arxiv.org/abs/2401.00001" + "Abstract: Abstract A" + "Mon, 15 Jan 2024 00:00:00 GMT" + 'Auth A' + "cs.AI" + 'new' + "" +) + +_ITEM_B = ( + "Paper B" + "https://arxiv.org/abs/2401.00002" + "Abstract: Abstract B" + "Mon, 15 Jan 2024 00:00:00 GMT" + 'Auth B' + "cs.CL" + 'new' + "" +) + + +@patch("paperweight.scraper._fetch_single_rss_feed") +def test_fetch_rss_single_category(mock_fetch): + mock_fetch.return_value = _wrap_rss_feed([_ITEM_A]) + papers = fetch_rss_papers(["cs.AI"]) + assert len(papers) == 1 + assert papers[0]["title"] == "Paper A" + assert papers[0]["id"] == "2401.00001" + + +@patch("paperweight.scraper._fetch_single_rss_feed") +def test_fetch_rss_deduplicates_across_categories(mock_fetch): + """Same paper in two category feeds should appear only once.""" + mock_fetch.return_value = _wrap_rss_feed([_ITEM_A]) + papers = fetch_rss_papers(["cs.AI", "cs.LG"]) + assert len(papers) == 1 + + +@patch("paperweight.scraper._fetch_single_rss_feed") +def test_fetch_rss_one_category_fails(mock_fetch): + """If one category feed fails, other categories still return papers.""" + def side_effect(url): + if "cs.AI" in url: + raise ConnectionError("boom") + return _wrap_rss_feed([_ITEM_B]) + + mock_fetch.side_effect = side_effect + papers = fetch_rss_papers(["cs.AI", "cs.CL"]) + assert len(papers) == 1 + assert papers[0]["title"] == "Paper B" + + +@patch("paperweight.scraper._fetch_single_rss_feed") +def test_fetch_rss_all_categories_fail(mock_fetch): + """If all feeds fail, return empty list (no exception).""" + mock_fetch.side_effect = ConnectionError("boom") + papers = fetch_rss_papers(["cs.AI", "cs.CL"]) + assert papers == [] + + +# --------------------------------------------------------------------------- +# Routing: fetch_recent_papers RSS vs API +# --------------------------------------------------------------------------- + + +@patch("paperweight.scraper.fetch_arxiv_papers") +@patch("paperweight.scraper.fetch_rss_papers") +def test_routing_daily_uses_rss(mock_rss, mock_api): + """start_days=1 → RSS called, API not called.""" + mock_rss.return_value = [ + { + "title": "RSS Paper", + "link": "https://arxiv.org/abs/2401.00001", + "date": date.today(), + "abstract": "Abstract", + "authors": [], + "categories": ["cs.AI"], + "pdf_url": "https://arxiv.org/pdf/2401.00001", + "id": "2401.00001", + } + ] + config = {"arxiv": {"categories": ["cs.AI"], "max_results": 10}} + papers = fetch_recent_papers(config, start_days=1) + assert len(papers) == 1 + assert papers[0]["title"] == "RSS Paper" + mock_rss.assert_called_once_with(["cs.AI"]) + mock_api.assert_not_called() + + +@patch("paperweight.scraper.fetch_arxiv_papers") +@patch("paperweight.scraper.fetch_rss_papers") +def test_routing_multiday_uses_api(mock_rss, mock_api): + """start_days=3 → API called, RSS not called.""" + mock_api.return_value = [ + { + "title": "API Paper", + "link": "https://arxiv.org/abs/2401.00001", + "date": date.today(), + "abstract": "Abstract", + "authors": [], + "categories": ["cs.AI"], + "pdf_url": "https://arxiv.org/pdf/2401.00001", + "id": "2401.00001", + } + ] + config = {"arxiv": {"categories": ["cs.AI"], "max_results": 10}} + papers = fetch_recent_papers(config, start_days=3) + assert len(papers) == 1 + mock_rss.assert_not_called() + mock_api.assert_called_once() + + +@patch("paperweight.scraper.fetch_arxiv_papers") +@patch("paperweight.scraper.fetch_rss_papers") +def test_routing_rss_fails_falls_back_to_api(mock_rss, mock_api): + """RSS exception → falls back to API.""" + mock_rss.side_effect = Exception("RSS broken") + mock_api.return_value = [ + { + "title": "API Paper", + "link": "https://arxiv.org/abs/2401.00001", + "date": date.today(), + "abstract": "Abstract", + "authors": [], + "categories": ["cs.AI"], + "pdf_url": "https://arxiv.org/pdf/2401.00001", + "id": "2401.00001", + } + ] + config = {"arxiv": {"categories": ["cs.AI"], "max_results": 10}} + papers = fetch_recent_papers(config, start_days=1) + assert len(papers) == 1 + assert papers[0]["title"] == "API Paper" + mock_api.assert_called_once() + + +@patch("paperweight.scraper.fetch_arxiv_papers") +@patch("paperweight.scraper.fetch_rss_papers") +def test_routing_rss_empty_falls_back_to_api(mock_rss, mock_api): + """RSS returns empty → falls back to API.""" + mock_rss.return_value = [] + mock_api.return_value = [ + { + "title": "API Paper", + "link": "https://arxiv.org/abs/2401.00001", + "date": date.today(), + "abstract": "Abstract", + "authors": [], + "categories": ["cs.AI"], + "pdf_url": "https://arxiv.org/pdf/2401.00001", + "id": "2401.00001", + } + ] + config = {"arxiv": {"categories": ["cs.AI"], "max_results": 10}} + papers = fetch_recent_papers(config, start_days=1) + assert len(papers) == 1 + assert papers[0]["title"] == "API Paper" + mock_api.assert_called_once() + + +@patch("paperweight.scraper.fetch_arxiv_papers") +@patch("paperweight.scraper.fetch_rss_papers") +def test_routing_max_results_applied_to_rss(mock_rss, mock_api): + """max_results cap is applied to RSS results.""" + mock_rss.return_value = [ + { + "title": f"Paper {i}", + "link": f"https://arxiv.org/abs/2401.{i:05d}", + "date": date.today(), + "abstract": "Abstract", + "authors": [], + "categories": ["cs.AI"], + "pdf_url": f"https://arxiv.org/pdf/2401.{i:05d}", + "id": f"2401.{i:05d}", + } + for i in range(5) + ] + config = {"arxiv": {"categories": ["cs.AI"], "max_results": 2}} + papers = fetch_recent_papers(config, start_days=1) + assert len(papers) == 2 + mock_api.assert_not_called() diff --git a/uv.lock b/uv.lock index 973ae9f..4819365 100644 --- a/uv.lock +++ b/uv.lock @@ -9,7 +9,7 @@ resolution-markers = [ [[package]] name = "academic-paperweight" -version = "0.3.0" +version = "0.3.1" source = { editable = "." } dependencies = [ { name = "arxiv" }, From 7e52957357f098f8b65e81e9a7092962a97dd31a Mon Sep 17 00:00:00 2001 From: Sean Brar Date: Tue, 17 Feb 2026 00:16:12 -0800 Subject: [PATCH 3/3] style: fix linting and formatting issues --- src/paperweight/analyzer.py | 39 ++++-- src/paperweight/db.py | 4 +- src/paperweight/notifier.py | 12 +- src/paperweight/scraper.py | 4 +- src/paperweight/utils.py | 37 +++-- tests/api/test_database.py | 1 + tests/conftest.py | 8 +- tests/test_analyzer.py | 7 +- tests/test_cli_integration.py | 6 +- tests/test_config.py | 254 +++++++++++++++++++--------------- tests/test_contracts.py | 4 +- tests/test_notifier.py | 30 ++-- tests/test_pipeline.py | 119 ++++++++-------- tests/test_processor.py | 79 ++++++----- tests/test_scraper.py | 4 +- 15 files changed, 355 insertions(+), 253 deletions(-) diff --git a/src/paperweight/analyzer.py b/src/paperweight/analyzer.py index 3e29275..31a4100 100644 --- a/src/paperweight/analyzer.py +++ b/src/paperweight/analyzer.py @@ -66,7 +66,9 @@ def get_abstracts(processed_papers, config, *, summary_concurrency=None): if analysis_type == "abstract": return [paper["abstract"] for paper in processed_papers] if analysis_type == "summary": - return summarize_papers(processed_papers, config, summary_concurrency=summary_concurrency) + return summarize_papers( + processed_papers, config, summary_concurrency=summary_concurrency + ) raise ValueError(f"Unknown analysis type: {analysis_type}") @@ -115,9 +117,7 @@ def _resolve_triage_model_config( analyzer_cfg = full_config.get("analyzer", {}) provider = ( - triage_cfg.get("llm_provider") - or analyzer_cfg.get("llm_provider") - or "openai" + triage_cfg.get("llm_provider") or analyzer_cfg.get("llm_provider") or "openai" ).lower() model = triage_cfg.get("model") or _default_model_for_provider(provider) api_key = ( @@ -206,7 +206,9 @@ async def _triage_one_paper_async(prompt, pollux_config, *, min_score): return _parse_triage_decision(answer, min_score=min_score) -async def _run_triage_async(prompts, pollux_config, *, min_score, concurrency=TRIAGE_CONCURRENCY): +async def _run_triage_async( + prompts, pollux_config, *, min_score, concurrency=TRIAGE_CONCURRENCY +): """Run triage prompts concurrently with a semaphore, returning decisions in order.""" semaphore = asyncio.Semaphore(concurrency) total = len(prompts) @@ -277,11 +279,18 @@ def triage_papers( prompts = [_build_triage_prompt(paper, profile_text) for paper in papers] - triage_concurrency = full_config.get("concurrency", {}).get("triage", TRIAGE_CONCURRENCY) + triage_concurrency = full_config.get("concurrency", {}).get( + "triage", TRIAGE_CONCURRENCY + ) try: decisions = asyncio.run( - _run_triage_async(prompts, pollux_config, min_score=min_score, concurrency=triage_concurrency) + _run_triage_async( + prompts, + pollux_config, + min_score=min_score, + concurrency=triage_concurrency, + ) ) except Exception as exc: logger.warning( @@ -357,9 +366,13 @@ async def _summarize_one_paper_async( return str(response) -def _resolve_summary_model_config(config: Dict[str, Any]) -> tuple[ProviderName, str, str]: +def _resolve_summary_model_config( + config: Dict[str, Any], +) -> tuple[ProviderName, str, str]: llm_provider = (config.get("llm_provider") or "openai").lower().strip() - api_key = config.get("api_key") or os.getenv(f"{llm_provider.upper()}_API_KEY") or "" + api_key = ( + config.get("api_key") or os.getenv(f"{llm_provider.upper()}_API_KEY") or "" + ) if llm_provider not in ("openai", "gemini") or not api_key: raise ValueError( "Summary analyzer requires a valid llm_provider (openai|gemini) and api_key." @@ -385,7 +398,9 @@ def summarize_papers( # noqa: C901 provider, model_name, api_key = _resolve_summary_model_config(config) max_input_tokens = _int_setting(config.get("max_input_tokens"), 7000, minimum=500) max_input_chars = _int_setting(config.get("max_input_chars"), 20_000, minimum=1000) - effective_concurrency = summary_concurrency if summary_concurrency is not None else SUMMARY_CONCURRENCY + effective_concurrency = ( + summary_concurrency if summary_concurrency is not None else SUMMARY_CONCURRENCY + ) pollux_config = Config( provider=provider, @@ -399,7 +414,9 @@ def summarize_papers( # noqa: C901 ), ) - async def _run_summary_batch() -> tuple[List[str | None], List[tuple[int, BaseException]]]: + async def _run_summary_batch() -> ( + tuple[List[str | None], List[tuple[int, BaseException]]] + ): semaphore = asyncio.Semaphore(effective_concurrency) results: List[str | None] = [None] * len(papers) failures: List[tuple[int, BaseException]] = [] diff --git a/src/paperweight/db.py b/src/paperweight/db.py index 3fd8ea2..924d723 100644 --- a/src/paperweight/db.py +++ b/src/paperweight/db.py @@ -14,9 +14,7 @@ def is_db_enabled(config: Dict[str, Any]) -> bool: @contextmanager -def connect_db( - db_config: Dict[str, Any], autocommit: bool = False -) -> Generator: +def connect_db(db_config: Dict[str, Any], autocommit: bool = False) -> Generator: """Create a database connection. Args: diff --git a/src/paperweight/notifier.py b/src/paperweight/notifier.py index fc62f32..1cc5841 100644 --- a/src/paperweight/notifier.py +++ b/src/paperweight/notifier.py @@ -124,7 +124,9 @@ def render_atom_feed( ET.SubElement(entry, f"{{{ns}}}category", {"term": cat}) ET.SubElement(entry, f"{{{ns}}}summary").text = summary - ET.SubElement(entry, f"{{{ns}}}content", {"type": "text"}).text = ( + ET.SubElement( + entry, f"{{{ns}}}content", {"type": "text"} + ).text = ( f"Score: {score:.2f}\nWhy: {rationale}\nLink: {link}\nSummary: {summary}" ) @@ -206,7 +208,9 @@ def send_email_notification(subject, body, config): if use_auth and from_password: server.login(from_email, from_password) elif use_auth and not from_password: - logger.warning("SMTP auth enabled but no password provided; skipping login.") + logger.warning( + "SMTP auth enabled but no password provided; skipping login." + ) text = msg.as_string() server.sendmail(from_email, to_email, text) server.quit() @@ -234,6 +238,8 @@ def compile_and_send_notifications(papers, config): sort_order = config.get("email", {}).get("sort_order", "relevance") papers = _sort_papers(papers, sort_order) subject = "New Papers from ArXiv" - body = render_text_digest(papers, sort_order=sort_order, heading="New Papers from ArXiv") + body = render_text_digest( + papers, sort_order=sort_order, heading="New Papers from ArXiv" + ) success = send_email_notification(subject, body, config) return success diff --git a/src/paperweight/scraper.py b/src/paperweight/scraper.py index 6e8fc02..92b8d1e 100644 --- a/src/paperweight/scraper.py +++ b/src/paperweight/scraper.py @@ -323,7 +323,9 @@ def fetch_rss_papers(categories): seen_ids.add(paper["id"]) papers.append(paper) - logger.info("RSS fetched %d unique papers from %d categories", len(papers), len(categories)) + logger.info( + "RSS fetched %d unique papers from %d categories", len(papers), len(categories) + ) return papers diff --git a/src/paperweight/utils.py b/src/paperweight/utils.py index bc02dee..485dca4 100644 --- a/src/paperweight/utils.py +++ b/src/paperweight/utils.py @@ -35,10 +35,18 @@ "important_words_weight": 0.5, "min_score": 3, }, - "analyzer": {"type": "abstract", "max_input_tokens": 7000, "max_input_chars": 20000}, + "analyzer": { + "type": "abstract", + "max_input_tokens": 7000, + "max_input_chars": 20000, + }, "triage": {"enabled": False}, "logging": {"level": "INFO"}, - "metadata_cache": {"enabled": True, "path": ".paperweight_cache.json", "ttl_hours": 4}, + "metadata_cache": { + "enabled": True, + "path": ".paperweight_cache.json", + "ttl_hours": 4, + }, "concurrency": {"content_fetch": 6, "triage": 3, "summary": 3}, } @@ -362,7 +370,14 @@ def _check_db_section(db): int(db["port"]) except (ValueError, TypeError) as e: raise ValueError("'port' in 'db' section must be a valid integer") from e - valid_sslmodes = {"disable", "allow", "prefer", "require", "verify-ca", "verify-full"} + valid_sslmodes = { + "disable", + "allow", + "prefer", + "require", + "verify-ca", + "verify-full", + } if db["sslmode"] not in valid_sslmodes: raise ValueError( f"Invalid sslmode '{db['sslmode']}'. Must be one of: {', '.join(sorted(valid_sslmodes))}" @@ -404,7 +419,9 @@ def _check_concurrency_section(concurrency): except (TypeError, ValueError): raise ValueError(f"'{key}' in 'concurrency' must be a valid integer") if val < lo or val > hi: - raise ValueError(f"'{key}' in 'concurrency' must be between {lo} and {hi}") + raise ValueError( + f"'{key}' in 'concurrency' must be between {lo} and {hi}" + ) def _check_profiles_section(profiles): @@ -537,13 +554,15 @@ def split_arxiv_id(raw_id): raw = (raw_id or "").strip() if "/abs/" in raw: raw = raw.split("/abs/")[-1] - raw = raw.replace("http://arxiv.org/abs/", "").replace( - "https://arxiv.org/abs/", "" - ) + raw = raw.replace("http://arxiv.org/abs/", "").replace("https://arxiv.org/abs/", "") new_style = re.match(r"^(?P\d{4}\.\d{4,5})(?Pv\d+)?$", raw) if new_style: - return new_style.group("id"), new_style.group("version") or DEFAULT_ARXIV_VERSION + return new_style.group("id"), new_style.group( + "version" + ) or DEFAULT_ARXIV_VERSION legacy_style = re.match(r"^(?P[a-z\-]+/\d{7})(?Pv\d+)?$", raw) if legacy_style: - return legacy_style.group("id"), legacy_style.group("version") or DEFAULT_ARXIV_VERSION + return legacy_style.group("id"), legacy_style.group( + "version" + ) or DEFAULT_ARXIV_VERSION return raw, DEFAULT_ARXIV_VERSION diff --git a/tests/api/test_database.py b/tests/api/test_database.py index 5a1248a..efa1daf 100644 --- a/tests/api/test_database.py +++ b/tests/api/test_database.py @@ -20,6 +20,7 @@ def parse_database_url(url: str) -> dict: """Parse a PostgreSQL URL into a config dict.""" # postgresql://user:pass@host:port/database from urllib.parse import urlparse + parsed = urlparse(url) return { "host": parsed.hostname, diff --git a/tests/conftest.py b/tests/conftest.py index 72cdd52..297ee4a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -79,7 +79,13 @@ def base_test_config(tmp_path: Path) -> Dict[str, Any]: "max_results": 10, }, "processor": { - "keywords": ["machine learning", "neural network", "deep learning", "ai", "transformer"], + "keywords": [ + "machine learning", + "neural network", + "deep learning", + "ai", + "transformer", + ], "exclusion_keywords": [], # Don't exclude anything for testing "important_words": ["novel", "state-of-the-art"], "title_keyword_weight": 3, diff --git a/tests/test_analyzer.py b/tests/test_analyzer.py index 8f6dce5..bdffe17 100644 --- a/tests/test_analyzer.py +++ b/tests/test_analyzer.py @@ -64,7 +64,9 @@ def test_summarize_requires_valid_provider_and_key( with pytest.raises(ValueError, match="Summary analyzer requires"): summarize_paper(paper, config) - def test_summarize_falls_back_to_abstract_when_model_returns_no_answers(self, mocker): + def test_summarize_falls_back_to_abstract_when_model_returns_no_answers( + self, mocker + ): mocker.patch("pollux.run", new=AsyncMock(return_value={"answers": []})) paper = { @@ -149,7 +151,8 @@ def test_triage_falls_back_for_entire_batch_when_llm_errors(self, mocker): assert len(shortlisted) == 1 assert shortlisted[0]["title"] == "Transformers for Agents" assert all( - "heuristic fallback" in paper["triage_rationale"].lower() for paper in papers + "heuristic fallback" in paper["triage_rationale"].lower() + for paper in papers ) def test_triage_falls_back_without_api_key(self): diff --git a/tests/test_cli_integration.py b/tests/test_cli_integration.py index a6a57c8..0788bb6 100644 --- a/tests/test_cli_integration.py +++ b/tests/test_cli_integration.py @@ -61,7 +61,8 @@ def _stub_scraper(monkeypatch): lambda _ids: [("2401.12345", b"stub-bytes", "pdf")], ) monkeypatch.setattr( - "paperweight.scraper.extract_text_from_source", lambda _c, _m: "transformer agent" + "paperweight.scraper.extract_text_from_source", + lambda _c, _m: "transformer agent", ) monkeypatch.setattr("paperweight.scraper.get_last_processed_date", lambda: None) monkeypatch.setattr("paperweight.scraper.save_last_processed_date", lambda _d: None) @@ -101,7 +102,8 @@ def _stub_scraper_two_papers(monkeypatch): ], ) monkeypatch.setattr( - "paperweight.scraper.extract_text_from_source", lambda _c, _m: "transformer agent" + "paperweight.scraper.extract_text_from_source", + lambda _c, _m: "transformer agent", ) monkeypatch.setattr("paperweight.scraper.get_last_processed_date", lambda: None) monkeypatch.setattr("paperweight.scraper.save_last_processed_date", lambda _d: None) diff --git a/tests/test_config.py b/tests/test_config.py index 29ea3d1..891dcd4 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -26,23 +26,24 @@ # Fixtures # --------------------------------------------------------------------------- + @pytest.fixture def valid_base_config(): """Minimal valid configuration for testing.""" return { - 'arxiv': {'categories': ['cs.AI']}, - 'processor': {}, - 'analyzer': {'type': 'abstract'}, - 'notifier': { - 'email': { - 'to': 'test@example.com', - 'from': 'sender@example.com', - 'password': 'pass', - 'smtp_server': 'smtp.example.com', - 'smtp_port': 587, + "arxiv": {"categories": ["cs.AI"]}, + "processor": {}, + "analyzer": {"type": "abstract"}, + "notifier": { + "email": { + "to": "test@example.com", + "from": "sender@example.com", + "password": "pass", + "smtp_server": "smtp.example.com", + "smtp_port": 587, } }, - 'logging': {'level': 'INFO'}, + "logging": {"level": "INFO"}, } @@ -50,26 +51,26 @@ def valid_base_config(): def sample_config(): """Sample config for load_config tests.""" return { - 'arxiv': {'categories': ['cs.AI'], 'max_results': 50}, - 'processor': {'keywords': ['AI']}, - 'analyzer': {'type': 'summary', 'llm_provider': 'openai'}, - 'notifier': { - 'email': { - 'to': 'test@example.com', - 'from': 'sender@example.com', - 'password': 'pass', - 'smtp_server': 'smtp.example.com', - 'smtp_port': 587, + "arxiv": {"categories": ["cs.AI"], "max_results": 50}, + "processor": {"keywords": ["AI"]}, + "analyzer": {"type": "summary", "llm_provider": "openai"}, + "notifier": { + "email": { + "to": "test@example.com", + "from": "sender@example.com", + "password": "pass", + "smtp_server": "smtp.example.com", + "smtp_port": 587, } }, - 'logging': {'level': 'INFO'}, + "logging": {"level": "INFO"}, } @pytest.fixture def config_file(sample_config): """Write sample_config to a temp file for load_config tests.""" - with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: yaml.dump(sample_config, f) yield f.name os.unlink(f.name) @@ -79,6 +80,7 @@ def config_file(sample_config): # Config Loading Tests # --------------------------------------------------------------------------- + class TestLoadConfig: """Tests for load_config function.""" @@ -91,60 +93,64 @@ def test_missing_config_file(self, tmp_path, monkeypatch): def test_invalid_yaml_syntax(self, tmp_path, monkeypatch): """YAMLError on malformed YAML.""" monkeypatch.chdir(tmp_path) - (tmp_path / 'config.yaml').write_text("invalid: yaml: syntax:") + (tmp_path / "config.yaml").write_text("invalid: yaml: syntax:") with pytest.raises(yaml.YAMLError): load_config() def test_load_config_basic(self, config_file): """Successfully load a valid config file.""" - with patch.dict(os.environ, {'OPENAI_API_KEY': 'dummy_key'}): + with patch.dict(os.environ, {"OPENAI_API_KEY": "dummy_key"}): config = load_config(config_path=config_file) assert isinstance(config, dict) - assert 'arxiv' in config - assert config['arxiv']['max_results'] == 50 + assert "arxiv" in config + assert config["arxiv"]["max_results"] == 50 def test_env_var_override(self, config_file): """Environment variables override config file values.""" - with patch.dict(os.environ, { - 'PAPERWEIGHT_MAX_RESULTS': '100', - 'OPENAI_API_KEY': 'test_api_key', - }): + with patch.dict( + os.environ, + { + "PAPERWEIGHT_MAX_RESULTS": "100", + "OPENAI_API_KEY": "test_api_key", + }, + ): config = load_config(config_path=config_file) - assert config['arxiv']['max_results'] == 100 - assert config['analyzer']['api_key'] == 'test_api_key' + assert config["arxiv"]["max_results"] == 100 + assert config["analyzer"]["api_key"] == "test_api_key" def test_missing_api_key_raises(self, config_file): """ValueError when LLM provider requires API key that's missing.""" - with patch('paperweight.utils.load_dotenv', return_value=None): + with patch("paperweight.utils.load_dotenv", return_value=None): with patch.dict(os.environ, {}, clear=True): with pytest.raises(ValueError, match="Missing API key for openai"): load_config(config_path=config_file) def test_abstract_type_no_api_key_required(self, config_file, sample_config): """Abstract analyzer type does not require API key.""" - sample_config['analyzer']['type'] = 'abstract' - with open(config_file, 'w') as f: + sample_config["analyzer"]["type"] = "abstract" + with open(config_file, "w") as f: yaml.dump(sample_config, f) with patch.dict(os.environ, {}, clear=True): config = load_config(config_path=config_file) - assert 'api_key' not in config['analyzer'] + assert "api_key" not in config["analyzer"] # --------------------------------------------------------------------------- # Config Validation Tests # --------------------------------------------------------------------------- + class TestCheckConfig: """Tests for check_config validation.""" def test_missing_required_section(self): """Missing top-level section raises ValueError.""" config = { - 'arxiv': {}, - 'processor': {}, - 'analyzer': {}, - 'notifier': {}, + "arxiv": {}, + "processor": {}, + "analyzer": {}, + "notifier": {}, } with pytest.raises(ValueError, match="Missing required section: 'logging'"): check_config(config) @@ -152,13 +158,15 @@ def test_missing_required_section(self): def test_missing_categories_subsection(self): """Missing categories in arxiv section raises ValueError.""" config = { - 'arxiv': {}, - 'processor': {}, - 'analyzer': {}, - 'notifier': {}, - 'logging': {}, + "arxiv": {}, + "processor": {}, + "analyzer": {}, + "notifier": {}, + "logging": {}, } - with pytest.raises(ValueError, match="Missing required subsection: 'categories' in 'arxiv'"): + with pytest.raises( + ValueError, match="Missing required subsection: 'categories' in 'arxiv'" + ): check_config(config) def test_valid_config_passes(self, valid_base_config): @@ -167,31 +175,36 @@ def test_valid_config_passes(self, valid_base_config): def test_valid_multiple_categories(self, valid_base_config): """Multiple valid arXiv categories pass validation.""" - valid_base_config['arxiv']['categories'] = ['cs.AI', 'math.CO', 'physics.APP'] + valid_base_config["arxiv"]["categories"] = ["cs.AI", "math.CO", "physics.APP"] assert check_config(valid_base_config) is None def test_notifier_is_optional(self, valid_base_config): """Notifier section can be omitted for stdout/atom delivery.""" - del valid_base_config['notifier'] + del valid_base_config["notifier"] assert check_config(valid_base_config) is None class TestInvalidCategories: """Tests for invalid arXiv category validation.""" - @pytest.mark.parametrize("invalid_category", [ - 'invalid', # No dot - 'cs.ai', # Lowercase after dot - 'CS.AI', # Uppercase before dot - 'cs.A', # Only one letter after dot - 'cs.AI.ML', # More than one dot - '123.AI', # Numbers before dot - 'cs.123', # Numbers after dot - ]) + @pytest.mark.parametrize( + "invalid_category", + [ + "invalid", # No dot + "cs.ai", # Lowercase after dot + "CS.AI", # Uppercase before dot + "cs.A", # Only one letter after dot + "cs.AI.ML", # More than one dot + "123.AI", # Numbers before dot + "cs.123", # Numbers after dot + ], + ) def test_invalid_category_formats(self, valid_base_config, invalid_category): """Invalid arXiv category format raises ValueError.""" - valid_base_config['arxiv']['categories'] = [invalid_category] - with pytest.raises(ValueError, match=f"Invalid arXiv category: {invalid_category}"): + valid_base_config["arxiv"]["categories"] = [invalid_category] + with pytest.raises( + ValueError, match=f"Invalid arXiv category: {invalid_category}" + ): check_config(valid_base_config) @@ -200,14 +213,19 @@ class TestAnalyzerValidation: def test_invalid_analyzer_type(self, valid_base_config): """Invalid analyzer type raises ValueError.""" - valid_base_config['analyzer']['type'] = 'invalid_type' + valid_base_config["analyzer"]["type"] = "invalid_type" with pytest.raises(ValueError, match="Invalid analyzer type: 'invalid_type'"): check_config(valid_base_config) def test_invalid_llm_provider(self, valid_base_config): """Invalid LLM provider raises ValueError.""" - valid_base_config['analyzer'] = {'type': 'summary', 'llm_provider': 'invalid_provider'} - with pytest.raises(ValueError, match="Invalid LLM provider: 'invalid_provider'"): + valid_base_config["analyzer"] = { + "type": "summary", + "llm_provider": "invalid_provider", + } + with pytest.raises( + ValueError, match="Invalid LLM provider: 'invalid_provider'" + ): check_config(valid_base_config) @@ -216,32 +234,34 @@ class TestEmailValidation: def test_missing_email_field(self, valid_base_config): """Missing required email field raises ValueError.""" - valid_base_config['notifier']['email'] = {'to': 'test@example.com'} + valid_base_config["notifier"]["email"] = {"to": "test@example.com"} with pytest.raises(ValueError, match="Missing required email field: 'from'"): check_config(valid_base_config) def test_no_auth_does_not_require_password(self, valid_base_config): """When use_auth=False, password is not required.""" - valid_base_config['notifier']['email'] = { - 'to': 'test@example.com', - 'from': 'sender@example.com', - 'smtp_server': 'smtp.example.com', - 'smtp_port': 587, - 'use_auth': False, + valid_base_config["notifier"]["email"] = { + "to": "test@example.com", + "from": "sender@example.com", + "smtp_server": "smtp.example.com", + "smtp_port": 587, + "use_auth": False, } assert check_config(valid_base_config) is None def test_auth_requires_password(self, valid_base_config): """When use_auth is True (default), password is required.""" - del valid_base_config['notifier']['email']['password'] - with pytest.raises(ValueError, match="Missing required email field: 'password'"): + del valid_base_config["notifier"]["email"]["password"] + with pytest.raises( + ValueError, match="Missing required email field: 'password'" + ): check_config(valid_base_config) def test_email_disabled_skips_required_fields(self, valid_base_config): """Email requirements are skipped when explicitly disabled.""" - valid_base_config['notifier'] = { - 'type': 'email', - 'email': {'enabled': False}, + valid_base_config["notifier"] = { + "type": "email", + "email": {"enabled": False}, } assert check_config(valid_base_config) is None @@ -251,7 +271,7 @@ class TestLoggingValidation: def test_invalid_logging_level(self, valid_base_config): """Invalid logging level raises ValueError.""" - valid_base_config['logging']['level'] = 'INVALID_LEVEL' + valid_base_config["logging"]["level"] = "INVALID_LEVEL" with pytest.raises(ValueError, match="Invalid logging level: 'INVALID_LEVEL'"): check_config(valid_base_config) @@ -261,16 +281,18 @@ class TestDatabaseValidation: def test_db_enabled_requires_integer_port(self, valid_base_config): """Database port must be integer when db is enabled.""" - valid_base_config['db'] = { - 'enabled': True, - 'host': 'localhost', - 'port': None, - 'database': 'paperweight', - 'user': 'paperweight', - 'password': 'pass', - 'sslmode': 'prefer', + valid_base_config["db"] = { + "enabled": True, + "host": "localhost", + "port": None, + "database": "paperweight", + "user": "paperweight", + "password": "pass", + "sslmode": "prefer", } - with pytest.raises(ValueError, match="'port' in 'db' section must be a valid integer"): + with pytest.raises( + ValueError, match="'port' in 'db' section must be a valid integer" + ): check_config(valid_base_config) @@ -278,42 +300,51 @@ def test_db_enabled_requires_integer_port(self, valid_base_config): # Utility Function Tests # --------------------------------------------------------------------------- + class TestEnvVarExpansion: """Tests for environment variable expansion.""" def test_expand_env_vars(self): """Expand $VAR and ${VAR} syntax in config values.""" - with patch.dict(os.environ, {'TEST_VAR': 'test_value', 'NESTED_VAR': 'nested_value'}): + with patch.dict( + os.environ, {"TEST_VAR": "test_value", "NESTED_VAR": "nested_value"} + ): config = { - 'simple': '$TEST_VAR', - 'nested': {'key': '${NESTED_VAR}', 'list': ['$TEST_VAR', '${NESTED_VAR}']}, - 'untouched': 123, + "simple": "$TEST_VAR", + "nested": { + "key": "${NESTED_VAR}", + "list": ["$TEST_VAR", "${NESTED_VAR}"], + }, + "untouched": 123, } expanded = expand_env_vars(config) - assert expanded['simple'] == 'test_value' - assert expanded['nested']['key'] == 'nested_value' - assert expanded['nested']['list'] == ['test_value', 'nested_value'] - assert expanded['untouched'] == 123 + assert expanded["simple"] == "test_value" + assert expanded["nested"]["key"] == "nested_value" + assert expanded["nested"]["list"] == ["test_value", "nested_value"] + assert expanded["untouched"] == 123 def test_override_with_env(self): """PAPERWEIGHT_* env vars override config values with type coercion.""" config = { - 'max_results': 50, - 'enable_feature': False, - 'api_url': 'https://api.example.com', - 'timeout': 30.5, + "max_results": 50, + "enable_feature": False, + "api_url": "https://api.example.com", + "timeout": 30.5, } - with patch.dict(os.environ, { - 'PAPERWEIGHT_MAX_RESULTS': '100', - 'PAPERWEIGHT_ENABLE_FEATURE': 'true', - 'PAPERWEIGHT_API_URL': 'https://new-api.example.com', - 'PAPERWEIGHT_TIMEOUT': '60.5', - }): + with patch.dict( + os.environ, + { + "PAPERWEIGHT_MAX_RESULTS": "100", + "PAPERWEIGHT_ENABLE_FEATURE": "true", + "PAPERWEIGHT_API_URL": "https://new-api.example.com", + "PAPERWEIGHT_TIMEOUT": "60.5", + }, + ): overridden = override_with_env(config) - assert overridden['max_results'] == 100 - assert overridden['enable_feature'] is True - assert overridden['api_url'] == 'https://new-api.example.com' - assert overridden['timeout'] == 60.5 + assert overridden["max_results"] == 100 + assert overridden["enable_feature"] is True + assert overridden["api_url"] == "https://new-api.example.com" + assert overridden["timeout"] == 60.5 class TestArxivSectionValidation: @@ -321,14 +352,18 @@ class TestArxivSectionValidation: def test_negative_max_results_raises(self): """Negative max_results raises ValueError.""" - with pytest.raises(ValueError, match="'max_results' in 'arxiv' section must be a non-negative integer"): - _check_arxiv_section({'categories': ['cs.AI'], 'max_results': -1}) + with pytest.raises( + ValueError, + match="'max_results' in 'arxiv' section must be a non-negative integer", + ): + _check_arxiv_section({"categories": ["cs.AI"], "max_results": -1}) # --------------------------------------------------------------------------- # Profile Tests # --------------------------------------------------------------------------- + class TestProfiles: """Tests for profile switching.""" @@ -374,6 +409,7 @@ def test_load_config_with_profile(self, tmp_path): # DEFAULT_CONFIG Merge Tests # --------------------------------------------------------------------------- + class TestDefaultConfigMerge: """Tests for DEFAULT_CONFIG merge behavior in load_config.""" diff --git a/tests/test_contracts.py b/tests/test_contracts.py index dddc920..b6bf0f1 100644 --- a/tests/test_contracts.py +++ b/tests/test_contracts.py @@ -22,9 +22,7 @@ def test_no_circular_imports(self): import paperweight package_path = paperweight.__path__ - module_names = [ - name for _, name, _ in pkgutil.iter_modules(package_path) - ] + module_names = [name for _, name, _ in pkgutil.iter_modules(package_path)] for module_name in module_names: full_name = f"paperweight.{module_name}" diff --git a/tests/test_notifier.py b/tests/test_notifier.py index 3a5a163..b2f3d76 100644 --- a/tests/test_notifier.py +++ b/tests/test_notifier.py @@ -11,37 +11,37 @@ ) -@patch('paperweight.notifier.smtplib.SMTP') +@patch("paperweight.notifier.smtplib.SMTP") def test_send_email_notification(mock_smtp): mock_server = MagicMock() mock_smtp.return_value = mock_server config = { - 'email': { - 'from': 'sender@example.com', - 'to': 'recipient@example.com', - 'password': 'password123', - 'smtp_server': 'smtp.example.com', - 'smtp_port': 587 + "email": { + "from": "sender@example.com", + "to": "recipient@example.com", + "password": "password123", + "smtp_server": "smtp.example.com", + "smtp_port": 587, } } send_email_notification("Test Subject", "Test Body", config) mock_server.starttls.assert_called_once() - mock_server.login.assert_called_once_with('sender@example.com', 'password123') + mock_server.login.assert_called_once_with("sender@example.com", "password123") mock_server.sendmail.assert_called_once() mock_server.quit.assert_called_once() -@patch('paperweight.notifier.send_email_notification') +@patch("paperweight.notifier.send_email_notification") def test_compile_and_send_notifications_empty_list(mock_send_email): config = { - 'email': { - 'from': 'sender@example.com', - 'to': 'recipient@example.com', - 'password': 'password123', - 'smtp_server': 'smtp.example.com', - 'smtp_port': 587 + "email": { + "from": "sender@example.com", + "to": "recipient@example.com", + "password": "password123", + "smtp_server": "smtp.example.com", + "smtp_port": 587, } } compile_and_send_notifications([], config) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 64c7f7d..0f70c3e 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -48,6 +48,7 @@ # Fixtures # --------------------------------------------------------------------------- + @pytest.fixture def integration_config(tmp_path): """Load and patch config for integration testing.""" @@ -81,10 +82,10 @@ def integration_config(tmp_path): def mock_main_dependencies(mocker): """Mock all external dependencies for main() tests.""" # Mock sys.argv to prevent argparse from picking up pytest arguments - mocker.patch('sys.argv', ['paperweight']) + mocker.patch("sys.argv", ["paperweight"]) # Mock configuration and logging - mock_load_config = mocker.patch('paperweight.main.load_config') + mock_load_config = mocker.patch("paperweight.main.load_config") mock_load_config.return_value = { "logging": {"level": "INFO"}, "processor": {}, @@ -92,61 +93,59 @@ def mock_main_dependencies(mocker): "notifier": {"email": {}}, "db": {"enabled": False}, } - mock_setup_logging = mocker.patch('paperweight.main.setup_logging') + mock_setup_logging = mocker.patch("paperweight.main.setup_logging") # Mock paper fetching and processing - mock_get_recent_papers = mocker.patch('paperweight.main.get_recent_papers') + mock_get_recent_papers = mocker.patch("paperweight.main.get_recent_papers") mock_get_recent_papers.return_value = [{"id": "1234.5678", "title": "Test Paper"}] mock_triage_papers = mocker.patch("paperweight.main.triage_papers") mock_triage_papers.side_effect = lambda papers, _config: papers mock_hydrate_papers = mocker.patch("paperweight.main.hydrate_papers_with_content") mock_hydrate_papers.side_effect = lambda papers, _config: papers - mock_process_papers = mocker.patch('paperweight.main.process_papers') + mock_process_papers = mocker.patch("paperweight.main.process_papers") mock_process_papers.return_value = [ {"id": "1234.5678", "title": "Test Paper", "relevance_score": 0.8} ] - mock_get_abstracts = mocker.patch('paperweight.main.get_abstracts') + mock_get_abstracts = mocker.patch("paperweight.main.get_abstracts") mock_get_abstracts.return_value = ["Test summary"] # Mock digest rendering/writing - mock_render_text_digest = mocker.patch('paperweight.main.render_text_digest') + mock_render_text_digest = mocker.patch("paperweight.main.render_text_digest") mock_render_text_digest.return_value = "digest" - mock_render_json_digest = mocker.patch('paperweight.main.render_json_digest') + mock_render_json_digest = mocker.patch("paperweight.main.render_json_digest") mock_render_json_digest.return_value = "[]" - mock_write_output = mocker.patch('paperweight.main.write_output') - mock_render_atom_feed = mocker.patch('paperweight.main.render_atom_feed') + mock_write_output = mocker.patch("paperweight.main.write_output") + mock_render_atom_feed = mocker.patch("paperweight.main.render_atom_feed") mock_render_atom_feed.return_value = "" # Mock notifications - mock_notifications = mocker.patch( - 'paperweight.main.compile_and_send_notifications' - ) + mock_notifications = mocker.patch("paperweight.main.compile_and_send_notifications") mock_notifications.return_value = True # Mock database functions - mock_is_db_enabled = mocker.patch('paperweight.main.is_db_enabled') + mock_is_db_enabled = mocker.patch("paperweight.main.is_db_enabled") mock_is_db_enabled.return_value = False # Mock logger - mock_logger = mocker.patch('paperweight.main.logger') + mock_logger = mocker.patch("paperweight.main.logger") return { - 'load_config': mock_load_config, - 'setup_logging': mock_setup_logging, - 'get_recent_papers': mock_get_recent_papers, - 'triage_papers': mock_triage_papers, - 'hydrate_papers_with_content': mock_hydrate_papers, - 'process_papers': mock_process_papers, - 'get_abstracts': mock_get_abstracts, - 'render_text_digest': mock_render_text_digest, - 'render_json_digest': mock_render_json_digest, - 'write_output': mock_write_output, - 'render_atom_feed': mock_render_atom_feed, - 'notifications': mock_notifications, - 'logger': mock_logger, - 'is_db_enabled': mock_is_db_enabled, + "load_config": mock_load_config, + "setup_logging": mock_setup_logging, + "get_recent_papers": mock_get_recent_papers, + "triage_papers": mock_triage_papers, + "hydrate_papers_with_content": mock_hydrate_papers, + "process_papers": mock_process_papers, + "get_abstracts": mock_get_abstracts, + "render_text_digest": mock_render_text_digest, + "render_json_digest": mock_render_json_digest, + "write_output": mock_write_output, + "render_atom_feed": mock_render_atom_feed, + "notifications": mock_notifications, + "logger": mock_logger, + "is_db_enabled": mock_is_db_enabled, } @@ -154,10 +153,11 @@ def mock_main_dependencies(mocker): # Full Pipeline Tests # --------------------------------------------------------------------------- + @pytest.mark.integration @pytest.mark.skipif( not os.getenv(LIVE_INTEGRATION_ENV), - reason=f"Set {LIVE_INTEGRATION_ENV}=1 to run live integration test." + reason=f"Set {LIVE_INTEGRATION_ENV}=1 to run live integration test.", ) def test_pipeline_end_to_end(integration_config): # noqa: C901 """Full pipeline: fetch, process, summarize, store, notify.""" @@ -187,7 +187,9 @@ def test_pipeline_end_to_end(integration_config): # noqa: C901 config_hash = hash_config(integration_config) pipeline_version = get_package_version() with connect_db(integration_config["db"]) as conn: - run_id = create_run(conn, config_hash, pipeline_version, "pytest_integration") + run_id = create_run( + conn, config_hash, pipeline_version, "pytest_integration" + ) conn.commit() # 1. Fetch @@ -217,7 +219,9 @@ def test_pipeline_end_to_end(integration_config): # noqa: C901 conn.commit() # 4. Notify - notification_sent = compile_and_send_notifications(processed, integration_config["notifier"]) + notification_sent = compile_and_send_notifications( + processed, integration_config["notifier"] + ) assert notification_sent, "Notification send failed" # 5. Verify Email @@ -338,74 +342,79 @@ def fake_fetch_paper_contents(paper_ids, max_workers=6): # Error Handling Tests (absorbed from test_main.py) # --------------------------------------------------------------------------- + class TestMainErrorHandling: """Tests for error handling in the main entry point.""" def test_config_yaml_error(self, mock_main_dependencies): """YAML parsing errors are logged.""" - mock_main_dependencies['load_config'].side_effect = yaml.YAMLError("Invalid YAML") + mock_main_dependencies["load_config"].side_effect = yaml.YAMLError( + "Invalid YAML" + ) main() - mock_main_dependencies['logger'].error.assert_called_with( + mock_main_dependencies["logger"].error.assert_called_with( "Configuration error: Invalid YAML" ) def test_network_error(self, mock_main_dependencies): """Network errors are logged.""" - mock_main_dependencies['load_config'].side_effect = requests.RequestException( + mock_main_dependencies["load_config"].side_effect = requests.RequestException( "Connection failed" ) main() - mock_main_dependencies['logger'].error.assert_called_with( + mock_main_dependencies["logger"].error.assert_called_with( "Network error occurred: Connection failed" ) def test_database_unreachable(self, mock_main_dependencies, mocker): """Database connection errors are logged.""" mocker.patch( - 'paperweight.main.setup_and_get_papers', + "paperweight.main.setup_and_get_papers", side_effect=DatabaseConnectionError("Database enabled but unreachable."), ) main() - mock_main_dependencies['logger'].error.assert_called_with( + mock_main_dependencies["logger"].error.assert_called_with( "Database error: Database enabled but unreachable." ) def test_no_papers_found(self, mock_main_dependencies): """When no papers are found, notification is not called.""" - mock_main_dependencies['get_recent_papers'].return_value = [] + mock_main_dependencies["get_recent_papers"].return_value = [] main() - mock_main_dependencies['notifications'].assert_not_called() - mock_main_dependencies['logger'].info.assert_any_call( + mock_main_dependencies["notifications"].assert_not_called() + mock_main_dependencies["logger"].info.assert_any_call( "No new papers to process. Exiting." ) def test_default_delivery_writes_digest(self, mock_main_dependencies): """Default mode renders and writes stdout digest; abstract mode skips hydration.""" main() - mock_main_dependencies['get_recent_papers'].assert_called_once_with( - mock_main_dependencies['load_config'].return_value, include_content=False + mock_main_dependencies["get_recent_papers"].assert_called_once_with( + mock_main_dependencies["load_config"].return_value, include_content=False ) - mock_main_dependencies['triage_papers'].assert_called_once() - mock_main_dependencies['process_papers'].assert_called_once() + mock_main_dependencies["triage_papers"].assert_called_once() + mock_main_dependencies["process_papers"].assert_called_once() # Abstract mode skips content hydration entirely - mock_main_dependencies['hydrate_papers_with_content'].assert_not_called() - mock_main_dependencies['render_text_digest'].assert_called_once() - mock_main_dependencies['write_output'].assert_called_once() - mock_main_dependencies['notifications'].assert_not_called() + mock_main_dependencies["hydrate_papers_with_content"].assert_not_called() + mock_main_dependencies["render_text_digest"].assert_called_once() + mock_main_dependencies["write_output"].assert_called_once() + mock_main_dependencies["notifications"].assert_not_called() def test_email_delivery_uses_notifier(self, mock_main_dependencies, monkeypatch): """Email mode delegates to notifier adapter.""" - monkeypatch.setattr('sys.argv', ['paperweight', '--delivery', 'email']) + monkeypatch.setattr("sys.argv", ["paperweight", "--delivery", "email"]) main() - mock_main_dependencies['notifications'].assert_called_once() + mock_main_dependencies["notifications"].assert_called_once() - def test_json_delivery_uses_json_renderer(self, mock_main_dependencies, monkeypatch): + def test_json_delivery_uses_json_renderer( + self, mock_main_dependencies, monkeypatch + ): """JSON mode renders JSON payload and writes output.""" - monkeypatch.setattr('sys.argv', ['paperweight', '--delivery', 'json']) + monkeypatch.setattr("sys.argv", ["paperweight", "--delivery", "json"]) main() - mock_main_dependencies['render_json_digest'].assert_called_once() - mock_main_dependencies['write_output'].assert_called_once() + mock_main_dependencies["render_json_digest"].assert_called_once() + mock_main_dependencies["write_output"].assert_called_once() diff --git a/tests/test_processor.py b/tests/test_processor.py index fd6fddf..05810b5 100644 --- a/tests/test_processor.py +++ b/tests/test_processor.py @@ -18,15 +18,15 @@ def processor_config(): """Standard processor configuration for tests.""" return { - 'keywords': ['AI', 'healthcare', 'quantum', 'computing'], - 'exclusion_keywords': ['biology'], - 'important_words': ['artificial intelligence'], - 'title_keyword_weight': 3, - 'abstract_keyword_weight': 2, - 'content_keyword_weight': 1, - 'exclusion_keyword_penalty': 5, - 'important_words_weight': 0.5, - 'min_score': 0, + "keywords": ["AI", "healthcare", "quantum", "computing"], + "exclusion_keywords": ["biology"], + "important_words": ["artificial intelligence"], + "title_keyword_weight": 3, + "abstract_keyword_weight": 2, + "content_keyword_weight": 1, + "exclusion_keyword_penalty": 5, + "important_words_weight": 0.5, + "min_score": 0, } @@ -36,17 +36,17 @@ class TestCalculatePaperScore: def test_score_breakdown_structure(self, processor_config): """Score calculation returns score and breakdown dict.""" paper = { - 'title': 'AI in Healthcare', - 'abstract': 'This paper discusses AI applications in healthcare.', - 'content': 'Artificial Intelligence has numerous applications in healthcare...' + "title": "AI in Healthcare", + "abstract": "This paper discusses AI applications in healthcare.", + "content": "Artificial Intelligence has numerous applications in healthcare...", } score, breakdown = calculate_paper_score(paper, processor_config) assert score > 0 - assert 'keyword_matching' in breakdown - assert 'exclusion_penalty' in breakdown - assert 'important_words' in breakdown + assert "keyword_matching" in breakdown + assert "exclusion_penalty" in breakdown + assert "important_words" in breakdown class TestProcessPapers: @@ -56,24 +56,24 @@ def test_papers_sorted_by_relevance(self, processor_config): """Papers are sorted by relevance score, highest first.""" papers = [ { - 'title': 'AI in Healthcare', - 'abstract': 'This paper discusses the applications of AI in healthcare.', - 'content': 'Artificial Intelligence has numerous applications in healthcare...' + "title": "AI in Healthcare", + "abstract": "This paper discusses the applications of AI in healthcare.", + "content": "Artificial Intelligence has numerous applications in healthcare...", }, { - 'title': 'Quantum Computing Advances', - 'abstract': 'Recent advancements in quantum computing are presented.', - 'content': 'Quantum computing has seen significant progress in recent years...' - } + "title": "Quantum Computing Advances", + "abstract": "Recent advancements in quantum computing are presented.", + "content": "Quantum computing has seen significant progress in recent years...", + }, ] - processor_config['min_score'] = 5 + processor_config["min_score"] = 5 processed = process_papers(papers, processor_config) assert len(processed) == 2 - assert processed[0]['relevance_score'] > processed[1]['relevance_score'] - assert 'score_breakdown' in processed[0] - assert 'normalized_score' in processed[0] + assert processed[0]["relevance_score"] > processed[1]["relevance_score"] + assert "score_breakdown" in processed[0] + assert "normalized_score" in processed[0] def test_empty_input_returns_empty(self, processor_config): """Empty paper list returns empty result.""" @@ -121,24 +121,29 @@ class TestNormalizeScores: def test_normalization_range(self): """Scores are normalized to 0-1 range.""" papers = [ - {'relevance_score': 10}, - {'relevance_score': 20}, - {'relevance_score': 30}, - {'relevance_score': 40}, + {"relevance_score": 10}, + {"relevance_score": 20}, + {"relevance_score": 30}, + {"relevance_score": 40}, ] normalized = normalize_scores(papers) - assert normalized[0]['normalized_score'] == 0.0 - assert normalized[-1]['normalized_score'] == 1.0 - assert 0.0 < normalized[1]['normalized_score'] < normalized[2]['normalized_score'] < 1.0 + assert normalized[0]["normalized_score"] == 0.0 + assert normalized[-1]["normalized_score"] == 1.0 + assert ( + 0.0 + < normalized[1]["normalized_score"] + < normalized[2]["normalized_score"] + < 1.0 + ) def test_equal_scores_normalize_to_one(self): """When all scores are equal, normalized scores are 1.0.""" papers = [ - {'relevance_score': 10}, - {'relevance_score': 10}, - {'relevance_score': 10} + {"relevance_score": 10}, + {"relevance_score": 10}, + {"relevance_score": 10}, ] normalized = normalize_scores(papers) - assert all(paper['normalized_score'] == 1.0 for paper in normalized) + assert all(paper["normalized_score"] == 1.0 for paper in normalized) diff --git a/tests/test_scraper.py b/tests/test_scraper.py index b05db35..0e7749f 100644 --- a/tests/test_scraper.py +++ b/tests/test_scraper.py @@ -18,7 +18,6 @@ hydrate_papers_with_content, ) - # --------------------------------------------------------------------------- # fetch_arxiv_papers — batched OR query # --------------------------------------------------------------------------- @@ -419,7 +418,7 @@ def _make_item_xml( ): """Build a minimal RSS element for testing.""" parts = [ - f"", + "", f"{title}", f"{link}", f"{description}", @@ -538,6 +537,7 @@ def test_fetch_rss_deduplicates_across_categories(mock_fetch): @patch("paperweight.scraper._fetch_single_rss_feed") def test_fetch_rss_one_category_fails(mock_fetch): """If one category feed fails, other categories still return papers.""" + def side_effect(url): if "cs.AI" in url: raise ConnectionError("boom")