Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More types #484

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyterrier/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

# will be set in terrier.terrier.java once java is loaded
IndexRef = None
# will be set in once utils.set_tqdm() once _() runs
tqdm = None


# deprecated functions explored to the main namespace, which will be removed in a future version
Expand Down
20 changes: 10 additions & 10 deletions pyterrier/apply_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Callable, Any, Union, Optional, Iterable
import itertools
import more_itertools
import numpy as np
import numpy.typing as npt
import pandas as pd
import pyterrier as pt

Expand Down Expand Up @@ -92,7 +92,7 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
# batching
iterator = pt.model.split_df(inp, batch_size=self.batch_size)
if self.verbose:
iterator = pt.tqdm(iterator, desc="pt.apply", unit='row')
iterator = pt.tqdm(iterator, desc="pt.apply", unit='row') # type: ignore
return pd.concat([self._apply_df(chunk_df) for chunk_df in iterator])

def _apply_df(self, inp: pd.DataFrame) -> pd.DataFrame:
Expand Down Expand Up @@ -148,7 +148,7 @@ def transform(self, res: pd.DataFrame) -> pd.DataFrame:
it = res.groupby("qid")
lastqid = None
if self.verbose:
it = pt.tqdm(it, unit='query')
it = pt.tqdm(it, unit='query') # type: ignore
try:
if self.batch_size is None:
query_dfs = []
Expand Down Expand Up @@ -275,7 +275,7 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:

iterator = pt.model.split_df(outputRes, batch_size=self.batch_size)
if self.verbose:
iterator = pt.tqdm(iterator, desc="pt.apply", unit='row')
iterator = pt.tqdm(iterator, desc="pt.apply", unit='row') # type: ignore
rtr = pd.concat([self._transform_batchwise(chunk_df) for chunk_df in iterator])
rtr = pt.model.add_ranks(rtr)
return rtr
Expand All @@ -294,7 +294,7 @@ def _feature_fn(row):
pipe = pt.terrier.Retriever(index) >> pt.apply.doc_features(_feature_fn) >> pt.LTRpipeline(xgBoost())
"""
def __init__(self,
fn: Callable[[Union[pd.Series, pt.model.IterDictRecord]], np.array],
fn: Callable[[Union[pd.Series, pt.model.IterDictRecord]], npt.NDArray],
*,
verbose: bool = False
):
Expand All @@ -313,7 +313,7 @@ def transform_iter(self, inp: pt.model.IterDict) -> pt.model.IterDict:
# we assume that the function can take a dictionary as well as a pandas.Series. As long as [""] notation is used
# to access fields, both should work
if self.verbose:
inp = pt.tqdm(inp, desc="pt.apply.doc_features")
inp = pt.tqdm(inp, desc="pt.apply.doc_features") # type: ignore
for row in inp:
row["features"] = self.fn(row)
yield row
Expand All @@ -322,7 +322,7 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
fn = self.fn
outputRes = inp.copy()
if self.verbose:
pt.tqdm.pandas(desc="pt.apply.doc_features", unit="d")
pt.tqdm.pandas(desc="pt.apply.doc_features", unit="d") # type: ignore
outputRes["features"] = outputRes.progress_apply(fn, axis=1)
else:
outputRes["features"] = outputRes.apply(fn, axis=1)
Expand Down Expand Up @@ -368,7 +368,7 @@ def transform_iter(self, inp: pt.model.IterDict) -> pt.model.IterDict:
# we assume that the function can take a dictionary as well as a pandas.Series. As long as [""] notation is used
# to access fields, both should work
if self.verbose:
inp = pt.tqdm(inp, desc="pt.apply.query")
inp = pt.tqdm(inp, desc="pt.apply.query") # type: ignore
for row in inp:
row = row.copy()
if "query" in row:
Expand All @@ -384,7 +384,7 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
outputRes = inp.copy()
try:
if self.verbose:
pt.tqdm.pandas(desc="pt.apply.query", unit="d")
pt.tqdm.pandas(desc="pt.apply.query", unit="d") # type: ignore
outputRes["query"] = outputRes.progress_apply(self.fn, axis=1)
else:
outputRes["query"] = outputRes.apply(self.fn, axis=1)
Expand Down Expand Up @@ -444,7 +444,7 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
# batching
iterator = pt.model.split_df(inp, batch_size=self.batch_size)
if self.verbose:
iterator = pt.tqdm(iterator, desc="pt.apply", unit='row')
iterator = pt.tqdm(iterator, desc="pt.apply", unit='row') # type: ignore
rtr = pd.concat([self.fn(chunk_df) for chunk_df in iterator])
return rtr

Expand Down
12 changes: 6 additions & 6 deletions pyterrier/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pandas as pd
from .transformer import is_lambda
import types
from typing import Union, Tuple, Iterator, Dict, Any, List, Literal
from typing import Union, Tuple, Iterator, Dict, Any, List, Literal, Optional
from warnings import warn
import requests
from .io import autoopen, touch
Expand Down Expand Up @@ -139,7 +139,7 @@ def download(URLs : Union[str,List[str]], filename : str, **kwargs):
r = requests.get(url, allow_redirects=True, stream=True, **kwargs)
r.raise_for_status()
total = int(r.headers.get('content-length', 0))
with pt.io.finalized_open(filename, 'b') as file, pt.tqdm(
with pt.io.finalized_open(filename, 'b') as file, pt.tqdm( # type: ignore
desc=basename,
total=total,
unit='iB',
Expand Down Expand Up @@ -507,7 +507,7 @@ def get_results(self, variant=None) -> pd.DataFrame:
result.sort_values(by=['qid', 'score', 'docno'], ascending=[True, False, True], inplace=True) # ensure data is sorted by qid, -score, did
# result doesn't yet contain queries (only qids) so load and merge them in
topics = self.get_topics(variant)
result = pd.merge(result, topics, how='left', on='qid', copy=False)
result = pd.merge(result, topics, how='left', on='qid')
return result

def _describe_component(self, component):
Expand Down Expand Up @@ -610,7 +610,7 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
set_docnos = set(docnos)
it = (tuple(getattr(doc, f) for f in fields) for doc in docstore.get_many_iter(set_docnos))
if self.verbose:
it = pd.tqdm(it, unit='d', total=len(set_docnos), desc='IRDSTextLoader')
it = pt.tqdm(it, unit='d', total=len(set_docnos), desc='IRDSTextLoader') # type: ignore
metadata = pd.DataFrame(list(it), columns=fields).set_index('doc_id')
metadata_frame = metadata.loc[docnos].reset_index(drop=True)

Expand Down Expand Up @@ -1104,7 +1104,7 @@ def _merge_years(self, component, variant):
"corpus_iter" : lambda dataset, **kwargs : pt.index.treccollection2textgen(dataset.get_corpus(), num_docs=11429, verbose=kwargs.get("verbose", False))
}

DATASET_MAP = {
DATASET_MAP : Dict[str, Dataset] = {
# used for UGlasgow teaching
"50pct" : RemoteDataset("50pct", FIFTY_PCT_FILES),
# umass antique corpus - see http://ciir.cs.umass.edu/downloads/Antique/
Expand Down Expand Up @@ -1210,7 +1210,7 @@ def list_datasets(en_only=True):
def transformer_from_dataset(
dataset : Union[str, Dataset],
clz,
variant: str = None,
variant: Optional[str] = None,
version: str = 'latest',
**kwargs) -> pt.Transformer:
"""Returns a Transformer instance of type ``clz`` for the provided index of variant ``variant``."""
Expand Down
8 changes: 4 additions & 4 deletions pyterrier/debug.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from . import Transformer
from typing import List
from typing import List, Optional

def print_columns(by_query : bool = False, message : str = None) -> Transformer:
def print_columns(by_query : Optional[bool] = False, message : Optional[str] = None) -> Transformer:
"""
Returns a transformer that can be inserted into pipelines that can print the column names of the dataframe
at this stage in the pipeline:
Expand Down Expand Up @@ -82,8 +82,8 @@ def print_rows(
by_query : bool = True,
jupyter: bool = True,
head : int = 2,
message : str = None,
columns : List[str] = None) -> Transformer:
message : Optional[str] = None,
columns : Optional[List[str]] = None) -> Transformer:
"""
Returns a transformer that can be inserted into pipelines that can print some of the dataframe
at this stage in the pipeline:
Expand Down
2 changes: 1 addition & 1 deletion pyterrier/java/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def add_jar(jar_path):


@before_init
def add_package(org_name: str = None, package_name: str = None, version: str = None, file_type='jar'):
def add_package(org_name : str, package_name : str, version : Optional[str] = None, file_type : str = 'jar'):
if version is None or version == 'snapshot':
version = mavenresolver.latest_version_num(org_name, package_name)
file_name = mavenresolver.get_package_jar(org_name, package_name, version, artifact=file_type)
Expand Down
2 changes: 1 addition & 1 deletion pyterrier/java/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def register_config(name, config: Dict[str, Any]):
class JavaClasses:
def __init__(self, **mapping: Union[str, Callable[[], str]]):
self._mapping = mapping
self._cache = {}
self._cache : Dict[str, Callable]= {}

def __dir__(self):
return list(self._mapping.keys())
Expand Down
12 changes: 6 additions & 6 deletions pyterrier/new.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

from typing import Sequence, Union
from typing import Sequence, Union, Optional, cast, Iterable
import pandas as pd
from .model import add_ranks

Expand All @@ -9,7 +9,7 @@ def empty_Q() -> pd.DataFrame:
"""
return pd.DataFrame(columns=["qid", "query"])

def queries(queries : Union[str, Sequence[str]], qid : Union[str, Sequence[str]] = None, **others) -> pd.DataFrame:
def queries(queries : Union[str, Sequence[str]], qid : Optional[Union[str, Iterable[str]]] = None, **others) -> pd.DataFrame:
"""
Creates a new queries dataframe. Will return a dataframe with the columns `["qid", "query"]`.
Any further lists in others will also be added.
Expand Down Expand Up @@ -40,7 +40,7 @@ def queries(queries : Union[str, Sequence[str]], qid : Union[str, Sequence[str]]
assert type(qid) == str
return pd.DataFrame({"qid" : [qid], "query" : [queries], **others})
if qid is None:
qid = map(str, range(1, len(queries)+1))
qid = cast(Iterable[str], map(str, range(1, len(queries)+1))) # noqa: PT100 (this is typing.cast, not jinus.cast)
return pd.DataFrame({"qid" : qid, "query" : queries, **others})

Q = queries
Expand All @@ -53,8 +53,8 @@ def empty_R() -> pd.DataFrame:

def ranked_documents(
scores : Sequence[Sequence[float]],
qid : Sequence[str] = None,
docno=None,
qid : Optional[Sequence[str]] = None,
docno = Optional[Sequence[Sequence[str]]],
**others) -> pd.DataFrame:
"""
Creates a new ranked documents dataframe. Will return a dataframe with the columns `["qid", "docno", "score", "rank"]`.
Expand Down Expand Up @@ -120,4 +120,4 @@ def ranked_documents(
raise ValueError("We assume multiple documents, for now")
return add_ranks(rtr)

R = ranked_documents
R = ranked_documents
Loading
Loading