diff --git a/.gitignore b/.gitignore index 680ca4b..43a24f8 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,8 @@ data/ artifacts/ .DS_Store .coverage +.paperweight_cache.json +*.bak # Build artifacts build/ diff --git a/CHANGELOG.md b/CHANGELOG.md index cdf6459..7e503e1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.3.0] - 2026-02-15 + +### Added +- Profile switching via `--profile NAME` flag and `PAPERWEIGHT_PROFILE` env var +- Metadata cache (`metadata_cache` config section) to skip repeated arXiv API calls within a TTL window +- Progress logging during triage and summary LLM calls +- Per-call LLM timeout (45 s) for triage and summary to prevent hanging runs +- `--version` flag on the CLI +- Public API surface: `paperweight.__version__`, `load_config`, `get_recent_papers`, `score_papers`, etc. re-exported from `__init__.py` + +### Changed +- Triage now uses per-paper async calls (same pattern as summaries) instead of batch `run_many` +- Triage rationale is compact: prompt asks for max 20 words, output is whitespace-normalized and truncated +- `paperweight init` prints a clean error to stderr (no traceback) when config already exists; use `--force` to overwrite + ## [0.2.0] - 2026-02-14 ### Added @@ -72,7 +87,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Email notification system - YAML-based configuration -[Unreleased]: https://github.com/seanbrar/paperweight/compare/v0.2.0...HEAD +[Unreleased]: https://github.com/seanbrar/paperweight/compare/v0.3.0...HEAD +[0.3.0]: https://github.com/seanbrar/paperweight/compare/v0.2.0...v0.3.0 [0.2.0]: https://github.com/seanbrar/paperweight/compare/v0.1.2...v0.2.0 [0.1.2]: https://github.com/seanbrar/paperweight/compare/v0.1.1...v0.1.2 [0.1.1]: https://github.com/seanbrar/paperweight/compare/v0.1.0...v0.1.1 diff --git a/README.md b/README.md index 16a573f..66f4d13 100644 --- a/README.md +++ b/README.md @@ -71,10 +71,16 @@ paperweight run --delivery email # strict checks for CI/release gates paperweight doctor --strict + +# activate a named profile +paperweight run --profile fast ``` Detailed command behavior: `docs/CLI.md` +`--max-items` is a processing cap (not a guaranteed output count): paperweight +processes at most N fetched papers, and may output fewer if filters remove them. + ## Configuration Core sections: @@ -83,6 +89,8 @@ Core sections: - `triage`: shortlist gate (title + abstract) - `processor`: scoring config - `analyzer`: `abstract` or `summary` +- `metadata_cache` (optional, speeds up repeated runs) +- `profiles` (optional, named config overlays) - `logging` - `notifier` (optional, only for email) diff --git a/config-base.yaml b/config-base.yaml index f8a38be..4c7b805 100644 --- a/config-base.yaml +++ b/config-base.yaml @@ -19,26 +19,41 @@ processor: content_keyword_weight: 1 exclusion_keyword_penalty: 5 important_words_weight: 0.5 - min_score: 10 + min_score: 3 analyzer: type: abstract # summary | abstract - llm_provider: openai # openai | gemini + # llm_provider: openai # openai | gemini # api_key: ${OPENAI_API_KEY} + # model: gpt-5-mini max_input_tokens: 7000 max_input_chars: 20000 -# AI-first shortlist gate (title + abstract) -triage: +# AI triage — opt-in. Uncomment and set enabled: true to use. +# Requires an LLM provider and API key. +# triage: +# enabled: true +# llm_provider: openai # openai | gemini +# # api_key: ${OPENAI_API_KEY} +# # model: gpt-5-mini +# min_score: 60 +# max_selected: 25 + +# Metadata cache — avoids repeated arXiv API calls within the TTL window. +metadata_cache: enabled: true - llm_provider: openai # openai | gemini - # api_key: ${OPENAI_API_KEY} - min_score: 60 - max_selected: 25 + path: .paperweight_cache.json + ttl_hours: 4 logging: level: INFO # DEBUG | INFO | WARNING | ERROR - file: paperweight.log + # file: paperweight.log # Omit for stderr-only logging + +# Concurrency limits for parallelized pipeline stages. +concurrency: + content_fetch: 6 # 1-20, threads for paper content downloads + triage: 3 # 1-10, async LLM workers for triage + summary: 3 # 1-10, async LLM workers for summarization # Optional metadata for --delivery atom feed: @@ -71,3 +86,17 @@ db: # Optional artifact storage path storage: base_dir: data/artifacts + +# Named profiles — activate with --profile NAME or PAPERWEIGHT_PROFILE env var. +# Each profile is a partial config overlay that deep-merges on top of the base. +# profiles: +# fast: +# arxiv: +# max_results: 20 +# triage: +# max_selected: 10 +# deep: +# arxiv: +# max_results: 200 +# triage: +# max_selected: 50 diff --git a/docs/CLI.md b/docs/CLI.md index 0723fd8..23fe8ea 100644 --- a/docs/CLI.md +++ b/docs/CLI.md @@ -8,6 +8,10 @@ paperweight has three commands: `paperweight` is shorthand for `paperweight run`. +Global flags: + +- `--version` — print version and exit + ## run ```bash @@ -17,7 +21,9 @@ paperweight run \ [--delivery stdout|json|atom|email] \ [--output PATH] \ [--sort-order relevance|alphabetical|publication_time] \ - [--max-items N] + [--max-items N] \ + [--profile NAME] \ + [--quiet] ``` Behavior: @@ -26,6 +32,9 @@ Behavior: - runs triage on title + abstract - hydrates full text only for shortlisted papers - scores/summarizes and delivers digest +- `--max-items N` caps how many fetched papers enter processing (triage/hydration/summary); output may be fewer than `N` after filtering +- `--profile NAME` activates a named profile from the config's `profiles` section (or set `PAPERWEIGHT_PROFILE` env var) +- `--quiet` suppresses progress status lines on stderr Delivery modes: @@ -34,14 +43,24 @@ Delivery modes: - `atom`: Atom feed XML - `email`: SMTP send via `notifier.email` config -`json` fields: +`json` fields (always present): + +- `title` — paper title +- `arxiv_id` — arXiv identifier +- `authors` — list of author names +- `categories` — list of arXiv categories +- `published` — publication date (ISO format) +- `abstract` — paper abstract +- `link` — arXiv abstract URL +- `pdf_url` — direct PDF URL +- `score` — relevance score (float) +- `keywords_matched` — list of matched keywords -- `title` -- `date` -- `score` -- `why` -- `link` -- `summary` +`json` fields (conditional): + +- `triage_score` — present when triage is enabled +- `triage_rationale` — present when triage is enabled +- `summary` — present when summary differs from abstract (i.e. LLM summarization was used) ## init @@ -53,11 +72,12 @@ Behavior: - writes a minimal `config.yaml` template - refuses to overwrite unless `--force` is passed +- prints a clean error message (no traceback) if config already exists ## doctor ```bash -paperweight doctor [--config PATH] [--strict] +paperweight doctor [--config PATH] [--strict] [--profile NAME] ``` Checks: @@ -71,3 +91,4 @@ Exit codes: - `0`: healthy (or warnings present without `--strict`) - `1`: hard failure, or warning in strict mode + diff --git a/docs/CONFIGURATION.md b/docs/CONFIGURATION.md index 81a62b2..7df2a44 100644 --- a/docs/CONFIGURATION.md +++ b/docs/CONFIGURATION.md @@ -48,6 +48,37 @@ This keeps runtime lower than downloading full text for every candidate. ## Optional sections +### `metadata_cache` + +```yaml +metadata_cache: + enabled: false + path: .paperweight_cache.json + ttl_hours: 4 +``` + +When enabled, paperweight caches arXiv metadata locally and reuses it within the +TTL window, skipping repeated API calls. `--force-refresh` bypasses the cache. + +### `profiles` + +```yaml +profiles: + fast: + arxiv: + max_results: 20 + triage: + max_selected: 10 + deep: + arxiv: + max_results: 200 + triage: + max_selected: 50 +``` + +Activate with `--profile fast` or `PAPERWEIGHT_PROFILE=fast`. Each profile is +a partial config overlay that deep-merges on top of the base config. + ### `notifier` (only for `--delivery email`) ```yaml @@ -113,10 +144,14 @@ export PAPERWEIGHT_MAX_RESULTS=100 ## Analyzer keys When `analyzer.type: summary`, API key is required. +If a summary call fails at runtime, paperweight falls back to that paper's abstract. When `triage.enabled: true`, an API key is strongly recommended. Without one, paperweight falls back to a lightweight keyword/abstract heuristic. +If triage LLM calls fail or time out at runtime, paperweight falls back to +heuristic triage for the entire batch to keep behavior consistent within a run. + Provider keys: - `OPENAI_API_KEY` for OpenAI @@ -128,3 +163,4 @@ Provider keys: - `--delivery json` ignores `notifier`. - `--delivery atom` uses optional `feed` metadata. - `--delivery email` requires valid `notifier.email` settings. +- `--profile NAME` deep-merges the named profile on top of the base config before env overrides. diff --git a/docs/ROADMAP.md b/docs/ROADMAP.md index a829a4a..135a209 100644 --- a/docs/ROADMAP.md +++ b/docs/ROADMAP.md @@ -5,64 +5,68 @@ time saved, setup simplicity, and digest quality. ## Product definition +paperweight is a **fast, scriptable arXiv interface** — the himalaya of academic papers. +It fetches structured arXiv data, scores by keywords, and outputs rich JSON. +AI enrichment (triage, summarization) is available via config but imposes zero cost +on the default path. + paperweight should be better than "just checking arXiv" when the user wants: - a smaller daily reading queue -- deterministic output that can be automated -- relevance filtering that improves over time +- deterministic output that can be automated and piped +- keyword-scored relevance filtering out of the box +- structured metadata (authors, categories, PDF URLs) for scripting ## Core success metrics These metrics guide all releases: 1. **Time to first useful run** - - target: <= 5 minutes from install to first digest + - target: <= 2 minutes from install to first digest 2. **Daily digest size** - target: median 5-20 items after user tuning 3. **Runtime** - - target: <= 120 seconds for `3 categories x max_results=50` on default non-summary mode + - target: sub-second warm runs (metadata cached), <= 60s cold fetch for 3 categories x 50 papers 4. **CLI reliability** - target: >= 99% successful runs in local smoke workflows 5. **Signal quality (human-evaluated)** - target: >= 7/10 items marked "worth reading" in pilot usage -## v0.2 release gates (must pass) - -1. CLI contract stable: - - `run`, `init`, `doctor` - - `run` delivery: `stdout`, `json`, `atom`, optional `email` -2. Zero-key baseline works: - - `init` defaults to `analyzer.type: abstract` - - `run` works without LLM keys via triage fallback -3. Setup validation: - - `doctor --strict` returns non-zero on warnings/failures -4. Output ergonomics: - - deterministic text digest - - scriptable JSON - - Atom feed export -5. Quality checks: - - lint clean - - tests green (including small CLI integration suite) -6. Packaging: - - release workflow present and tag-driven - -## v0.3 focus (quality lift, not surface-area lift) - -1. **Speed** - - add metadata cache - - target: >= 40% runtime reduction on repeated daily runs -2. **Digest quality** +## v0.3 focus (config resilience, richer metadata, performance, CLI polish) + +1. **Config resilience** + - DEFAULT_CONFIG ensures partial/minimal configs never crash + - triage disabled by default (opt-in via config) + - log file optional (stderr-only by default) + - target: `paperweight run` works with only `arxiv.categories` set +2. **Richer metadata** + - capture authors, categories, PDF URL, arXiv ID from API + - track which keywords matched during scoring + - JSON output includes full structured data contract + - target: JSON schema always complete without AI +3. **Performance** + - lazy imports for heavy dependencies (psycopg, pollux, tiktoken, pypdf) + - parallel category fetching + - target: sub-second warm runs, ~3x cold-fetch speedup +4. **CLI polish & API surface** + - `--version` flag + - `init` prints clean error (not traceback) when config exists + - `__init__.py` exposes `__version__` and key public functions + - target: scriptable from `import paperweight` without submodule diving + +## v0.4 focus (typed data, AI enrichment, feedback loop) + +1. **Typed data structures** + - replace `Dict[str, Any]` pipeline with `Paper` dataclass/Pydantic model + - eliminate in-place mutation in `process_papers` (return new objects) + - target: zero `KeyError` risk from undocumented dict keys +2. **AI enrichment polish** - improve triage rationale quality and compactness - - target: rationale present on >= 95% of shortlisted items -3. **Workflow fit** - - add saved presets/profile switching - - target: switch profile in one command, no config edits - -## v0.4 focus (feedback loop) - -1. add local feedback capture (`relevant` / `irrelevant`) -2. incorporate feedback into ranking -3. target: +20% improvement in user-rated relevance from v0.2 baseline + - target: rationale present on >= 95% of shortlisted items when triage enabled +3. **Feedback loop** + - add local feedback capture (`relevant` / `irrelevant`) + - incorporate feedback into ranking + - target: +20% improvement in user-rated relevance from v0.2 baseline ## v1.0 criteria diff --git a/kick_tires.py b/kick_tires.py new file mode 100644 index 0000000..90399c4 --- /dev/null +++ b/kick_tires.py @@ -0,0 +1,64 @@ +import os +import subprocess + + +def run_command(command): + print(f"Running command: {command}") + try: + # Use subprocess.run to execute command + # split command string into args list + args = command.split() + result = subprocess.run(args, capture_output=True, text=True) + print("STDOUT:", result.stdout) + if result.stderr: + print("STDERR:", result.stderr) + return result + except Exception as e: + print(f"Error running command '{command}': {e}") + return None + +def check_import(): + print("\n--- Checking Import ---") + try: + import paperweight + print(f"Successfully imported paperweight. File: {paperweight.__file__}") + print("Dir(paperweight):", dir(paperweight)) + except ImportError as e: + print(f"Failed to import paperweight: {e}") + except Exception as e: + print(f"Error during import check: {e}") + +def main(): + print("--- Kicking the Tires of paperweight ---") + + # 1. Check if installed + check_import() + + # 2. Run CLI help + print("\n--- CLI Help ---") + run_command("paperweight --help") + + # 3. Init + print("\n--- Init ---") + if os.path.exists("config.yaml"): + print("config.yaml already exists. Backing it up to config.yaml.bak") + os.rename("config.yaml", "config.yaml.bak") + + run_command("paperweight init") + + if os.path.exists("config.yaml"): + print("config.yaml created successfully.") + else: + print("config.yaml was NOT created.") + + # 4. Doctor + print("\n--- Doctor ---") + run_command("paperweight doctor") + + # 5. Run (Dry run or real run?) + print("\n--- Run ---") + # README says: paperweight run --force-refresh + run_command("paperweight run --force-refresh") + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 54de9db..4c8db48 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "academic-paperweight" -version = "0.2.0" +version = "0.3.0" description = "Automated retrieval, filtering, and LLM-powered summarization of arXiv papers based on your research interests." readme = "README.md" requires-python = ">=3.11, <3.14" diff --git a/scripts/verify_llm_openai.py b/scripts/verify_llm_openai.py new file mode 100644 index 0000000..c30106f --- /dev/null +++ b/scripts/verify_llm_openai.py @@ -0,0 +1,278 @@ +#!/usr/bin/env python3 +"""Smoke + latency checks for OpenAI-backed LLM features in paperweight. + +This script validates: +1) direct Pollux->OpenAI connectivity for tiny prompts +2) paperweight summary path latency for a tiny synthetic paper +3) paperweight triage path latency for a tiny synthetic shortlist +""" + +from __future__ import annotations + +import argparse +import asyncio +import logging +import os +import statistics +import sys +import time +from pathlib import Path +from typing import Any + +from dotenv import load_dotenv +from pollux import Config, RetryPolicy, run + +# Keep local script runnable without package install. +sys.path.append(str(Path(__file__).parent.parent / "src")) + +from paperweight.analyzer import summarize_paper, triage_papers + +logger = logging.getLogger("verify_llm_openai") + + +def _setup_logging(verbose: bool) -> None: + level = logging.DEBUG if verbose else logging.INFO + logging.basicConfig( + level=level, + format="%(asctime)s %(levelname)s %(message)s", + datefmt="%H:%M:%S", + ) + if not verbose: + # Keep default output concise and focused on probe results. + logging.getLogger("httpx").setLevel(logging.WARNING) + logging.getLogger("openai").setLevel(logging.WARNING) + logging.getLogger("pollux").setLevel(logging.WARNING) + logging.getLogger("paperweight.analyzer").setLevel(logging.WARNING) + + +def _safe_mean(values: list[float]) -> float: + return statistics.mean(values) if values else 0.0 + + +def _safe_p95(values: list[float]) -> float: + if not values: + return 0.0 + ordered = sorted(values) + idx = max(0, min(len(ordered) - 1, int(round((len(ordered) - 1) * 0.95)))) + return ordered[idx] + + +def _run_direct_probe(model: str, api_key: str, repeats: int) -> dict[str, Any]: + durations: list[float] = [] + failures: list[str] = [] + answers: list[str] = [] + logger.info("Direct probe model=%s repeats=%s", model, repeats) + + for idx in range(1, repeats + 1): + config = Config( + provider="openai", + model=model, + api_key=api_key, + retry=RetryPolicy(max_attempts=1, max_elapsed_s=15.0), + ) + start = time.perf_counter() + try: + result = asyncio.run(run("Reply with exactly: pong", config=config)) + elapsed = time.perf_counter() - start + response = "" + if isinstance(result, dict): + answers_blob = result.get("answers") + if isinstance(answers_blob, list) and answers_blob: + response = str(answers_blob[0]).strip() + durations.append(elapsed) + answers.append(response) + logger.info( + " run=%s status=ok elapsed=%.2fs response=%r", + idx, + elapsed, + response[:60], + ) + except Exception as exc: # pragma: no cover - integration behavior + elapsed = time.perf_counter() - start + durations.append(elapsed) + err = f"{type(exc).__name__}: {exc}" + failures.append(err) + logger.error(" run=%s status=error elapsed=%.2fs error=%s", idx, elapsed, err) + + return { + "model": model, + "durations": durations, + "mean_s": _safe_mean(durations), + "p95_s": _safe_p95(durations), + "failures": failures, + "answers": answers, + } + + +def _run_summary_probe(model: str, api_key: str) -> dict[str, Any]: + paper = { + "title": "Toy study of efficient transformer routing", + "abstract": ( + "We propose a compact routing mechanism for transformer blocks and " + "evaluate quality and speed tradeoffs." + ), + "content": ( + "Introduction. We study efficient routing in transformer models. " + "Method. We prune low-value experts and share projections. " + "Results. We report quality parity with reduced compute. " + ) + * 60, + } + config = { + "type": "summary", + "llm_provider": "openai", + "api_key": api_key, + "model": model, + "max_input_tokens": 1500, + "max_input_chars": 6000, + } + + logger.info("Summary probe model=%s", model) + start = time.perf_counter() + try: + summary = summarize_paper(paper, config) + elapsed = time.perf_counter() - start + logger.info( + " status=ok elapsed=%.2fs summary_len=%s", + elapsed, + len(summary or ""), + ) + return {"model": model, "elapsed_s": elapsed, "error": None} + except Exception as exc: # pragma: no cover - defensive + elapsed = time.perf_counter() - start + err = f"{type(exc).__name__}: {exc}" + logger.error(" status=error elapsed=%.2fs error=%s", elapsed, err) + return {"model": model, "elapsed_s": elapsed, "error": err} + + +def _run_triage_probe(model: str, api_key: str) -> dict[str, Any]: + papers = [ + { + "title": "Transformer sparsification for long-context NLP", + "abstract": "We compress attention with sparse expert routing.", + }, + { + "title": "Cataloging graph invariants in chemistry", + "abstract": "This paper studies graph properties and molecules.", + }, + ] + full_config = { + "triage": { + "enabled": True, + "llm_provider": "openai", + "api_key": api_key, + "model": model, + "min_score": 50, + "max_selected": 10, + }, + "analyzer": {"llm_provider": "openai", "api_key": api_key}, + "processor": {"keywords": ["transformer", "nlp", "reasoning"]}, + } + + logger.info("Triage probe model=%s papers=%s", model, len(papers)) + start = time.perf_counter() + try: + shortlisted = triage_papers(papers, full_config) + elapsed = time.perf_counter() - start + logger.info( + " status=ok elapsed=%.2fs selected=%s", + elapsed, + len(shortlisted), + ) + return {"model": model, "elapsed_s": elapsed, "error": None} + except Exception as exc: # pragma: no cover - defensive + elapsed = time.perf_counter() - start + err = f"{type(exc).__name__}: {exc}" + logger.error(" status=error elapsed=%.2fs error=%s", elapsed, err) + return {"model": model, "elapsed_s": elapsed, "error": err} + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Verify OpenAI LLM integration and latency for paperweight.", + ) + parser.add_argument( + "--models", + nargs="+", + default=["gpt-5-nano", "gpt-5-mini"], + help="Model IDs to test.", + ) + parser.add_argument( + "--repeats", + type=int, + default=3, + help="Direct tiny-prompt probes per model.", + ) + parser.add_argument( + "--dotenv-path", + default=".env", + help="Dotenv file to load before checking OPENAI_API_KEY.", + ) + parser.add_argument( + "--api-key-env", + default="OPENAI_API_KEY", + help="Environment variable containing the OpenAI API key.", + ) + parser.add_argument( + "--warn-threshold-seconds", + type=float, + default=6.0, + help="Warn if direct-probe mean latency exceeds this value.", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="Enable debug logs.", + ) + return parser.parse_args() + + +def main() -> int: + args = _parse_args() + _setup_logging(args.verbose) + + dotenv_path = Path(args.dotenv_path) + if dotenv_path.exists(): + load_dotenv(dotenv_path=dotenv_path) + logger.info("Loaded dotenv file: %s", dotenv_path) + else: + logger.info("Dotenv file not found at %s; using current environment", dotenv_path) + + api_key = os.getenv(args.api_key_env, "") + if not api_key: + logger.error("Missing API key: %s", args.api_key_env) + return 2 + + logger.info("Starting OpenAI verification models=%s", ",".join(args.models)) + had_failure = False + + for model in args.models: + direct = _run_direct_probe(model, api_key, args.repeats) + summary = _run_summary_probe(model, api_key) + triage = _run_triage_probe(model, api_key) + + logger.info( + "RESULT model=%s direct_mean=%.2fs direct_p95=%.2fs summary=%.2fs triage=%.2fs", + model, + direct["mean_s"], + direct["p95_s"], + float(summary["elapsed_s"]), + float(triage["elapsed_s"]), + ) + + if direct["failures"] or summary["error"] or triage["error"]: + had_failure = True + + if float(direct["mean_s"]) > float(args.warn_threshold_seconds): + logger.warning( + "Direct probe mean latency %.2fs exceeded threshold %.2fs for model=%s", + direct["mean_s"], + args.warn_threshold_seconds, + model, + ) + + return 1 if had_failure else 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/paperweight/__init__.py b/src/paperweight/__init__.py index e69de29..16c4e8a 100644 --- a/src/paperweight/__init__.py +++ b/src/paperweight/__init__.py @@ -0,0 +1,29 @@ +"""paperweight — an arXiv triage CLI.""" + +from importlib.metadata import PackageNotFoundError +from importlib.metadata import version as _pkg_version + +try: + __version__: str = _pkg_version("academic-paperweight") +except PackageNotFoundError: + __version__ = "unknown" + +# Public API re-exports — keep the surface small and intentional. +from paperweight.main import ( # noqa: E402 + process_and_summarize_papers, + score_papers, + setup_and_get_papers, + summarize_scored_papers, +) +from paperweight.scraper import get_recent_papers # noqa: E402 +from paperweight.utils import load_config # noqa: E402 + +__all__ = [ + "__version__", + "get_recent_papers", + "load_config", + "process_and_summarize_papers", + "score_papers", + "setup_and_get_papers", + "summarize_scored_papers", +] diff --git a/src/paperweight/analyzer.py b/src/paperweight/analyzer.py index a92bf38..3e29275 100644 --- a/src/paperweight/analyzer.py +++ b/src/paperweight/analyzer.py @@ -10,21 +10,50 @@ import os from typing import Any, Dict, List, Literal, cast -from pollux import Config, RetryPolicy, Source, run - from paperweight.utils import count_tokens ProviderName = Literal["gemini", "openai"] logger = logging.getLogger(__name__) +# Keep fanout modest for provider stability and local predictability. +SUMMARY_CONCURRENCY = 3 +TRIAGE_CONCURRENCY = 3 +LLM_TIMEOUT_S = 45.0 +RATIONALE_MAX_CHARS = 160 + + +def _int_setting(value: Any, default: int, *, minimum: int = 0) -> int: + try: + parsed = int(value) + except (TypeError, ValueError): + parsed = default + return max(minimum, parsed) + + +def _float_setting(value: Any, default: float, *, minimum: float = 0.0) -> float: + try: + parsed = float(value) + except (TypeError, ValueError): + parsed = default + return max(minimum, parsed) -def get_abstracts(processed_papers, config): + +def _compact_rationale(text, *, max_chars=RATIONALE_MAX_CHARS): + """Whitespace-normalize and truncate a triage rationale.""" + text = " ".join((text or "").split()) + if len(text) > max_chars: + text = text[: max_chars - 1].rstrip() + "\u2026" + return text or "No rationale" + + +def get_abstracts(processed_papers, config, *, summary_concurrency=None): """Extract abstracts or summaries from processed papers based on configuration. Args: processed_papers: List of dictionaries containing paper data. config: Configuration dictionary specifying analysis type and parameters. + summary_concurrency: Optional override for the number of concurrent summary workers. Returns: List of strings containing either abstracts or summaries based on config type. @@ -36,10 +65,9 @@ def get_abstracts(processed_papers, config): if analysis_type == "abstract": return [paper["abstract"] for paper in processed_papers] - elif analysis_type == "summary": - return [summarize_paper(paper, config) for paper in processed_papers] - else: - raise ValueError(f"Unknown analysis type: {analysis_type}") + if analysis_type == "summary": + return summarize_papers(processed_papers, config, summary_concurrency=summary_concurrency) + raise ValueError(f"Unknown analysis type: {analysis_type}") def _truncate_for_prompt( @@ -98,8 +126,8 @@ def _resolve_triage_model_config( or os.getenv(f"{provider.upper()}_API_KEY") or "" ) - min_score = float(triage_cfg.get("min_score", 60.0)) - max_selected = int(triage_cfg.get("max_selected", 25)) + min_score = _float_setting(triage_cfg.get("min_score"), 60.0, minimum=0.0) + max_selected = _int_setting(triage_cfg.get("max_selected"), 25, minimum=1) return provider, model, api_key, min_score, max_selected @@ -114,37 +142,41 @@ def _heuristic_triage_score(paper: Dict[str, Any], profile_terms: List[str]) -> return min(100.0, 100.0 * (hits / len(profile_terms))) -def _triage_one_paper( - paper: Dict[str, Any], - pollux_config: Config, - profile: str, +def _heuristic_triage( + papers: List[Dict[str, Any]], + profile_terms: List[str], *, min_score: float, -) -> Dict[str, Any]: + max_selected: int, + rationale: str, +) -> List[Dict[str, Any]]: + shortlisted = [] + for paper in papers: + score = _heuristic_triage_score(paper, profile_terms) + paper["triage_score"] = score + paper["triage_rationale"] = rationale + if score >= min_score: + shortlisted.append(paper) + return shortlisted[:max_selected] + + +def _build_triage_prompt(paper: Dict[str, Any], profile: str) -> str: title = (paper.get("title") or "").strip() abstract = (paper.get("abstract") or "").strip() - - prompt = ( + return ( "You are triaging arXiv papers for relevance.\n" "Return JSON only with keys: include (boolean), score (0-100 number), rationale (string).\n" + "Rationale must be a compact one-liner (max 20 words).\n" "Be strict. Include only if likely useful to the profile.\n\n" f"Profile:\n{profile}\n\n" f"Title: {title}\n\n" f"Abstract:\n{abstract}\n" ) - result = asyncio.run(run(prompt, config=pollux_config)) - response = None - if isinstance(result, dict): - answers = result.get("answers") - if isinstance(answers, list) and answers: - response = answers[0] + +def _parse_triage_decision(response: Any, *, min_score: float) -> Dict[str, Any]: if not response: - return { - "include": False, - "score": 0.0, - "rationale": "No model response", - } + raise ValueError("No model response") raw = str(response).strip() start = raw.find("{") @@ -155,10 +187,49 @@ def _triage_one_paper( parsed = json.loads(raw) score = float(parsed.get("score", 0.0)) include = bool(parsed.get("include", score >= min_score)) - rationale = str(parsed.get("rationale", "")).strip() + rationale = _compact_rationale(str(parsed.get("rationale", ""))) return {"include": include, "score": score, "rationale": rationale} +async def _triage_one_paper_async(prompt, pollux_config, *, min_score): + """Call `run` for a single triage prompt with a timeout.""" + from pollux import run + + result = await asyncio.wait_for( + run(prompt, config=pollux_config), timeout=LLM_TIMEOUT_S + ) + answer = None + if isinstance(result, dict): + answers = result.get("answers") + if isinstance(answers, list) and answers: + answer = answers[0] + return _parse_triage_decision(answer, min_score=min_score) + + +async def _run_triage_async(prompts, pollux_config, *, min_score, concurrency=TRIAGE_CONCURRENCY): + """Run triage prompts concurrently with a semaphore, returning decisions in order.""" + semaphore = asyncio.Semaphore(concurrency) + total = len(prompts) + completed = 0 + + async def _worker(index, prompt): + nonlocal completed + async with semaphore: + decision = await _triage_one_paper_async( + prompt, pollux_config, min_score=min_score + ) + completed += 1 + logger.info("Triage: %d/%d", completed, total) + return index, decision + + tasks = [asyncio.create_task(_worker(i, p)) for i, p in enumerate(prompts)] + results = [None] * total + for coro in asyncio.as_completed(tasks): + index, decision = await coro + results[index] = decision + return results + + def triage_papers( papers: List[Dict[str, Any]], full_config: Dict[str, Any], @@ -168,9 +239,11 @@ def triage_papers( return [] triage_cfg = full_config.get("triage", {}) - if not triage_cfg.get("enabled", True): + if not triage_cfg.get("enabled", False): return papers + from pollux import Config, RetryPolicy + provider, model, api_key, min_score, max_selected = _resolve_triage_model_config( full_config ) @@ -181,41 +254,50 @@ def triage_papers( logger.warning( "AI triage is enabled but provider/key is unavailable; using heuristic triage." ) - shortlisted = [] - for paper in papers: - score = _heuristic_triage_score(paper, profile_terms) - paper["triage_score"] = score - paper["triage_rationale"] = "Keyword/abstract heuristic fallback" - if score >= min_score: - shortlisted.append(paper) - return shortlisted[:max_selected] + return _heuristic_triage( + papers, + profile_terms, + min_score=min_score, + max_selected=max_selected, + rationale="Keyword/abstract heuristic fallback", + ) provider_name = cast(ProviderName, provider) pollux_config = Config( provider=provider_name, model=model, api_key=api_key, - retry=RetryPolicy(max_attempts=2, initial_delay_s=1.0, max_delay_s=5.0), + retry=RetryPolicy( + max_attempts=2, + initial_delay_s=1.0, + max_delay_s=5.0, + max_elapsed_s=20.0, + ), ) - shortlisted = [] - for paper in papers: - try: - decision = _triage_one_paper( - paper, - pollux_config, - profile_text, - min_score=min_score, - ) - except Exception as e: - logger.warning("AI triage failed for '%s': %s", paper.get("title", ""), e) - score = _heuristic_triage_score(paper, profile_terms) - decision = { - "include": score >= min_score, - "score": score, - "rationale": "LLM error; keyword/abstract heuristic fallback", - } + prompts = [_build_triage_prompt(paper, profile_text) for paper in papers] + + triage_concurrency = full_config.get("concurrency", {}).get("triage", TRIAGE_CONCURRENCY) + + try: + decisions = asyncio.run( + _run_triage_async(prompts, pollux_config, min_score=min_score, concurrency=triage_concurrency) + ) + except Exception as exc: + logger.warning( + "AI triage failed; using heuristic triage for entire batch: %s", + exc, + ) + return _heuristic_triage( + papers, + profile_terms, + min_score=min_score, + max_selected=max_selected, + rationale="LLM unavailable; keyword/abstract heuristic fallback", + ) + shortlisted = [] + for paper, decision in zip(papers, decisions): paper["triage_score"] = float(decision["score"]) paper["triage_rationale"] = decision["rationale"] if decision["include"] and float(decision["score"]) >= min_score: @@ -225,83 +307,163 @@ def triage_papers( return shortlisted[:max_selected] -def summarize_paper(paper: Dict[str, Any], config: Dict[str, Any]) -> str: - """Generate a summary of a paper using an LLM. +async def _summarize_one_paper_async( + paper: Dict[str, Any], + pollux_config: Any, + *, + max_input_tokens: int, + max_input_chars: int, +) -> str: + from pollux import Source, run - Uses Pollux for LLM interaction. Pollux handles retries internally - via RetryPolicy (exponential backoff with jitter). + title = (paper.get("title") or "").strip() + abstract = (paper.get("abstract") or "").strip() + content = paper.get("content") or "" - Args: - paper: Dictionary containing paper data including content and metadata. - config: Configuration dictionary containing LLM settings. + prompt = ( + "Summarize the paper for a busy researcher.\n" + "Constraints:\n" + "- Be accurate; do not invent results.\n" + "- 4-6 sentences.\n" + "- Include: problem, approach, key results/claims, and who should read it.\n\n" + f"Title: {title}\n\n" + f"Abstract:\n{abstract}\n\n" + ) - Returns: - A string containing the generated summary. - """ + content = _truncate_for_prompt( + str(content), + prompt, + max_input_tokens=max_input_tokens, + max_input_chars=max_input_chars, + ) + source = Source.from_text(content, identifier=title or "paper-content") + + input_tokens = count_tokens(prompt) + count_tokens(content) + logger.debug("Summary input tokens title=%r count=%s", title[:60], input_tokens) + + result = await asyncio.wait_for( + run(prompt, source=source, config=pollux_config), timeout=LLM_TIMEOUT_S + ) + response = None + if isinstance(result, dict): + answers = result.get("answers") + if isinstance(answers, list) and answers: + response = answers[0] + if not response: + raise RuntimeError(f"LLM returned no answers for '{title[:80]}'") + + output_tokens = count_tokens(response) + logger.debug("Summary output tokens title=%r count=%s", title[:60], output_tokens) + return str(response) + + +def _resolve_summary_model_config(config: Dict[str, Any]) -> tuple[ProviderName, str, str]: llm_provider = (config.get("llm_provider") or "openai").lower().strip() - api_key = config.get("api_key") + api_key = config.get("api_key") or os.getenv(f"{llm_provider.upper()}_API_KEY") or "" + if llm_provider not in ("openai", "gemini") or not api_key: + raise ValueError( + "Summary analyzer requires a valid llm_provider (openai|gemini) and api_key." + ) + model_name = (config.get("model") or "").strip() or _default_model_for_provider( + llm_provider + ) + return cast(ProviderName, llm_provider), model_name, api_key + + +def summarize_papers( # noqa: C901 + papers: List[Dict[str, Any]], + config: Dict[str, Any], + *, + summary_concurrency: int | None = None, +) -> List[str]: + """Summarize papers with abstract fallback on runtime LLM errors.""" + if not papers: + return [] + + from pollux import Config, RetryPolicy - if llm_provider not in ["openai", "gemini"] or not api_key: + provider, model_name, api_key = _resolve_summary_model_config(config) + max_input_tokens = _int_setting(config.get("max_input_tokens"), 7000, minimum=500) + max_input_chars = _int_setting(config.get("max_input_chars"), 20_000, minimum=1000) + effective_concurrency = summary_concurrency if summary_concurrency is not None else SUMMARY_CONCURRENCY + + pollux_config = Config( + provider=provider, + model=model_name, + api_key=api_key, + retry=RetryPolicy( + max_attempts=3, + initial_delay_s=1.0, + max_delay_s=10.0, + max_elapsed_s=30.0, + ), + ) + + async def _run_summary_batch() -> tuple[List[str | None], List[tuple[int, BaseException]]]: + semaphore = asyncio.Semaphore(effective_concurrency) + results: List[str | None] = [None] * len(papers) + failures: List[tuple[int, BaseException]] = [] + completed = 0 + total = len(papers) + + async def _worker(index: int, paper: Dict[str, Any]): + async with semaphore: + try: + summary = await _summarize_one_paper_async( + paper, + pollux_config, + max_input_tokens=max_input_tokens, + max_input_chars=max_input_chars, + ) + return index, summary, None + except Exception as exc: + return index, None, exc + + tasks = [ + asyncio.create_task(_worker(index, paper)) + for index, paper in enumerate(papers) + ] + + for task in asyncio.as_completed(tasks): + index, summary, exc = await task + completed += 1 + logger.info("Summary: %d/%d", completed, total) + if exc is not None: + failures.append((index, exc)) + continue + results[index] = summary + + return results, failures + + raw_summaries, failures = asyncio.run(_run_summary_batch()) + + summaries: List[str] = [] + for index, summary in enumerate(raw_summaries): + if summary is not None: + summaries.append(summary) + continue + fallback = str(papers[index].get("abstract") or "") + summaries.append(fallback) + + if failures: logger.warning( - f"No valid LLM provider or API key available for {llm_provider}. Falling back to abstract." + "Summary fallback used for %s/%s papers.", + len(failures), + len(papers), ) - return paper["abstract"] + for index, exc in failures: + logger.warning( + "Summary failed for '%s': %s", + papers[index].get("title", ""), + exc, + ) - try: - provider: ProviderName = llm_provider # type: ignore[assignment] # guarded above - model_name = (config.get("model") or "").strip() or _default_model_for_provider( - llm_provider - ) - pollux_config = Config( - provider=provider, - model=model_name, - api_key=api_key, - retry=RetryPolicy(max_attempts=3, initial_delay_s=1.0, max_delay_s=10.0), - ) + return summaries - title = (paper.get("title") or "").strip() - abstract = (paper.get("abstract") or "").strip() - content = paper.get("content") or "" - - # Guardrails. Defaults intentionally conservative. - max_input_tokens = int(config.get("max_input_tokens", 7000)) - max_input_chars = int(config.get("max_input_chars", 20_000)) - - prompt = ( - "Summarize the paper for a busy researcher.\n" - "Constraints:\n" - "- Be accurate; do not invent results.\n" - "- 4-6 sentences.\n" - "- Include: problem, approach, key results/claims, and who should read it.\n\n" - f"Title: {title}\n\n" - f"Abstract:\n{abstract}\n\n" - ) - content = _truncate_for_prompt( - str(content), - prompt, - max_input_tokens=max_input_tokens, - max_input_chars=max_input_chars, - ) - source = Source.from_text(content, identifier=title or "paper-content") - - input_tokens = count_tokens(prompt) + count_tokens(content) - logger.info(f"Input token count: {input_tokens}") - - result = asyncio.run(run(prompt, source=source, config=pollux_config)) - response = None - if isinstance(result, dict): - answers = result.get("answers") - if isinstance(answers, list) and answers: - response = answers[0] - if not response: - logger.warning("LLM returned no answers; falling back to abstract.") - return paper.get("abstract", "") - - output_tokens = count_tokens(response) - logger.info(f"Output token count: {output_tokens}") - - return response - except Exception as e: - logger.error(f"Error summarizing paper: {e}", exc_info=True) - return paper["abstract"] +def summarize_paper(paper: Dict[str, Any], config: Dict[str, Any]) -> str: + """Generate a summary of a single paper using the same batch engine.""" + summaries = summarize_papers([paper], config) + if summaries: + return summaries[0] + return str(paper.get("abstract") or "") diff --git a/src/paperweight/db.py b/src/paperweight/db.py index 41e7bac..3fd8ea2 100644 --- a/src/paperweight/db.py +++ b/src/paperweight/db.py @@ -3,9 +3,6 @@ from contextlib import contextmanager from typing import Any, Dict, Generator -import psycopg -from psycopg import Connection - class DatabaseConnectionError(RuntimeError): """Raised when a configured database is unreachable.""" @@ -19,7 +16,7 @@ def is_db_enabled(config: Dict[str, Any]) -> bool: @contextmanager def connect_db( db_config: Dict[str, Any], autocommit: bool = False -) -> Generator[Connection, None, None]: +) -> Generator: """Create a database connection. Args: @@ -30,6 +27,8 @@ def connect_db( Yields: A psycopg connection object. """ + import psycopg + conn = psycopg.connect( host=db_config["host"], port=db_config["port"], diff --git a/src/paperweight/logging_config.py b/src/paperweight/logging_config.py index 79d9ee1..5ade8f7 100644 --- a/src/paperweight/logging_config.py +++ b/src/paperweight/logging_config.py @@ -15,11 +15,11 @@ def setup_logging(logging_config): Args: logging_config: Dictionary containing logging configuration parameters including - 'level' and 'file' settings. + 'level' and optional 'file' settings. - The function configures both file and console handlers with the following features: + The function configures handlers with the following features: - Console handler with WARNING and above levels - - File handler with the configured level (defaults to INFO) + - File handler with the configured level (defaults to INFO) when 'file' is set - Standard format: timestamp - logger_name - level - message - Automatic creation of log directory if it doesn't exist """ @@ -28,12 +28,30 @@ def setup_logging(logging_config): if logging_level not in valid_levels: logging_level = "INFO" - log_file = logging_config["file"] - log_dir = os.path.dirname(log_file) - if log_dir and not os.path.exists(log_dir): - os.makedirs(log_dir, exist_ok=True) + log_file = logging_config.get("file") - logging_config = { + handlers = { + "console": { + "class": "logging.StreamHandler", + "formatter": "standard", + "level": "WARNING", + }, + } + active_handlers = ["console"] + + if log_file: + log_dir = os.path.dirname(log_file) + if log_dir and not os.path.exists(log_dir): + os.makedirs(log_dir, exist_ok=True) + handlers["file"] = { + "class": "logging.FileHandler", + "filename": log_file, + "formatter": "standard", + "level": logging_level, + } + active_handlers.append("file") + + dict_config = { "version": 1, "disable_existing_loggers": False, "formatters": { @@ -42,25 +60,13 @@ def setup_logging(logging_config): "datefmt": "%Y-%m-%d %H:%M:%S", }, }, - "handlers": { - "console": { - "class": "logging.StreamHandler", - "formatter": "standard", - "level": "WARNING", - }, - "file": { - "class": "logging.FileHandler", - "filename": log_file, - "formatter": "standard", - "level": logging_level, - }, - }, + "handlers": handlers, "root": { - "handlers": ["console", "file"], + "handlers": active_handlers, "level": logging_level, }, } - logging.config.dictConfig(logging_config) + logging.config.dictConfig(dict_config) logging.getLogger().setLevel(logging_level) diff --git a/src/paperweight/main.py b/src/paperweight/main.py index b951452..7a70c22 100644 --- a/src/paperweight/main.py +++ b/src/paperweight/main.py @@ -26,6 +26,7 @@ write_output, ) from paperweight.processor import process_papers +from paperweight.progress import ProgressReporter from paperweight.scraper import get_recent_papers, hydrate_papers_with_content from paperweight.storage import ( create_run, @@ -45,12 +46,6 @@ - cs.CL max_results: 50 -triage: - enabled: true - llm_provider: openai - min_score: 60 - max_selected: 25 - processor: keywords: - transformer @@ -63,21 +58,29 @@ content_keyword_weight: 1 exclusion_keyword_penalty: 5 important_words_weight: 0.5 - min_score: 10 + min_score: 3 analyzer: type: abstract - llm_provider: openai max_input_tokens: 7000 max_input_chars: 20000 +metadata_cache: + enabled: true + path: .paperweight_cache.json + ttl_hours: 4 + +concurrency: + content_fetch: 6 + triage: 3 + summary: 3 + logging: level: INFO - file: paperweight.log """ -def setup_and_get_papers(force_refresh, include_content=True, config_path="config.yaml"): +def setup_and_get_papers(force_refresh, include_content=True, config_path="config.yaml", profile=None): """Set up the application and fetch papers. Args: @@ -88,7 +91,7 @@ def setup_and_get_papers(force_refresh, include_content=True, config_path="confi Tuple of (papers, config) where papers is a list of paper dictionaries and config is the loaded configuration dictionary. """ - config = load_config(config_path=config_path) + config = load_config(config_path=config_path, profile=profile) setup_logging(config["logging"]) logger.info("Configuration loaded successfully") @@ -121,28 +124,45 @@ def get_summary_model(config): return None -def process_and_summarize_papers(recent_papers, config): - """Process and analyze papers based on configured criteria. +def score_papers(papers, config): + """Score papers based on configured criteria (title + abstract keywords). Args: - recent_papers: List of paper dictionaries to process. + papers: List of paper dictionaries to score. config: Configuration dictionary containing processing parameters. Returns: - List of processed papers with relevance scores and summaries. + List of scored papers above the min_score threshold, or None if empty. """ - if not recent_papers: + if not papers: logger.info("No new papers to process. Exiting.") return None - processed_papers = process_papers(recent_papers, config["processor"]) - logger.info(f"Processed {len(processed_papers)} papers") + processed_papers = process_papers(papers, config["processor"]) + logger.info(f"Scored {len(processed_papers)} papers above threshold") if not processed_papers: logger.info("No papers met the relevance criteria. Exiting.") return None - summaries = get_abstracts(processed_papers, config["analyzer"]) + return processed_papers + + +def summarize_scored_papers(processed_papers, config): + """Attach summaries to already-scored papers. + + Args: + processed_papers: List of scored paper dictionaries. + config: Configuration dictionary. + + Returns: + The same list with ``summary`` field attached, or None if input is empty. + """ + if not processed_papers: + return None + + summary_concurrency = config.get("concurrency", {}).get("summary") + summaries = get_abstracts(processed_papers, config["analyzer"], summary_concurrency=summary_concurrency) for paper, summary in zip(processed_papers, summaries): paper["summary"] = ( summary if summary else paper.get("abstract", "No summary available") @@ -151,6 +171,25 @@ def process_and_summarize_papers(recent_papers, config): return processed_papers +def process_and_summarize_papers(recent_papers, config): + """Process and analyze papers based on configured criteria. + + Convenience wrapper that scores then summarizes in one call. + Retained for backward compatibility. + + Args: + recent_papers: List of paper dictionaries to process. + config: Configuration dictionary containing processing parameters. + + Returns: + List of processed papers with relevance scores and summaries. + """ + scored = score_papers(recent_papers, config) + if scored is None: + return None + return summarize_scored_papers(scored, config) + + def _initialize_db_run(config, recent_papers): """Initialize a database run and persist paper metadata. @@ -246,9 +285,6 @@ def _handle_error(error, error_type): def _deliver_output(processed_papers, config, args): """Deliver processed papers via the requested adapter.""" - if args.max_items and args.max_items > 0: - processed_papers = processed_papers[: args.max_items] - if args.delivery == "stdout": digest = render_text_digest(processed_papers, sort_order=args.sort_order) write_output(digest, args.output) @@ -324,7 +360,18 @@ def _add_run_arguments(parser: argparse.ArgumentParser) -> None: "--max-items", type=int, default=0, - help="Optional cap on number of delivered papers (0 = no cap)", + help="Optional cap on papers to process and deliver (0 = no cap)", + ) + parser.add_argument( + "--profile", + type=str, + default=None, + help="Activate a named profile from the config's profiles section", + ) + parser.add_argument( + "--quiet", + action="store_true", + help="Suppress progress status lines on stderr", ) @@ -332,6 +379,11 @@ def _build_cli_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="paperweight: Fetch, triage, and summarize arXiv papers" ) + parser.add_argument( + "--version", + action="version", + version=f"paperweight {get_package_version()}", + ) subparsers = parser.add_subparsers(dest="command") run_parser = subparsers.add_parser("run", help="Run the paperweight pipeline") @@ -360,6 +412,12 @@ def _build_cli_parser() -> argparse.ArgumentParser: action="store_true", help="Return non-zero if any warnings are present", ) + doctor_parser.add_argument( + "--profile", + type=str, + default=None, + help="Activate a named profile from the config's profiles section", + ) return parser @@ -370,11 +428,14 @@ def _parse_args(argv: list[str] | None = None) -> argparse.Namespace: # Backward-compatible default: `paperweight [run-args]` == `paperweight run [run-args]` known_commands = {"run", "init", "doctor"} - if args_list and args_list[0] in {"-h", "--help"}: + if args_list and args_list[0] in {"-h", "--help", "--version"}: return parser.parse_args(args_list) if args_list and args_list[0] in known_commands: return parser.parse_args(args_list) + # TODO: simplify — the fallback parser duplicates _build_cli_parser's run + # arguments. Consider using parser.parse_known_args() or inserting "run" + # into args_list when no known subcommand is found. run_parser = argparse.ArgumentParser( description="paperweight: Fetch, triage, and summarize arXiv papers" ) @@ -399,7 +460,7 @@ def _write_minimal_config(path: str, force: bool = False) -> None: print(f"Wrote config: {target}") -def _doctor(config_path: str, strict: bool = False) -> int: +def _doctor(config_path: str, strict: bool = False, profile: str | None = None) -> int: results: list[tuple[str, str, str]] = [] config_file = Path(config_path) @@ -411,15 +472,19 @@ def _doctor(config_path: str, strict: bool = False) -> int: return 1 try: - config = load_config(config_path=config_path) + config = load_config(config_path=config_path, profile=profile) results.append(("OK", "config parse", "Loaded and validated")) except Exception as e: results.append(("FAIL", "config parse", str(e))) _print_doctor(results) return 1 + active_profile = config.get("active_profile") + if active_profile: + results.append(("OK", "profile", active_profile)) + triage_cfg = config.get("triage", {}) - triage_enabled = triage_cfg.get("enabled", True) + triage_enabled = triage_cfg.get("enabled", False) triage_provider = ( triage_cfg.get("llm_provider") or config.get("analyzer", {}).get("llm_provider") @@ -458,7 +523,7 @@ def _print_doctor(results: list[tuple[str, str, str]]) -> None: print(f"[{status}] {check}: {detail}") -def _run_pipeline(args: argparse.Namespace) -> int: +def _run_pipeline(args: argparse.Namespace) -> int: # noqa: C901 config = None run_id = None paper_id_map = {} @@ -467,19 +532,73 @@ def _run_pipeline(args: argparse.Namespace) -> int: db_enabled = False had_error = False + progress = ProgressReporter(quiet=getattr(args, "quiet", False)) + try: + # 1. Metadata (cached by default) + progress.phase("fetching metadata...") recent_papers, config = setup_and_get_papers( args.force_refresh, include_content=False, config_path=args.config, + profile=getattr(args, "profile", None), + ) + if args.max_items and args.max_items > 0 and len(recent_papers) > args.max_items: + logger.info( + "Applying max-items compute cap: processing first %s of %s fetched papers", + args.max_items, + len(recent_papers), + ) + recent_papers = recent_papers[: args.max_items] + + categories = config.get("arxiv", {}).get("categories", []) + progress.phase_end( + "fetching metadata...", + f"{len(recent_papers)} papers ({len(categories)} categories)", ) - shortlisted_papers = _apply_triage_and_hydrate(recent_papers, config) + + # 2. Triage (title + abstract — no content needed) + progress.phase("triaging...") + triaged_papers = triage_papers(recent_papers, config) + if not triaged_papers: + logger.info("AI triage selected no papers. Exiting.") + triaged_papers = [] + progress.phase_end("triaging...", f"{len(triaged_papers)}/{len(recent_papers)} selected") + + # 3. Score (title + abstract keywords — no content needed) + progress.phase("scoring...") + scored_papers = score_papers(triaged_papers, config) + if not scored_papers and triaged_papers: + threshold = config.get("processor", {}).get("min_score", 0) + progress.phase_end( + "scoring...", + f"0/{len(triaged_papers)} above min_score ({threshold}) — " + "try adding keywords or lowering processor.min_score", + ) + else: + progress.phase_end( + "scoring...", + f"{len(scored_papers)} papers above threshold" if scored_papers else "0 papers above threshold", + ) + + # 4. Hydrate ONLY if analyzer needs full content (summary mode) + if scored_papers and config.get("analyzer", {}).get("type") == "summary": + progress.phase("fetching content...") + scored_papers = hydrate_papers_with_content(scored_papers, config) + progress.phase_end("fetching content...", f"{len(scored_papers)} hydrated") + db_enabled = is_db_enabled(config) - if db_enabled: - run_id, paper_id_map = _initialize_db_run(config, shortlisted_papers) + if db_enabled and scored_papers: + run_id, paper_id_map = _initialize_db_run(config, scored_papers) - processed_papers = process_and_summarize_papers(shortlisted_papers, config) + # 5. Summarize (abstract passthrough or LLM) + if scored_papers and config.get("analyzer", {}).get("type") == "summary": + progress.phase("summarizing...") + processed_papers = summarize_scored_papers(scored_papers, config) + if scored_papers and config.get("analyzer", {}).get("type") == "summary": + count = len(processed_papers) if processed_papers else 0 + progress.phase_end("summarizing...", f"{count}/{len(scored_papers)} done") if db_enabled and run_id and processed_papers: _persist_results(config, run_id, processed_papers, paper_id_map) @@ -487,6 +606,9 @@ def _run_pipeline(args: argparse.Namespace) -> int: if processed_papers: _deliver_output(processed_papers, config, args) + delivered = len(processed_papers) if processed_papers else 0 + progress.phase_end("done —", f"{delivered} papers delivered to stdout") + run_status = "success" except ( requests.RequestException, @@ -519,15 +641,24 @@ def main(argv: list[str] | None = None) -> int: args.output = getattr(args, "output", None) args.sort_order = getattr(args, "sort_order", "relevance") args.max_items = getattr(args, "max_items", 0) + args.profile = getattr(args, "profile", None) + args.quiet = getattr(args, "quiet", False) if args.command == "init": - _write_minimal_config(args.config, force=args.force) - return 0 + try: + _write_minimal_config(args.config, force=args.force) + return 0 + except ValueError as exc: + print(f"paperweight init: {exc}", file=sys.stderr) + return 1 if args.command == "doctor": - return _doctor(args.config, strict=getattr(args, "strict", False)) + return _doctor(args.config, strict=getattr(args, "strict", False), profile=getattr(args, "profile", None)) return _run_pipeline(args) if __name__ == "__main__": + # TODO: the broad except here is redundant with error handling inside + # main() / _run_pipeline(). Consider removing once all CLI paths + # return clean exit codes on error. try: sys.exit(main()) except Exception as e: diff --git a/src/paperweight/notifier.py b/src/paperweight/notifier.py index 1d9f962..fc62f32 100644 --- a/src/paperweight/notifier.py +++ b/src/paperweight/notifier.py @@ -50,12 +50,30 @@ def render_text_digest( for idx, paper in enumerate(ordered, start=1): score = paper.get("relevance_score", paper.get("triage_score", 0.0)) lines.append(f"{idx}. {paper.get('title', 'Untitled')}") + + authors = paper.get("authors", []) + if authors: + display = ", ".join(authors[:3]) + if len(authors) > 3: + display += f" +{len(authors) - 3} more" + lines.append(f" Authors: {display}") + lines.append(f" Date: {_format_paper_date(paper)}") lines.append(f" Score: {score:.2f}") + + matched = paper.get("keywords_matched", []) + if matched: + lines.append(f" Matched: {', '.join(matched)}") + if paper.get("triage_rationale"): lines.append(f" Why: {paper.get('triage_rationale')}") lines.append(f" Link: {paper.get('link', '')}") - lines.append(f" Summary: {(paper.get('summary') or '').strip()}") + + summary = (paper.get("summary") or "").strip() + if not summary: + summary = (paper.get("abstract") or "").strip() + if summary: + lines.append(f" Summary: {summary}") lines.append("") return "\n".join(lines).rstrip() + "\n" @@ -97,6 +115,14 @@ def render_atom_feed( ET.SubElement(entry, f"{{{ns}}}updated").text = updated if link: ET.SubElement(entry, f"{{{ns}}}link", {"href": link, "rel": "alternate"}) + + for author_name in paper.get("authors", []): + author_el = ET.SubElement(entry, f"{{{ns}}}author") + ET.SubElement(author_el, f"{{{ns}}}name").text = author_name + + for cat in paper.get("categories", []): + ET.SubElement(entry, f"{{{ns}}}category", {"term": cat}) + ET.SubElement(entry, f"{{{ns}}}summary").text = summary ET.SubElement(entry, f"{{{ns}}}content", {"type": "text"}).text = ( f"Score: {score:.2f}\nWhy: {rationale}\nLink: {link}\nSummary: {summary}" @@ -113,16 +139,24 @@ def render_json_digest( ordered = _sort_papers(papers, sort_order) payload = [] for paper in ordered: - payload.append( - { - "title": paper.get("title", "Untitled"), - "date": _format_paper_date(paper), - "score": paper.get("relevance_score", paper.get("triage_score", 0.0)), - "why": paper.get("triage_rationale", ""), - "link": paper.get("link", ""), - "summary": (paper.get("summary") or "").strip(), - } - ) + record = { + "title": paper.get("title", "Untitled"), + "arxiv_id": paper.get("id", ""), + "authors": paper.get("authors", []), + "categories": paper.get("categories", []), + "published": _format_paper_date(paper), + "abstract": paper.get("abstract", ""), + "link": paper.get("link", ""), + "pdf_url": paper.get("pdf_url", ""), + "score": paper.get("relevance_score", paper.get("triage_score", 0.0)), + "keywords_matched": paper.get("keywords_matched", []), + } + if "triage_score" in paper: + record["triage_score"] = paper["triage_score"] + record["triage_rationale"] = paper.get("triage_rationale", "") + if paper.get("summary") and paper.get("summary") != paper.get("abstract"): + record["summary"] = paper["summary"] + payload.append(record) return json.dumps(payload, indent=2, ensure_ascii=True) diff --git a/src/paperweight/processor.py b/src/paperweight/processor.py index d8c2aaf..d553a8f 100644 --- a/src/paperweight/processor.py +++ b/src/paperweight/processor.py @@ -33,6 +33,7 @@ def process_papers( if score >= processor_config["min_score"]: paper["relevance_score"] = score paper["score_breakdown"] = score_breakdown + paper["keywords_matched"] = score_breakdown.get("keywords_matched", []) processed_papers.append(paper) else: logger.debug( @@ -90,20 +91,20 @@ def calculate_paper_score(paper, config): abstract = paper.get("abstract", "") content = paper.get("content", "") - title_keywords = count_keywords(title, config["keywords"]) - abstract_keywords = count_keywords(abstract, config["keywords"]) - content_keywords = count_keywords(content, config["keywords"]) + title_kw_score, title_matched = count_keywords(title, config["keywords"]) + abstract_kw_score, abstract_matched = count_keywords(abstract, config["keywords"]) + content_kw_score, _ = count_keywords(content, config["keywords"]) max_title_score = 50 max_abstract_score = 50 max_content_score = 25 - title_score = min(title_keywords * config["title_keyword_weight"], max_title_score) + title_score = min(title_kw_score * config["title_keyword_weight"], max_title_score) abstract_score = min( - abstract_keywords * config["abstract_keyword_weight"], max_abstract_score + abstract_kw_score * config["abstract_keyword_weight"], max_abstract_score ) content_score = min( - content_keywords * config["content_keyword_weight"], max_content_score + content_kw_score * config["content_keyword_weight"], max_content_score ) score += title_score + abstract_score + content_score @@ -112,11 +113,12 @@ def calculate_paper_score(paper, config): "abstract": round(abstract_score, 2), "content": round(content_score, 2), } + score_breakdown["keywords_matched"] = sorted(set(title_matched + abstract_matched)) # Exclusion list - exclusion_count = count_keywords(content, config["exclusion_keywords"]) + exclusion_score_raw, _ = count_keywords(content, config["exclusion_keywords"]) exclusion_score = min( - exclusion_count * config["exclusion_keyword_penalty"], max_content_score + exclusion_score_raw * config["exclusion_keyword_penalty"], max_content_score ) score -= exclusion_score score_breakdown["exclusion_penalty"] = -round(exclusion_score, 2) @@ -140,11 +142,18 @@ def count_keywords(text, keywords): keywords: List of keywords to count. Returns: - Dictionary mapping keywords to their occurrence counts. + Tuple of (score, matched_list) where score is a float and matched_list + contains the keywords that were found. """ - return sum( - math.log(text.lower().count(keyword.lower()) + 1) for keyword in keywords - ) + text_lower = text.lower() + matched = [] + score = 0.0 + for keyword in keywords: + count = text_lower.count(keyword.lower()) + if count > 0: + matched.append(keyword) + score += math.log(count + 1) + return score, matched def count_important_words(text, important_words): diff --git a/src/paperweight/progress.py b/src/paperweight/progress.py new file mode 100644 index 0000000..84a0a06 --- /dev/null +++ b/src/paperweight/progress.py @@ -0,0 +1,35 @@ +"""Lightweight progress reporting to stderr for pipeline phases.""" + +import sys +import time + + +class ProgressReporter: + """Write ``paperweight: phase... detail`` lines to stderr. + + Suppressed entirely when *quiet* is ``True``. + """ + + def __init__(self, *, quiet: bool = False): + self._quiet = quiet + self._phase_start: float | None = None + + def phase(self, label: str, detail: str = "") -> None: + """Begin a new phase, printing its status line immediately.""" + self._phase_start = time.monotonic() + self._emit(label, detail) + + def phase_end(self, label: str, detail: str = "") -> None: + """End the current phase, appending elapsed time.""" + elapsed = "" + if self._phase_start is not None: + secs = time.monotonic() - self._phase_start + elapsed = f" ({secs:.0f}s)" + self._emit(label, f"{detail}{elapsed}") + self._phase_start = None + + def _emit(self, label: str, detail: str) -> None: + if self._quiet: + return + suffix = f" {detail}" if detail else "" + print(f"paperweight: {label}{suffix}", file=sys.stderr, flush=True) diff --git a/src/paperweight/scraper.py b/src/paperweight/scraper.py index 0152e8f..3e168f1 100644 --- a/src/paperweight/scraper.py +++ b/src/paperweight/scraper.py @@ -8,16 +8,16 @@ import gzip import hashlib import io +import json import logging import os import tarfile -import time +from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import date, datetime, timedelta from typing import Any, Dict, List, Optional import arxiv import requests -from pypdf import PdfReader from tenacity import ( retry, retry_if_exception_type, @@ -87,12 +87,17 @@ def fetch_arxiv_papers( ) break + arxiv_id, _ = split_arxiv_id(result.entry_id) papers.append( { "title": result.title, "link": result.entry_id, "date": submitted_date, "abstract": result.summary, + "authors": [a.name for a in result.authors], + "categories": list(result.categories), + "pdf_url": result.pdf_url, + "id": arxiv_id, } ) @@ -129,34 +134,46 @@ def fetch_recent_papers(config, start_days=1): logger.info(f"Fetching papers from {start_date} to {end_date}") - all_papers = [] - processed_ids = set() - - for category in categories: + def _fetch_category(category): logger.info(f"Processing category: {category}") try: - papers = fetch_arxiv_papers( + return category, fetch_arxiv_papers( category, start_date, max_results=max_results if max_results > 0 else None, ) - new_papers = [ - paper - for paper in papers - if paper["link"].split("/abs/")[-1] not in processed_ids - ] - processed_ids.update( - paper["link"].split("/abs/")[-1] for paper in new_papers - ) - - if max_results > 0: - new_papers = new_papers[:max_results] - - all_papers.extend(new_papers) - logger.debug(f"Added {len(new_papers)} new papers from category {category}") except ValueError as ve: logger.error(f"Error fetching papers for category {category}: {ve}") - continue + return category, [] + + all_papers = [] + processed_ids: set = set() + + workers = min(len(categories), 4) if categories else 1 + with ThreadPoolExecutor(max_workers=workers) as executor: + futures = {executor.submit(_fetch_category, cat): cat for cat in categories} + # Collect in submission order for deterministic results + results_by_cat = {} + for future in as_completed(futures): + cat, papers = future.result() + results_by_cat[cat] = papers + + for category in categories: + papers = results_by_cat.get(category, []) + new_papers = [ + paper + for paper in papers + if paper["link"].split("/abs/")[-1] not in processed_ids + ] + processed_ids.update( + paper["link"].split("/abs/")[-1] for paper in new_papers + ) + + if max_results > 0: + new_papers = new_papers[:max_results] + + all_papers.extend(new_papers) + logger.debug(f"Added {len(new_papers)} new papers from category {category}") logger.info(f"Fetched a total of {len(all_papers)} papers") return all_papers @@ -218,6 +235,8 @@ def extract_text_from_pdf(pdf_content): Returns: Extracted text as a string. """ + from pypdf import PdfReader + pdf_file = io.BytesIO(pdf_content) pdf_reader = PdfReader(pdf_file) text = "" @@ -271,37 +290,42 @@ def extract_text_from_source(content, method): return decompressed.decode("utf-8", errors="ignore") -def fetch_paper_contents(paper_ids): +def fetch_paper_contents(paper_ids, max_workers=6): """Fetch contents for multiple papers in parallel. Args: paper_ids: List of arXiv paper IDs to fetch. + max_workers: Maximum number of concurrent download threads. Returns: - Dictionary mapping paper IDs to their content. + List of (paper_id, content, method) tuples, in the same order as *paper_ids*. """ - contents = [] total_papers = len(paper_ids) - logger.info(f"Fetching content for {total_papers} papers") - for i, paper_id in enumerate(paper_ids): + logger.info(f"Fetching content for {total_papers} papers (workers={max_workers})") + + results: List[Any] = [None] * total_papers + index_by_id = {pid: i for i, pid in enumerate(paper_ids)} + + def _fetch(paper_id): try: content, method = fetch_paper_content(paper_id) - contents.append((paper_id, content, method)) + return paper_id, content, method except Exception as e: logger.error(f"Error fetching content for paper ID {paper_id}: {e}") - contents.append((paper_id, None, None)) - - if (i + 1) % 4 == 0: - time.sleep(1) - logger.debug( - f"Processed {i + 1}/{total_papers} papers. Waiting 1 second..." - ) - - if (i + 1) % 20 == 0: - logger.info(f"Processed {i + 1}/{total_papers} papers") + return paper_id, None, None + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = {executor.submit(_fetch, pid): pid for pid in paper_ids} + completed = 0 + for future in as_completed(futures): + paper_id, content, method = future.result() + results[index_by_id[paper_id]] = (paper_id, content, method) + completed += 1 + if completed % 20 == 0: + logger.info(f"Fetched {completed}/{total_papers} papers") logger.info(f"Finished fetching content for all {total_papers} papers") - return contents + return results def _hydrate_papers_with_content(papers, config, db_enabled): @@ -309,8 +333,9 @@ def _hydrate_papers_with_content(papers, config, db_enabled): if not papers: return [] + max_workers = config.get("concurrency", {}).get("content_fetch", 6) paper_ids = [paper["link"].split("/abs/")[-1] for paper in papers] - contents = fetch_paper_contents(paper_ids) + contents = fetch_paper_contents(paper_ids, max_workers=max_workers) papers_with_content = [] storage_base = config.get("storage", {}).get("base_dir", "data/artifacts") @@ -344,7 +369,85 @@ def hydrate_papers_with_content(papers, config): return _hydrate_papers_with_content(papers, config, db_enabled) -def get_recent_papers(config, force_refresh=False, include_content=True): +def _int_setting(value, default, *, minimum=0): + try: + parsed = int(value) + except (TypeError, ValueError): + parsed = default + return max(minimum, parsed) + + +def _metadata_cache_options(config): + """Return (enabled, path, ttl_hours) from config['metadata_cache'].""" + mc = config.get("metadata_cache", {}) + enabled = mc.get("enabled", True) + path = mc.get("path", ".paperweight_cache.json") + ttl_hours = _int_setting(mc.get("ttl_hours"), 4, minimum=0) + return enabled, path, ttl_hours + + +def _metadata_cache_key(config): + """Build a stable key from the parameters that affect which papers are fetched.""" + cats = sorted(config.get("arxiv", {}).get("categories", [])) + max_r = config.get("arxiv", {}).get("max_results", 0) + today = datetime.now().date().isoformat() + raw = f"{cats}|{max_r}|{today}" + return hashlib.sha256(raw.encode()).hexdigest()[:16] + + +def _serialize_metadata_papers(papers): + """Convert paper list to JSON-safe form (dates become ISO strings).""" + out = [] + for p in papers: + rec = dict(p) + if isinstance(rec.get("date"), date): + rec["date"] = rec["date"].isoformat() + out.append(rec) + return out + + +def _deserialize_metadata_papers(records): + """Restore paper list from JSON-safe form.""" + out = [] + for rec in records: + rec = dict(rec) + if isinstance(rec.get("date"), str): + rec["date"] = datetime.strptime(rec["date"], "%Y-%m-%d").date() + out.append(rec) + return out + + +def _load_metadata_cache(cache_path, expected_key, ttl_hours): + """Return cached papers or None if cache is missing/stale/corrupt.""" + try: + with open(cache_path, "r", encoding="utf-8") as f: + data = json.load(f) + if data.get("key") != expected_key: + return None + written = datetime.fromisoformat(data["written_at"]) + if (datetime.now() - written).total_seconds() > ttl_hours * 3600: + return None + return _deserialize_metadata_papers(data["papers"]) + except (OSError, json.JSONDecodeError, KeyError, ValueError): + return None + + +def _write_metadata_cache(cache_path, key, papers): + """Write paper metadata to the cache file.""" + payload = { + "key": key, + "written_at": datetime.now().isoformat(), + "papers": _serialize_metadata_papers(papers), + } + try: + with open(cache_path, "w", encoding="utf-8") as f: + json.dump(payload, f) + logger.debug("Wrote metadata cache to %s (%d papers)", cache_path, len(papers)) + except OSError as e: + logger.warning("Could not write metadata cache: %s", e) + + +def get_recent_papers(config, force_refresh=False, include_content=True): # noqa: C901 """Get recent papers, either from cache or by fetching new ones. Args: @@ -370,24 +473,37 @@ def get_recent_papers(config, force_refresh=False, include_content=True): current_date = datetime.now().date() logger.info(f"Current date: {current_date}") - if last_processed_date is None or force_refresh: - # If never run before, fetch papers from the last 7 days - days = 7 - logger.info("First run detected. Fetching papers from the last 7 days.") - else: - days = (current_date - last_processed_date).days - if days == 0: - logger.info("Already processed papers for today. No new papers to fetch.") - return [] - elif days > 7: - # If more than a week has passed, limit to 7 days to avoid overload + # Metadata cache: check before computing days so same-day runs can hit cache + cache_enabled, cache_path, cache_ttl = _metadata_cache_options(config) + cache_key = _metadata_cache_key(config) + recent_papers = None + if cache_enabled and not force_refresh: + recent_papers = _load_metadata_cache(cache_path, cache_key, cache_ttl) + if recent_papers is not None: + logger.info("Loaded %d papers from metadata cache", len(recent_papers)) + + if recent_papers is None: + if last_processed_date is None or force_refresh: + # If never run before, fetch papers from the last 7 days days = 7 - logger.warning( - f"More than a week since last run. Limiting fetch to last {days} days." - ) + logger.info("First run detected. Fetching papers from the last 7 days.") + else: + days = (current_date - last_processed_date).days + if days == 0: + logger.info("Already processed papers for today. No new papers to fetch.") + return [] + elif days > 7: + # If more than a week has passed, limit to 7 days to avoid overload + days = 7 + logger.warning( + f"More than a week since last run. Limiting fetch to last {days} days." + ) + + logger.info(f"Fetching papers for the last {days} days") + recent_papers = fetch_recent_papers(config, days) + if cache_enabled: + _write_metadata_cache(cache_path, cache_key, recent_papers) - logger.info(f"Fetching papers for the last {days} days") - recent_papers = fetch_recent_papers(config, days) logger.info(f"Fetched {len(recent_papers)} recent papers") papers_result = recent_papers @@ -396,11 +512,11 @@ def get_recent_papers(config, force_refresh=False, include_content=True): else: papers_result = [] for paper in recent_papers: - paper_id = paper["link"].split("/abs/")[-1] paper_without_content = dict(paper) + if "id" not in paper_without_content: + paper_without_content["id"] = paper["link"].split("/abs/")[-1] paper_without_content.update( { - "id": paper_id, "content": "", "content_type": None, "artifacts": [], diff --git a/src/paperweight/utils.py b/src/paperweight/utils.py index b1e5e3c..bc02dee 100644 --- a/src/paperweight/utils.py +++ b/src/paperweight/utils.py @@ -11,17 +11,37 @@ import logging import os import re +from copy import deepcopy from datetime import datetime from importlib.metadata import PackageNotFoundError from importlib.metadata import version as pkg_version -import tiktoken import yaml from dotenv import load_dotenv LAST_PROCESSED_DATE_FILE = "last_processed_date.txt" DEFAULT_ARXIV_VERSION = "v0" +DEFAULT_CONFIG = { + "arxiv": {"categories": [], "max_results": 50}, + "processor": { + "keywords": [], + "exclusion_keywords": [], + "important_words": [], + "title_keyword_weight": 3, + "abstract_keyword_weight": 2, + "content_keyword_weight": 1, + "exclusion_keyword_penalty": 5, + "important_words_weight": 0.5, + "min_score": 3, + }, + "analyzer": {"type": "abstract", "max_input_tokens": 7000, "max_input_chars": 20000}, + "triage": {"enabled": False}, + "logging": {"level": "INFO"}, + "metadata_cache": {"enabled": True, "path": ".paperweight_cache.json", "ttl_hours": 4}, + "concurrency": {"content_fetch": 6, "triage": 3, "summary": 3}, +} + logger = logging.getLogger(__name__) @@ -82,7 +102,29 @@ def _coerce(env_value: str, current_value): return config -def load_config(config_path="config.yaml"): +def _deep_merge_dicts(base, override): + """Recursively merge *override* into a deep copy of *base*.""" + merged = deepcopy(base) + for key, value in override.items(): + if key in merged and isinstance(merged[key], dict) and isinstance(value, dict): + merged[key] = _deep_merge_dicts(merged[key], value) + else: + merged[key] = deepcopy(value) + return merged + + +def apply_profile(config, profile_name): + """Apply a named profile on top of *config* and return the merged result.""" + profiles = config.get("profiles", {}) + if profile_name not in profiles: + raise ValueError(f"Unknown profile: '{profile_name}'") + overlay = profiles[profile_name] + merged = _deep_merge_dicts(config, overlay) + merged["active_profile"] = profile_name + return merged + + +def load_config(config_path="config.yaml", profile=None): # noqa: C901 """Load and validate the application configuration. Args: @@ -100,11 +142,20 @@ def load_config(config_path="config.yaml"): load_dotenv() with open(config_path, "r") as config_file: - config = yaml.safe_load(config_file) - if config is None: + raw_config = yaml.safe_load(config_file) + if raw_config is None: raise ValueError("Empty configuration file") + # Merge user YAML over DEFAULT_CONFIG so every key has a safe default + config = _deep_merge_dicts(DEFAULT_CONFIG, raw_config) + config = expand_env_vars(config) + + # Profile switching: CLI flag > env var > none + profile_name = profile or os.environ.get("PAPERWEIGHT_PROFILE") + if profile_name: + config = apply_profile(config, profile_name) + config = override_with_env(config) # Handle API keys @@ -120,8 +171,6 @@ def load_config(config_path="config.yaml"): config["analyzer"]["api_key"] = api_key else: raise ValueError(f"Missing API key for {llm_provider}") - else: - pass if "arxiv" in config and "max_results" in config["arxiv"]: config["arxiv"]["max_results"] = int(config["arxiv"]["max_results"]) @@ -171,6 +220,12 @@ def check_config(config): _check_db_section(config["db"]) if "storage" in config: _check_storage_section(config["storage"]) + if "metadata_cache" in config: + _check_metadata_cache_section(config["metadata_cache"]) + if "concurrency" in config: + _check_concurrency_section(config["concurrency"]) + if "profiles" in config: + _check_profiles_section(config["profiles"]) except KeyError as e: raise ValueError(f"Missing required section or key: {e}") @@ -320,6 +375,47 @@ def _check_storage_section(storage): raise ValueError("Missing required storage field: 'base_dir'") +def _check_metadata_cache_section(mc): + """Validate the metadata_cache section of the configuration.""" + if not isinstance(mc, dict): + raise ValueError("'metadata_cache' must be a mapping") + if "ttl_hours" in mc: + try: + val = int(mc["ttl_hours"]) + except (TypeError, ValueError): + raise ValueError("'ttl_hours' in 'metadata_cache' must be a valid integer") + if val < 0: + raise ValueError("'ttl_hours' in 'metadata_cache' must be non-negative") + + +def _check_concurrency_section(concurrency): + """Validate the concurrency section of the configuration.""" + if not isinstance(concurrency, dict): + raise ValueError("'concurrency' must be a mapping") + limits = { + "content_fetch": (1, 20), + "triage": (1, 10), + "summary": (1, 10), + } + for key, (lo, hi) in limits.items(): + if key in concurrency: + try: + val = int(concurrency[key]) + except (TypeError, ValueError): + raise ValueError(f"'{key}' in 'concurrency' must be a valid integer") + if val < lo or val > hi: + raise ValueError(f"'{key}' in 'concurrency' must be between {lo} and {hi}") + + +def _check_profiles_section(profiles): + """Validate the profiles section of the configuration.""" + if not isinstance(profiles, dict): + raise ValueError("'profiles' must be a mapping") + for name, overlay in profiles.items(): + if not isinstance(overlay, dict): + raise ValueError(f"Profile '{name}' must be a mapping") + + def is_valid_arxiv_category(category): """Check if an arXiv category string is valid. @@ -373,6 +469,8 @@ def count_tokens(text): Returns: int: Number of tokens in the text. """ + import tiktoken + encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") return len(encoding.encode(text, allowed_special={"<|endoftext|>"})) diff --git a/tests/conftest.py b/tests/conftest.py index cf737a7..72cdd52 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -102,9 +102,11 @@ def base_test_config(tmp_path: Path) -> Dict[str, Any]: "use_auth": False, } }, + "triage": { + "enabled": False, + }, "logging": { "level": "DEBUG", - "file": str(tmp_path / "test_paperweight.log"), }, "db": { "enabled": False, diff --git a/tests/test_analyzer.py b/tests/test_analyzer.py index 6e37b58..8f6dce5 100644 --- a/tests/test_analyzer.py +++ b/tests/test_analyzer.py @@ -1,9 +1,11 @@ """Tests for the paper analyzer/summarization module. This file tests the LLM boundary: how paperweight interacts with -external LLM providers to generate summaries, including fallback behavior. +external LLM providers to generate summaries. """ +from unittest.mock import AsyncMock + import pytest from paperweight.analyzer import get_abstracts, summarize_paper, triage_papers @@ -12,21 +14,41 @@ class TestSummarizePaper: """Tests for paper summarization with LLM providers.""" + def test_summarize_success(self, mocker): + """Summarization returns model output with valid provider/key.""" + # Mock Pollux's async run() function + mock_result = {"answers": ["This is a summary of the paper."], "status": "ok"} + mocker.patch("pollux.run", new=AsyncMock(return_value=mock_result)) + + paper = { + "title": "Test Paper", + "abstract": "This is the abstract.", + "content": "This is the full content of the paper.", + } + config = { + "type": "summary", + "llm_provider": "openai", + "api_key": "fake_api_key", + } + + result = summarize_paper(paper, config) + assert result == "This is a summary of the paper." + @pytest.mark.parametrize( - "llm_provider, api_key, expected_result", + "llm_provider, api_key", [ - ("openai", "fake_api_key", "This is a summary of the paper."), - ("openai", None, "This is the abstract."), - ("invalid_provider", "fake_api_key", "This is the abstract."), + ("openai", None), + ("invalid_provider", "fake_api_key"), ], ) - def test_summarize_with_fallback( - self, llm_provider, api_key, expected_result, mocker + def test_summarize_requires_valid_provider_and_key( + self, llm_provider, api_key, mocker, monkeypatch ): - """Summarization falls back to abstract when LLM unavailable.""" - # Mock Pollux's async run() function + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + monkeypatch.delenv("GEMINI_API_KEY", raising=False) + mock_result = {"answers": ["This is a summary of the paper."], "status": "ok"} - mocker.patch("paperweight.analyzer.run", return_value=mock_result) + mocker.patch("pollux.run", new=AsyncMock(return_value=mock_result)) paper = { "title": "Test Paper", @@ -39,8 +61,25 @@ def test_summarize_with_fallback( "api_key": api_key, } + with pytest.raises(ValueError, match="Summary analyzer requires"): + summarize_paper(paper, config) + + def test_summarize_falls_back_to_abstract_when_model_returns_no_answers(self, mocker): + mocker.patch("pollux.run", new=AsyncMock(return_value={"answers": []})) + + paper = { + "title": "Test Paper", + "abstract": "This is the abstract.", + "content": "This is the full content of the paper.", + } + config = { + "type": "summary", + "llm_provider": "openai", + "api_key": "fake_api_key", + } + result = summarize_paper(paper, config) - assert result == expected_result + assert result == "This is the abstract." class TestGetAbstracts: @@ -58,12 +97,14 @@ class TestTriagePapers: def test_triage_uses_llm_decision(self, mocker): mocker.patch( - "paperweight.analyzer.run", - return_value={ - "answers": [ - '{"include": true, "score": 92, "rationale": "Strong profile match"}' - ] - }, + "pollux.run", + new=AsyncMock( + return_value={ + "answers": [ + '{"include": true, "score": 92, "rationale": "Strong profile match"}' + ] + } + ), ) papers = [ { @@ -82,6 +123,35 @@ def test_triage_uses_llm_decision(self, mocker): assert shortlisted[0]["triage_score"] == 92 assert "Strong profile match" in shortlisted[0]["triage_rationale"] + def test_triage_falls_back_for_entire_batch_when_llm_errors(self, mocker): + mocker.patch( + "pollux.run", + new=AsyncMock(side_effect=RuntimeError("provider unavailable")), + ) + papers = [ + { + "title": "Transformers for Agents", + "abstract": "A paper about language agents and planning.", + "link": "http://arxiv.org/abs/2401.12345", + }, + { + "title": "Graph Theory for Chemistry", + "abstract": "A paper about graph properties in molecules.", + "link": "http://arxiv.org/abs/2401.67890", + }, + ] + config = { + "triage": {"enabled": True, "llm_provider": "openai", "api_key": "key"}, + "processor": {"keywords": ["agents", "planning"]}, + "analyzer": {}, + } + shortlisted = triage_papers(papers, config) + assert len(shortlisted) == 1 + assert shortlisted[0]["title"] == "Transformers for Agents" + assert all( + "heuristic fallback" in paper["triage_rationale"].lower() for paper in papers + ) + def test_triage_falls_back_without_api_key(self): papers = [ { @@ -98,3 +168,24 @@ def test_triage_falls_back_without_api_key(self): shortlisted = triage_papers(papers, config) assert len(shortlisted) == 1 assert shortlisted[0]["triage_score"] >= 10 + + def test_triage_invalid_threshold_values_do_not_crash(self): + papers = [ + { + "title": "Transformers for Agents", + "abstract": "A paper about language agents and planning.", + "link": "http://arxiv.org/abs/2401.12345", + } + ] + config = { + "triage": { + "enabled": True, + "llm_provider": "openai", + "min_score": "invalid", + "max_selected": "invalid", + }, + "processor": {"keywords": ["agents"]}, + "analyzer": {"type": "abstract"}, + } + shortlisted = triage_papers(papers, config) + assert len(shortlisted) == 1 diff --git a/tests/test_cli.py b/tests/test_cli.py index 06eb29f..ffbf84f 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -12,13 +12,36 @@ def test_init_writes_minimal_config(tmp_path, monkeypatch): assert "arxiv:" in config_path.read_text(encoding="utf-8") -def test_init_does_not_overwrite_without_force(tmp_path, monkeypatch): +def test_init_does_not_overwrite_without_force(tmp_path, monkeypatch, capsys): monkeypatch.chdir(tmp_path) config_path = tmp_path / "config.yaml" config_path.write_text("existing: true\n", encoding="utf-8") - with pytest.raises(ValueError, match="already exists"): - main(["init"]) + exit_code = main(["init"]) + assert exit_code == 1 + stderr = capsys.readouterr().err + assert "already exists" in stderr + assert "paperweight init:" in stderr + + +def test_version_flag(): + with pytest.raises(SystemExit) as exc_info: + main(["--version"]) + assert exc_info.value.code == 0 + + +def test_import_public_api(): + import paperweight + + assert hasattr(paperweight, "__version__") + assert isinstance(paperweight.__version__, str) + assert paperweight.__version__ != "" + assert callable(paperweight.load_config) + assert callable(paperweight.score_papers) + assert callable(paperweight.get_recent_papers) + assert callable(paperweight.setup_and_get_papers) + assert callable(paperweight.process_and_summarize_papers) + assert callable(paperweight.summarize_scored_papers) def test_doctor_reports_missing_config(tmp_path): @@ -33,7 +56,7 @@ def test_doctor_success_with_loaded_config(tmp_path, monkeypatch): monkeypatch.setattr( "paperweight.main.load_config", - lambda config_path: { + lambda config_path, profile=None: { "arxiv": {"categories": ["cs.AI"]}, "processor": {"keywords": ["agents"]}, "analyzer": {"type": "abstract"}, @@ -54,7 +77,7 @@ def test_doctor_strict_fails_on_warning(tmp_path, monkeypatch): monkeypatch.setattr( "paperweight.main.load_config", - lambda config_path: { + lambda config_path, profile=None: { "arxiv": {"categories": ["cs.AI"]}, "processor": {"keywords": ["agents"]}, "analyzer": {"type": "abstract", "llm_provider": "openai"}, @@ -65,3 +88,27 @@ def test_doctor_strict_fails_on_warning(tmp_path, monkeypatch): exit_code = main(["doctor", "--config", str(config_path), "--strict"]) assert exit_code == 1 + + +def test_doctor_passes_profile_to_config_loader(tmp_path, monkeypatch): + config_path = tmp_path / "config.yaml" + config_path.write_text("placeholder: true\n", encoding="utf-8") + + observed = {} + + def fake_load(config_path, profile=None): + observed["profile"] = profile + return { + "arxiv": {"categories": ["cs.AI"]}, + "processor": {"keywords": ["agents"]}, + "analyzer": {"type": "abstract"}, + "triage": {"enabled": False}, + "logging": {"level": "INFO"}, + "active_profile": profile, + } + + monkeypatch.setattr("paperweight.main.load_config", fake_load) + + exit_code = main(["doctor", "--config", str(config_path), "--profile", "fast"]) + assert exit_code == 0 + assert observed["profile"] == "fast" diff --git a/tests/test_cli_integration.py b/tests/test_cli_integration.py index 7d21ebb..a6a57c8 100644 --- a/tests/test_cli_integration.py +++ b/tests/test_cli_integration.py @@ -32,7 +32,7 @@ def _write_config(tmp_path, *, triage_enabled=False): "min_score": 0, }, "analyzer": {"type": "abstract"}, - "logging": {"level": "INFO", "file": str(tmp_path / "paperweight.log")}, + "logging": {"level": "INFO"}, } config_path = tmp_path / "config.yaml" config_path.write_text(yaml.safe_dump(config), encoding="utf-8") @@ -46,6 +46,10 @@ def _stub_scraper(monkeypatch): "link": "http://arxiv.org/abs/2401.12345", "date": date(2024, 1, 15), "abstract": "A paper about transformer-based agents.", + "authors": ["Alice Smith", "Bob Jones"], + "categories": ["cs.AI", "cs.CL"], + "pdf_url": "https://arxiv.org/pdf/2401.12345", + "id": "2401.12345", } ] @@ -70,12 +74,20 @@ def _stub_scraper_two_papers(monkeypatch): "link": "http://arxiv.org/abs/2401.12345", "date": date(2024, 1, 15), "abstract": "A paper about transformer-based agents.", + "authors": ["Alice Smith", "Bob Jones"], + "categories": ["cs.AI", "cs.CL"], + "pdf_url": "https://arxiv.org/pdf/2401.12345", + "id": "2401.12345", }, { "title": "Reasoning Models", "link": "http://arxiv.org/abs/2401.67890", "date": date(2024, 1, 14), "abstract": "A paper about reasoning models.", + "authors": ["Carol White"], + "categories": ["cs.AI"], + "pdf_url": "https://arxiv.org/pdf/2401.67890", + "id": "2401.67890", }, ] monkeypatch.setattr( @@ -106,6 +118,8 @@ def test_run_stdout_mode_smoke(tmp_path, monkeypatch, capsys): assert "paperweight digest" in out assert "Transformer Agents" in out assert "http://arxiv.org/abs/2401.12345" in out + assert "Authors: Alice Smith, Bob Jones" in out + assert "Matched: " in out def test_run_atom_output_smoke(tmp_path, monkeypatch): @@ -167,3 +181,105 @@ def test_run_json_respects_max_items(tmp_path, monkeypatch): assert exit_code == 0 payload = json_path.read_text(encoding="utf-8") assert payload.count('"title"') == 1 + + +def test_run_max_items_caps_metadata_before_triage(tmp_path, monkeypatch, capsys): + """--max-items caps papers before triage/scoring; abstract mode skips hydration.""" + config_path = _write_config(tmp_path, triage_enabled=False) + _stub_scraper_two_papers(monkeypatch) + + hydrate_called = {"called": False} + + def fake_hydrate(papers, _config): + hydrate_called["called"] = True + return papers + + monkeypatch.setattr("paperweight.main.hydrate_papers_with_content", fake_hydrate) + + exit_code = main( + [ + "run", + "--config", + str(config_path), + "--force-refresh", + "--delivery", + "stdout", + "--max-items", + "1", + ] + ) + + assert exit_code == 0 + # Abstract mode should never hydrate + assert not hydrate_called["called"] + # Only 1 paper should be delivered + out = capsys.readouterr().out + assert out.count("Score:") == 1 + + +def test_run_json_includes_rich_fields(tmp_path, monkeypatch): + """JSON output includes arxiv_id, authors, categories, abstract, pdf_url, keywords_matched.""" + config_path = _write_config(tmp_path, triage_enabled=False) + json_path = tmp_path / "digest.json" + _stub_scraper(monkeypatch) + + exit_code = main( + [ + "run", + "--config", + str(config_path), + "--force-refresh", + "--delivery", + "json", + "--output", + str(json_path), + ] + ) + assert exit_code == 0 + + import json + + payload = json.loads(json_path.read_text(encoding="utf-8")) + assert len(payload) >= 1 + record = payload[0] + assert "arxiv_id" in record + assert "authors" in record + assert "categories" in record + assert "abstract" in record + assert "pdf_url" in record + assert "keywords_matched" in record + assert record["authors"] == ["Alice Smith", "Bob Jones"] + assert record["categories"] == ["cs.AI", "cs.CL"] + + +def test_zero_state_hint(tmp_path, monkeypatch, capsys): + """When 0 papers pass scoring, stderr shows a helpful hint.""" + config = { + "arxiv": {"categories": ["cs.AI"], "max_results": 5}, + "triage": {"enabled": False}, + "processor": { + "keywords": ["zzz_nonexistent_keyword_zzz"], + "exclusion_keywords": [], + "important_words": [], + "title_keyword_weight": 3, + "abstract_keyword_weight": 2, + "content_keyword_weight": 1, + "exclusion_keyword_penalty": 5, + "important_words_weight": 0.5, + "min_score": 9999, + }, + "analyzer": {"type": "abstract"}, + "logging": {"level": "INFO"}, + } + config_path = tmp_path / "config.yaml" + + import yaml + + config_path.write_text(yaml.safe_dump(config), encoding="utf-8") + _stub_scraper(monkeypatch) + + exit_code = main(["run", "--config", str(config_path), "--force-refresh"]) + + stderr = capsys.readouterr().err + assert "min_score" in stderr + assert exit_code == 0 diff --git a/tests/test_config.py b/tests/test_config.py index 8c0d764..29ea3d1 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -13,7 +13,9 @@ import yaml from paperweight.utils import ( + DEFAULT_CONFIG, _check_arxiv_section, + apply_profile, check_config, expand_env_vars, load_config, @@ -321,3 +323,88 @@ def test_negative_max_results_raises(self): """Negative max_results raises ValueError.""" with pytest.raises(ValueError, match="'max_results' in 'arxiv' section must be a non-negative integer"): _check_arxiv_section({'categories': ['cs.AI'], 'max_results': -1}) + + +# --------------------------------------------------------------------------- +# Profile Tests +# --------------------------------------------------------------------------- + +class TestProfiles: + """Tests for profile switching.""" + + def test_apply_profile_deep_merges(self): + """Profile overlay deep-merges into base config.""" + config = { + "arxiv": {"categories": ["cs.AI"], "max_results": 50}, + "profiles": { + "fast": {"arxiv": {"max_results": 20}}, + }, + } + merged = apply_profile(config, "fast") + assert merged["arxiv"]["max_results"] == 20 + assert merged["arxiv"]["categories"] == ["cs.AI"] + assert merged["active_profile"] == "fast" + + def test_apply_profile_unknown_name_raises(self): + """Unknown profile name raises ValueError.""" + config = {"profiles": {"fast": {"arxiv": {"max_results": 20}}}} + with pytest.raises(ValueError, match="Unknown profile: 'nope'"): + apply_profile(config, "nope") + + def test_load_config_with_profile(self, tmp_path): + """load_config applies profile when given.""" + cfg = { + "arxiv": {"categories": ["cs.AI"], "max_results": 50}, + "processor": {"keywords": ["AI"]}, + "analyzer": {"type": "abstract"}, + "logging": {"level": "INFO"}, + "profiles": { + "fast": {"arxiv": {"max_results": 10}}, + }, + } + config_path = tmp_path / "config.yaml" + config_path.write_text(yaml.dump(cfg), encoding="utf-8") + with patch.dict(os.environ, {}, clear=False): + result = load_config(config_path=str(config_path), profile="fast") + assert result["arxiv"]["max_results"] == 10 + assert result["active_profile"] == "fast" + + +# --------------------------------------------------------------------------- +# DEFAULT_CONFIG Merge Tests +# --------------------------------------------------------------------------- + +class TestDefaultConfigMerge: + """Tests for DEFAULT_CONFIG merge behavior in load_config.""" + + def test_minimal_config_loads_without_crash(self, tmp_path): + """A config with only arxiv.categories loads successfully via DEFAULT_CONFIG merge.""" + cfg = {"arxiv": {"categories": ["cs.AI"]}} + config_path = tmp_path / "config.yaml" + config_path.write_text(yaml.dump(cfg), encoding="utf-8") + with patch.dict(os.environ, {}, clear=False): + result = load_config(config_path=str(config_path)) + assert result["analyzer"]["type"] == "abstract" + assert result["processor"]["min_score"] == 3 + assert result["triage"]["enabled"] is False + assert result["logging"]["level"] == "INFO" + assert "file" not in result["logging"] + + def test_default_config_has_triage_disabled(self): + """DEFAULT_CONFIG has triage.enabled set to False.""" + assert DEFAULT_CONFIG["triage"]["enabled"] is False + + def test_user_config_overrides_defaults(self, tmp_path): + """User config values override DEFAULT_CONFIG.""" + cfg = { + "arxiv": {"categories": ["cs.AI"]}, + "processor": {"min_score": 10, "keywords": ["test"]}, + } + config_path = tmp_path / "config.yaml" + config_path.write_text(yaml.dump(cfg), encoding="utf-8") + with patch.dict(os.environ, {}, clear=False): + result = load_config(config_path=str(config_path)) + assert result["processor"]["min_score"] == 10 + assert result["processor"]["keywords"] == ["test"] + # Defaults still fill in missing keys + assert result["processor"]["title_keyword_weight"] == 3 diff --git a/tests/test_db.py b/tests/test_db.py index 2741604..8b4ce68 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -10,7 +10,7 @@ class TestConnectDb: """Tests for connect_db context manager.""" - @patch("paperweight.db.psycopg.connect") + @patch("psycopg.connect") def test_connect_db_basic(self, mock_connect): """Test basic database connection.""" mock_conn = MagicMock() @@ -39,7 +39,7 @@ def test_connect_db_basic(self, mock_connect): ) mock_conn.close.assert_called_once() - @patch("paperweight.db.psycopg.connect") + @patch("psycopg.connect") def test_connect_db_autocommit(self, mock_connect): """Test database connection with autocommit enabled.""" mock_conn = MagicMock() @@ -60,7 +60,7 @@ def test_connect_db_autocommit(self, mock_connect): call_kwargs = mock_connect.call_args[1] assert call_kwargs["autocommit"] is True - @patch("paperweight.db.psycopg.connect") + @patch("psycopg.connect") def test_connect_db_default_sslmode(self, mock_connect): """Test that sslmode defaults to 'prefer' when not specified.""" mock_conn = MagicMock() @@ -80,7 +80,7 @@ def test_connect_db_default_sslmode(self, mock_connect): call_kwargs = mock_connect.call_args[1] assert call_kwargs["sslmode"] == "prefer" - @patch("paperweight.db.psycopg.connect") + @patch("psycopg.connect") def test_connect_db_closes_on_exception(self, mock_connect): """Test that connection is closed even when an exception occurs.""" mock_conn = MagicMock() diff --git a/tests/test_notifier.py b/tests/test_notifier.py index 6782250..3a5a163 100644 --- a/tests/test_notifier.py +++ b/tests/test_notifier.py @@ -57,6 +57,8 @@ def test_render_text_digest_deterministic(): "link": "http://arxiv.org/abs/2", "relevance_score": 2.0, "triage_rationale": "Matched transformer + planning", + "authors": ["Alice", "Bob", "Carol", "Dave"], + "keywords_matched": ["transformer"], }, { "title": "A Paper", @@ -65,12 +67,18 @@ def test_render_text_digest_deterministic(): "link": "http://arxiv.org/abs/1", "relevance_score": 1.0, "triage_rationale": "Matched profile keywords", + "authors": ["Eve"], + "keywords_matched": ["agent"], }, ] digest = render_text_digest(papers, sort_order="alphabetical") assert "1. A Paper" in digest assert "2. B Paper" in digest assert "Why: Matched profile keywords" in digest + assert "Authors: Alice, Bob, Carol +1 more" in digest + assert "Authors: Eve" in digest + assert "Matched: transformer" in digest + assert "Matched: agent" in digest def test_render_atom_feed_contains_required_elements(): @@ -81,6 +89,8 @@ def test_render_atom_feed_contains_required_elements(): "summary": "Summary text", "link": "http://arxiv.org/abs/2401.12345", "relevance_score": 5.5, + "authors": ["Alice Smith", "Bob Jones"], + "categories": ["cs.AI", "cs.CL"], } ] feed = render_atom_feed(papers) @@ -89,6 +99,10 @@ def test_render_atom_feed_contains_required_elements(): assert "" in feed assert "Test Paper" in feed assert "http://arxiv.org/abs/2401.12345" in feed + assert "Alice Smith" in feed + assert "Bob Jones" in feed + assert 'term="cs.AI"' in feed + assert 'term="cs.CL"' in feed def test_write_output_to_file(tmp_path): @@ -102,13 +116,32 @@ def test_render_json_digest_contains_expected_fields(): { "title": "Test Paper", "date": date(2024, 1, 2), + "abstract": "An abstract about transformers.", "summary": "Summary text", "link": "http://arxiv.org/abs/2401.12345", "relevance_score": 5.5, + "triage_score": 85.0, "triage_rationale": "Matched core interests", + "id": "2401.12345", + "authors": ["Alice Smith"], + "categories": ["cs.AI"], + "pdf_url": "https://arxiv.org/pdf/2401.12345", + "keywords_matched": ["transformer"], } ] + import json + payload = render_json_digest(papers) - assert '"title": "Test Paper"' in payload - assert '"why": "Matched core interests"' in payload - assert '"score": 5.5' in payload + data = json.loads(payload) + record = data[0] + assert record["title"] == "Test Paper" + assert record["arxiv_id"] == "2401.12345" + assert record["authors"] == ["Alice Smith"] + assert record["categories"] == ["cs.AI"] + assert record["pdf_url"] == "https://arxiv.org/pdf/2401.12345" + assert record["keywords_matched"] == ["transformer"] + assert record["score"] == 5.5 + assert record["triage_score"] == 85.0 + assert record["triage_rationale"] == "Matched core interests" + # summary differs from abstract, so it should appear + assert record["summary"] == "Summary text" diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 4634c5d..64c7f7d 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -300,7 +300,7 @@ def test_pipeline_end_to_end_stubbed(monkeypatch, tmp_path): def fake_fetch_recent_papers(_config, _days): return list(fake_papers) - def fake_fetch_paper_contents(paper_ids): + def fake_fetch_paper_contents(paper_ids, max_workers=6): return [(paper_id, b"stub content", "pdf") for paper_id in paper_ids] monkeypatch.setattr( @@ -384,13 +384,15 @@ def test_no_papers_found(self, mock_main_dependencies): ) def test_default_delivery_writes_digest(self, mock_main_dependencies): - """Default mode renders and writes stdout digest.""" + """Default mode renders and writes stdout digest; abstract mode skips hydration.""" main() mock_main_dependencies['get_recent_papers'].assert_called_once_with( mock_main_dependencies['load_config'].return_value, include_content=False ) mock_main_dependencies['triage_papers'].assert_called_once() - mock_main_dependencies['hydrate_papers_with_content'].assert_called_once() + mock_main_dependencies['process_papers'].assert_called_once() + # Abstract mode skips content hydration entirely + mock_main_dependencies['hydrate_papers_with_content'].assert_not_called() mock_main_dependencies['render_text_digest'].assert_called_once() mock_main_dependencies['write_output'].assert_called_once() mock_main_dependencies['notifications'].assert_not_called() diff --git a/tests/test_processor.py b/tests/test_processor.py index 69015d8..fd6fddf 100644 --- a/tests/test_processor.py +++ b/tests/test_processor.py @@ -8,6 +8,7 @@ from paperweight.processor import ( calculate_paper_score, + count_keywords, normalize_scores, process_papers, ) @@ -80,6 +81,40 @@ def test_empty_input_returns_empty(self, processor_config): assert result == [] +class TestCountKeywords: + """Tests for the count_keywords function.""" + + def test_returns_tuple(self): + """count_keywords returns (score, matched_list) tuple.""" + score, matched = count_keywords("AI in healthcare", ["AI", "healthcare"]) + assert score > 0 + assert set(matched) == {"AI", "healthcare"} + + def test_no_matches_returns_empty(self): + """count_keywords with no matches returns zero score and empty list.""" + score, matched = count_keywords("nothing relevant here", ["quantum"]) + assert score == 0.0 + assert matched == [] + + +class TestKeywordsMatched: + """Tests for keywords_matched propagation in process_papers.""" + + def test_keywords_matched_in_scored_papers(self, processor_config): + """Scored papers include keywords_matched field.""" + papers = [ + { + "title": "AI in Healthcare", + "abstract": "This paper discusses AI applications in healthcare.", + "content": "", + } + ] + processed = process_papers(papers, processor_config) + assert len(processed) >= 1 + assert "keywords_matched" in processed[0] + assert "AI" in processed[0]["keywords_matched"] + + class TestNormalizeScores: """Tests for the normalize_scores function.""" diff --git a/tests/test_scraper.py b/tests/test_scraper.py index 3d585ce..87863d5 100644 --- a/tests/test_scraper.py +++ b/tests/test_scraper.py @@ -5,6 +5,7 @@ from paperweight.db import DatabaseConnectionError from paperweight.scraper import ( + _write_metadata_cache, extract_text_from_source, fetch_arxiv_papers, get_recent_papers, @@ -142,7 +143,7 @@ def test_hydrate_papers_with_content(monkeypatch): monkeypatch.setattr( "paperweight.scraper.fetch_paper_contents", - lambda _ids: [("2401.12345", b"pdf-bytes", "pdf")], + lambda _ids, max_workers=6: [("2401.12345", b"pdf-bytes", "pdf")], ) monkeypatch.setattr( "paperweight.scraper.extract_text_from_source", lambda _content, _method: "text" @@ -178,3 +179,46 @@ def test_get_recent_papers_without_content(monkeypatch): assert len(papers) == 1 assert papers[0]["content"] == "" fetch_content.assert_not_called() + + +def test_get_recent_papers_uses_metadata_cache(tmp_path, monkeypatch): + """When metadata cache is enabled and fresh, skip arXiv API calls.""" + cache_path = str(tmp_path / "cache.json") + config = { + "arxiv": {"categories": ["cs.AI"], "max_results": 2}, + "db": {"enabled": False}, + "metadata_cache": {"enabled": True, "path": cache_path, "ttl_hours": 4}, + } + cached_papers = [ + { + "title": "Cached Paper", + "link": "http://arxiv.org/abs/2401.99999", + "date": datetime(2024, 1, 15).date(), + "abstract": "Cached abstract", + } + ] + + # Pre-populate the cache + from paperweight.scraper import _metadata_cache_key + + key = _metadata_cache_key(config) + _write_metadata_cache(cache_path, key, cached_papers) + + monkeypatch.setattr("paperweight.scraper.get_last_processed_date", lambda: None) + monkeypatch.setattr("paperweight.scraper.save_last_processed_date", lambda _d: None) + + fetch_called = {"called": False} + + def fake_fetch(_config, _days): + fetch_called["called"] = True + return [] + + monkeypatch.setattr("paperweight.scraper.fetch_recent_papers", fake_fetch) + fetch_content = MagicMock() + monkeypatch.setattr("paperweight.scraper.fetch_paper_contents", fetch_content) + + papers = get_recent_papers(config, force_refresh=False, include_content=False) + assert len(papers) == 1 + assert papers[0]["title"] == "Cached Paper" + assert not fetch_called["called"] + fetch_content.assert_not_called() diff --git a/uv.lock b/uv.lock index 33326af..973ae9f 100644 --- a/uv.lock +++ b/uv.lock @@ -9,7 +9,7 @@ resolution-markers = [ [[package]] name = "academic-paperweight" -version = "0.2.0" +version = "0.3.0" source = { editable = "." } dependencies = [ { name = "arxiv" },