Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More types #484

Merged
merged 30 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
cba7764
lots of easy type fixes
cmacdonald Sep 19, 2024
b761b7a
remove invalid state mypy error
cmacdonald Sep 19, 2024
7b14b38
more type fixes
cmacdonald Sep 19, 2024
ff4baf5
type fix for numpy array
cmacdonald Sep 19, 2024
d44182a
one less type warning
cmacdonald Sep 19, 2024
bbd09f4
lots of type fixes, mostly for tqdm
cmacdonald Sep 19, 2024
09073af
ignore false positive PT100 error
seanmacavaney Sep 23, 2024
5abd1c5
various types updates
cmacdonald Nov 28, 2024
0efaa51
_ir_measures_to_dict now has types
cmacdonald Nov 28, 2024
9487500
woops. syntax
cmacdonald Nov 28, 2024
f0c175e
mroe explicit
cmacdonald Nov 28, 2024
26af998
ir_measures improvements
cmacdonald Nov 28, 2024
b21b355
fix
cmacdonald Nov 28, 2024
1b0c9aa
minor type fixes
cmacdonald Nov 28, 2024
5ed253a
mypy
cmacdonald Nov 28, 2024
20bf087
Merge branch 'master' into more_types
seanmacavaney Nov 29, 2024
ed8b57a
mypy style check github action
seanmacavaney Nov 29, 2024
832b5e2
daft assertion for mypy
cmacdonald Dec 4, 2024
4d414a3
mypy
cmacdonald Dec 4, 2024
90e76ff
mypy suggestion
cmacdonald Dec 4, 2024
6719d71
Merge branch 'more_types' of github.com:terrier-org/pyterrier into mo…
cmacdonald Dec 4, 2024
6e638e5
Merge branch 'master' into more_types
cmacdonald Dec 5, 2024
8870def
dont type check Jnius files
cmacdonald Dec 5, 2024
7de75a4
eliminate mypy warning
cmacdonald Dec 5, 2024
58d8025
mypy
cmacdonald Dec 5, 2024
8b1d9a5
mypy
cmacdonald Dec 5, 2024
6604563
dont make mypy a failure
cmacdonald Dec 5, 2024
182a02e
allow alternative formats for save_dir files (#502)
cmacdonald Dec 5, 2024
e6a6384
note the presence of types
cmacdonald Dec 5, 2024
3670d5a
Merge branch 'more_types' of github.com:terrier-org/pyterrier into mo…
cmacdonald Dec 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions .github/workflows/style.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Code Style Checks
name: style

on:
push:
Expand All @@ -7,7 +7,7 @@ on:
branches: [ master ]

jobs:
build:
flake8:
runs-on: 'ubuntu-latest'
steps:
- uses: actions/checkout@v4
Expand All @@ -24,3 +24,21 @@ jobs:
- name: pt.java.required checks
run: |
flake8 ./pyterrier --select=PT --show-source --statistics --count

mypy:
runs-on: 'ubuntu-latest'
steps:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.10'

- name: Install
run: |
pip install mypy --upgrade -r requirements.txt -r requirements-test.txt
pip install -e .

- name: MyPy
run: 'mypy --disable-error-code=import-untyped pyterrier || true'
2 changes: 2 additions & 0 deletions pyterrier/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@

# will be set in terrier.terrier.java once java is loaded
IndexRef = None
# will be set in once utils.set_tqdm() once _() runs
tqdm = None


# deprecated functions explored to the main namespace, which will be removed in a future version
Expand Down
7 changes: 4 additions & 3 deletions pyterrier/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __len__(self) -> int:

def _flatten(transformers: Iterable[Transformer], cls: type) -> Tuple[Transformer]:
return list(chain.from_iterable(
(t._transformers if isinstance(t, cls) else [t])
(t._transformers if isinstance(t, cls) else [t]) # type: ignore
for t in transformers
))

Expand Down Expand Up @@ -193,6 +193,7 @@ def fuse_left(self, left: Transformer) -> Optional[Transformer]:
# If the preceding component supports a native rank cutoff (via fuse_rank_cutoff), apply it.
if isinstance(left, SupportsFuseRankCutoff):
return left.fuse_rank_cutoff(self.k)
return None

class FeatureUnion(NAryTransformerBase):
"""
Expand Down Expand Up @@ -295,7 +296,7 @@ def compile(self) -> Transformer:
"""
Returns a new transformer that fuses feature unions where possible.
"""
out = deque()
out : deque = deque()
inp = deque([t.compile() for t in self._transformers])
while inp:
right = inp.popleft()
Expand Down Expand Up @@ -382,7 +383,7 @@ def compile(self, verbose: bool = False) -> Transformer:
"""Returns a new transformer that iteratively fuses adjacent transformers to form a more efficient pipeline."""
# compile constituent transformers (flatten allows complie() to return Compose pipelines)
inp = deque(_flatten((t.compile() for t in self._transformers), Compose))
out = deque()
out : deque = deque()
counter = 1
while inp:
if verbose:
Expand Down
28 changes: 14 additions & 14 deletions pyterrier/apply_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Callable, Any, Union, Optional, Iterable
import itertools
import more_itertools
import numpy as np
import numpy.typing as npt
import pandas as pd
import pyterrier as pt

Expand Down Expand Up @@ -92,7 +92,7 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
# batching
iterator = pt.model.split_df(inp, batch_size=self.batch_size)
if self.verbose:
iterator = pt.tqdm(iterator, desc="pt.apply", unit='row')
iterator = pt.tqdm(iterator, desc="pt.apply", unit='row') # type: ignore
return pd.concat([self._apply_df(chunk_df) for chunk_df in iterator])

def _apply_df(self, inp: pd.DataFrame) -> pd.DataFrame:
Expand Down Expand Up @@ -148,7 +148,7 @@ def transform(self, res: pd.DataFrame) -> pd.DataFrame:
it = res.groupby("qid")
lastqid = None
if self.verbose:
it = pt.tqdm(it, unit='query')
it = pt.tqdm(it, unit='query') # type: ignore
try:
if self.batch_size is None:
query_dfs = []
Expand All @@ -163,7 +163,7 @@ def transform(self, res: pd.DataFrame) -> pd.DataFrame:
iterator = pt.model.split_df(group, batch_size=self.batch_size)
query_dfs.append( pd.concat([self.fn(chunk_df) for chunk_df in iterator]) )
except Exception as a:
raise Exception("Problem applying %s for qid %s" % (self.fn, lastqid)) from a
raise Exception("Problem applying %r for qid %s" % (self.fn, lastqid)) from a # %r because its a function with bytes representation (mypy)

if self.add_ranks:
try:
Expand Down Expand Up @@ -253,7 +253,7 @@ def __repr__(self):
def _transform_rowwise(self, outputRes):
if self.verbose:
pt.tqdm.pandas(desc="pt.apply.doc_score", unit="d")
outputRes["score"] = outputRes.progress_apply(self.fn, axis=1).astype('float64')
outputRes["score"] = outputRes.progress_apply(self.fn, axis=1).astype('float64') # type: ignore
else:
outputRes["score"] = outputRes.apply(self.fn, axis=1).astype('float64')
outputRes = pt.model.add_ranks(outputRes)
Expand All @@ -275,7 +275,7 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:

iterator = pt.model.split_df(outputRes, batch_size=self.batch_size)
if self.verbose:
iterator = pt.tqdm(iterator, desc="pt.apply", unit='row')
iterator = pt.tqdm(iterator, desc="pt.apply", unit='row') # type: ignore
rtr = pd.concat([self._transform_batchwise(chunk_df) for chunk_df in iterator])
rtr = pt.model.add_ranks(rtr)
return rtr
Expand All @@ -294,7 +294,7 @@ def _feature_fn(row):
pipe = pt.terrier.Retriever(index) >> pt.apply.doc_features(_feature_fn) >> pt.LTRpipeline(xgBoost())
"""
def __init__(self,
fn: Callable[[Union[pd.Series, pt.model.IterDictRecord]], np.array],
fn: Callable[[Union[pd.Series, pt.model.IterDictRecord]], npt.NDArray],
*,
verbose: bool = False
):
Expand All @@ -313,7 +313,7 @@ def transform_iter(self, inp: pt.model.IterDict) -> pt.model.IterDict:
# we assume that the function can take a dictionary as well as a pandas.Series. As long as [""] notation is used
# to access fields, both should work
if self.verbose:
inp = pt.tqdm(inp, desc="pt.apply.doc_features")
inp = pt.tqdm(inp, desc="pt.apply.doc_features") # type: ignore
for row in inp:
row["features"] = self.fn(row)
yield row
Expand All @@ -322,8 +322,8 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
fn = self.fn
outputRes = inp.copy()
if self.verbose:
pt.tqdm.pandas(desc="pt.apply.doc_features", unit="d")
outputRes["features"] = outputRes.progress_apply(fn, axis=1)
pt.tqdm.pandas(desc="pt.apply.doc_features", unit="d") # type: ignore
outputRes["features"] = outputRes.progress_apply(fn, axis=1) # type: ignore
else:
outputRes["features"] = outputRes.apply(fn, axis=1)
return outputRes
Expand Down Expand Up @@ -368,7 +368,7 @@ def transform_iter(self, inp: pt.model.IterDict) -> pt.model.IterDict:
# we assume that the function can take a dictionary as well as a pandas.Series. As long as [""] notation is used
# to access fields, both should work
if self.verbose:
inp = pt.tqdm(inp, desc="pt.apply.query")
inp = pt.tqdm(inp, desc="pt.apply.query") # type: ignore
for row in inp:
row = row.copy()
if "query" in row:
Expand All @@ -384,8 +384,8 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
outputRes = inp.copy()
try:
if self.verbose:
pt.tqdm.pandas(desc="pt.apply.query", unit="d")
outputRes["query"] = outputRes.progress_apply(self.fn, axis=1)
pt.tqdm.pandas(desc="pt.apply.query", unit="d") # type: ignore
outputRes["query"] = outputRes.progress_apply(self.fn, axis=1) # type: ignore
else:
outputRes["query"] = outputRes.apply(self.fn, axis=1)
except ValueError as ve:
Expand Down Expand Up @@ -444,7 +444,7 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
# batching
iterator = pt.model.split_df(inp, batch_size=self.batch_size)
if self.verbose:
iterator = pt.tqdm(iterator, desc="pt.apply", unit='row')
iterator = pt.tqdm(iterator, desc="pt.apply", unit='row') # type: ignore
rtr = pd.concat([self.fn(chunk_df) for chunk_df in iterator])
return rtr

Expand Down
20 changes: 12 additions & 8 deletions pyterrier/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import json
import pandas as pd
from .transformer import is_lambda
from abc import abstractmethod
import types
from typing import Union, Tuple, Iterator, Dict, Any, List, Literal
from typing import Union, Tuple, Iterator, Dict, Any, List, Literal, Optional
from warnings import warn
import requests
from .io import autoopen, touch
Expand Down Expand Up @@ -54,12 +55,13 @@ def get_corpus(self):
"""
pass

@abstractmethod
def get_corpus_iter(self, verbose=True) -> pt.model.IterDict:
"""
Returns an iter of dicts for this collection. If verbose=True, a tqdm pbar shows the progress over this iterator.
"""
pass

def get_corpus_lang(self) -> Union[str,None]:
"""
Returns the ISO 639-1 language code for the corpus, or None for multiple/other/unknown
Expand All @@ -72,6 +74,7 @@ def get_index(self, variant=None, **kwargs):
"""
pass

@abstractmethod
def get_topics(self, variant=None) -> pd.DataFrame:
"""
Returns the topics, as a dataframe, ready for retrieval.
Expand All @@ -84,6 +87,7 @@ def get_topics_lang(self) -> Union[str,None]:
"""
return None

@abstractmethod
def get_qrels(self, variant=None) -> pd.DataFrame:
"""
Returns the qrels, as a dataframe, ready for evaluation.
Expand All @@ -109,7 +113,7 @@ def get_results(self, variant=None) -> pd.DataFrame:
"""
Returns a standard result set provided by the dataset. This is useful for re-ranking experiments.
"""
pass
return None

class RemoteDataset(Dataset):

Expand Down Expand Up @@ -139,7 +143,7 @@ def download(URLs : Union[str,List[str]], filename : str, **kwargs):
r = requests.get(url, allow_redirects=True, stream=True, **kwargs)
r.raise_for_status()
total = int(r.headers.get('content-length', 0))
with pt.io.finalized_open(filename, 'b') as file, pt.tqdm(
with pt.io.finalized_open(filename, 'b') as file, pt.tqdm( # type: ignore
desc=basename,
total=total,
unit='iB',
Expand Down Expand Up @@ -507,7 +511,7 @@ def get_results(self, variant=None) -> pd.DataFrame:
result.sort_values(by=['qid', 'score', 'docno'], ascending=[True, False, True], inplace=True) # ensure data is sorted by qid, -score, did
# result doesn't yet contain queries (only qids) so load and merge them in
topics = self.get_topics(variant)
result = pd.merge(result, topics, how='left', on='qid', copy=False)
result = pd.merge(result, topics, how='left', on='qid')
return result

def _describe_component(self, component):
Expand Down Expand Up @@ -610,7 +614,7 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
set_docnos = set(docnos)
it = (tuple(getattr(doc, f) for f in fields) for doc in docstore.get_many_iter(set_docnos))
if self.verbose:
it = pd.tqdm(it, unit='d', total=len(set_docnos), desc='IRDSTextLoader')
it = pt.tqdm(it, unit='d', total=len(set_docnos), desc='IRDSTextLoader') # type: ignore
metadata = pd.DataFrame(list(it), columns=fields).set_index('doc_id')
metadata_frame = metadata.loc[docnos].reset_index(drop=True)

Expand Down Expand Up @@ -1104,7 +1108,7 @@ def _merge_years(self, component, variant):
"corpus_iter" : lambda dataset, **kwargs : pt.index.treccollection2textgen(dataset.get_corpus(), num_docs=11429, verbose=kwargs.get("verbose", False))
}

DATASET_MAP = {
DATASET_MAP : Dict[str, Dataset] = {
# used for UGlasgow teaching
"50pct" : RemoteDataset("50pct", FIFTY_PCT_FILES),
# umass antique corpus - see http://ciir.cs.umass.edu/downloads/Antique/
Expand Down Expand Up @@ -1222,7 +1226,7 @@ def list_datasets(en_only=True):
def transformer_from_dataset(
dataset : Union[str, Dataset],
clz,
variant: str = None,
variant: Optional[str] = None,
version: str = 'latest',
**kwargs) -> pt.Transformer:
"""Returns a Transformer instance of type ``clz`` for the provided index of variant ``variant``."""
Expand Down
8 changes: 4 additions & 4 deletions pyterrier/debug.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from . import Transformer
from typing import List
from typing import List, Optional

def print_columns(by_query : bool = False, message : str = None) -> Transformer:
def print_columns(by_query : Optional[bool] = False, message : Optional[str] = None) -> Transformer:
"""
Returns a transformer that can be inserted into pipelines that can print the column names of the dataframe
at this stage in the pipeline:
Expand Down Expand Up @@ -82,8 +82,8 @@ def print_rows(
by_query : bool = True,
jupyter: bool = True,
head : int = 2,
message : str = None,
columns : List[str] = None) -> Transformer:
message : Optional[str] = None,
columns : Optional[List[str]] = None) -> Transformer:
"""
Returns a transformer that can be inserted into pipelines that can print some of the dataframe
at this stage in the pipeline:
Expand Down
3 changes: 2 additions & 1 deletion pyterrier/java/_core.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# type: ignore
import os
from pyterrier.java import required_raise, required, before_init, started, mavenresolver, JavaClasses, JavaInitializer, register_config
from typing import Optional
Expand Down Expand Up @@ -153,7 +154,7 @@ def add_jar(jar_path):


@before_init
def add_package(org_name: str = None, package_name: str = None, version: str = None, file_type='jar'):
def add_package(org_name : str, package_name : str, version : Optional[str] = None, file_type : str = 'jar'):
if version is None or version == 'snapshot':
version = mavenresolver.latest_version_num(org_name, package_name)
file_name = mavenresolver.get_package_jar(org_name, package_name, version, artifact=file_type)
Expand Down
3 changes: 2 additions & 1 deletion pyterrier/java/_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# type: ignore
import sys
import warnings
from functools import wraps
Expand Down Expand Up @@ -387,7 +388,7 @@ def register_config(name, config: Dict[str, Any]):
class JavaClasses:
def __init__(self, **mapping: Union[str, Callable[[], str]]):
self._mapping = mapping
self._cache = {}
self._cache : Dict[str, Callable]= {}

def __dir__(self):
return list(self._mapping.keys())
Expand Down
1 change: 1 addition & 0 deletions pyterrier/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def split_df(df : pd.DataFrame, N: Optional[int] = None, *, batch_size: Optional
assert (N is None) != (batch_size is None), "Either N or batch_size should be provided (and not both)"

if N is None:
assert batch_size is not None
N = math.ceil(len(df) / batch_size)

type = None
Expand Down
12 changes: 6 additions & 6 deletions pyterrier/new.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

from typing import Sequence, Union
from typing import Sequence, Union, Optional, cast, Iterable
import pandas as pd
from .model import add_ranks

Expand All @@ -9,7 +9,7 @@ def empty_Q() -> pd.DataFrame:
"""
return pd.DataFrame(columns=["qid", "query"])

def queries(queries : Union[str, Sequence[str]], qid : Union[str, Sequence[str]] = None, **others) -> pd.DataFrame:
def queries(queries : Union[str, Sequence[str]], qid : Optional[Union[str, Iterable[str]]] = None, **others) -> pd.DataFrame:
"""
Creates a new queries dataframe. Will return a dataframe with the columns `["qid", "query"]`.
Any further lists in others will also be added.
Expand Down Expand Up @@ -40,7 +40,7 @@ def queries(queries : Union[str, Sequence[str]], qid : Union[str, Sequence[str]]
assert type(qid) == str
return pd.DataFrame({"qid" : [qid], "query" : [queries], **others})
if qid is None:
qid = map(str, range(1, len(queries)+1))
qid = cast(Iterable[str], map(str, range(1, len(queries)+1))) # noqa: PT100 (this is typing.cast, not jinus.cast)
return pd.DataFrame({"qid" : qid, "query" : queries, **others})

Q = queries
Expand All @@ -53,8 +53,8 @@ def empty_R() -> pd.DataFrame:

def ranked_documents(
scores : Sequence[Sequence[float]],
qid : Sequence[str] = None,
docno=None,
qid : Optional[Sequence[str]] = None,
docno : Optional[Sequence[Sequence[str]]] = None,
**others) -> pd.DataFrame:
"""
Creates a new ranked documents dataframe. Will return a dataframe with the columns `["qid", "docno", "score", "rank"]`.
Expand Down Expand Up @@ -120,4 +120,4 @@ def ranked_documents(
raise ValueError("We assume multiple documents, for now")
return add_ranks(rtr)

R = ranked_documents
R = ranked_documents
Loading
Loading