Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert citation to formatted_citation usage where necessary #666

Merged
merged 7 commits into from
Nov 5, 2024
2 changes: 1 addition & 1 deletion paperqa/agents/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def table_formatter(
try:
display_name = cast(Docs, obj).texts[0].doc.title # type: ignore[attr-defined]
except AttributeError:
display_name = cast(Docs, obj).texts[0].doc.citation
display_name = cast(Docs, obj).texts[0].doc.formatted_citation
table.add_row(display_name[:max_chars_per_column], filename)
return table
raise NotImplementedError(
Expand Down
2 changes: 1 addition & 1 deletion paperqa/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ async def map_fxn_summary(
# needed empties for failures/skips
llm_result = LLMResult(model="", date="")
extras: dict[str, Any] = {}
citation = text.name + ": " + text.doc.citation
citation = text.name + ": " + text.doc.formatted_citation
success = False

if prompt_runner:
Expand Down
6 changes: 3 additions & 3 deletions paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ async def aget_evidence(
prompt_runner=prompt_runner,
extra_prompt_data={
"summary_length": answer_config.evidence_summary_length,
"citation": f"{m.name}: {m.doc.citation}",
"citation": f"{m.name}: {m.doc.formatted_citation}",
},
parser=llm_parse_json if prompt_config.use_json else None,
callbacks=callbacks,
Expand Down Expand Up @@ -715,7 +715,7 @@ async def aquery( # noqa: PLR0912
context_inner_prompt.format(
name=c.text.name,
text=c.context,
citation=c.text.doc.citation,
citation=c.text.doc.formatted_citation,
**(c.model_extra or {}),
)
for c in filtered_contexts
Expand Down Expand Up @@ -756,7 +756,7 @@ async def aquery( # noqa: PLR0912
answer_text = answer_text.replace(prompt_config.EXAMPLE_CITATION, "")
for c in filtered_contexts:
name = c.text.name
citation = c.text.doc.citation
citation = c.text.doc.formatted_citation
# do check for whole key (so we don't catch Callahan2019a with Callahan2019)
if name_in_text(name, answer_text):
bib[name] = citation
Expand Down
20 changes: 10 additions & 10 deletions paperqa/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ class Doc(Embeddable):
def __hash__(self) -> int:
return hash((self.docname, self.dockey))

@computed_field # type: ignore[prop-decorator]
@property
def formatted_citation(self) -> str:
return self.citation


class Text(Embeddable):
text: str
Expand Down Expand Up @@ -607,8 +612,9 @@ def __getitem__(self, item: str):
except AttributeError:
return self.other[item]

@computed_field # type: ignore[prop-decorator]
@property
def formatted_citation(self) -> str:
def formatted_citation(self) -> str | None: # type: ignore[override]

if self.is_retracted:
base_message = "**RETRACTED ARTICLE**"
Expand All @@ -620,15 +626,9 @@ def formatted_citation(self) -> str:
)
return f"{base_message} {citation_message} {retract_info}"

if (
self.citation is None # type: ignore[redundant-expr]
or self.citation_count is None
or self.source_quality is None
):
raise ValueError(
"Citation, citationCount, and sourceQuality are not set -- do you need"
" to call `hydrate`?"
)
if self.citation_count is None or self.source_quality is None:
logger.warning("citation_count and source_quality are not set.")
return self.citation

if self.source_quality_message:
return (
Expand Down
Loading