Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dynamic module loading #461

Merged
merged 2 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 36 additions & 16 deletions pyterrier/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,6 @@
cast = deprecated(version='0.11.0', reason="use pt.java.cast(...) instead")(java.cast)


# Additional setup performed in a function to avoid polluting the namespace with other imports like platform
def _():
# check python version
import platform
from packaging.version import Version
if Version(platform.python_version()) < Version('3.7.0'):
raise RuntimeError("From PyTerrier 0.8, Python 3.7 minimum is required, you currently have %s" % platform.python_version())

# apply is an object, not a module, as it also has __get_attr__() implemented
from pyterrier.apply import _apply
globals()['apply'] = _apply()

utils.set_tqdm()
_()

__all__ = [
'java', 'terrier', 'anserini', 'cache', 'debug', 'io', 'measures', 'model', 'new', 'ltr', 'parallel', 'pipelines',
'text', 'transformer', 'datasets', 'get_dataset', 'find_datasets', 'list_datasets', 'Experiment', 'GridScan',
Expand All @@ -71,9 +56,44 @@ def _():
'BatchRetrieve', 'TerrierRetrieve', 'FeaturesBatchRetrieve', 'IndexFactory',
'run', 'rewrite', 'index', 'FilesIndexer', 'TRECCollectionIndexer', 'DFIndexer', 'DFIndexUtils', 'IterDictIndexer',
'IndexingType', 'TerrierStemmer', 'TerrierStopwords', 'TerrierTokeniser',
'IndexRef', 'ApplicationSetup', 'properties', 'apply',
'IndexRef', 'ApplicationSetup', 'properties',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should apply be removed here?


# Deprecated:
'init', 'started', 'logging', 'version', 'check_version', 'extend_classpath', 'set_tqdm', 'set_property', 'set_properties',
'redirect_stdouterr', 'autoclass', 'cast',

# Entry point modules (appended loaded below):
]


# Additional setup performed in a function to avoid polluting the namespace with other imports like platform
def _():
from warnings import warn
import platform
from packaging.version import Version

# check python version
if Version(platform.python_version()) < Version('3.7.0'):
raise RuntimeError("From PyTerrier 0.8, Python 3.7 minimum is required, you currently have %s" % platform.python_version())

globs = globals()

# Load the _apply object as pt.apply so that the dynamic __getattr__ methods work
from pyterrier.apply import _apply
globs['apply'] = _apply()
__all__.append('apply')
cmacdonald marked this conversation as resolved.
Show resolved Hide resolved

# load modules defined as package entry points into the global pyterrier namespace
for entry_point in utils.entry_points('pyterrier.modules'):
if entry_point.name in globs:
warn(f'skipping loading {entry_point} because a module with this name is already loaded.')
continue
module = entry_point.load()
if callable(module): # if the entry point refers to an function/class, call it to get the module
module = module()
globs[entry_point.name] = module
__all__.append(entry_point.name)

# guess the environment and set an appropriate tqdm as pt.tqdm
utils.set_tqdm()
_()
7 changes: 3 additions & 4 deletions pyterrier/ltr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@

import pyterrier as pt
from . import Transformer, Estimator
from .apply import doc_score, doc_features
from .model import add_ranks
from typing import Sequence, Union, Tuple
import numpy as np, pandas as pd
Expand Down Expand Up @@ -244,7 +243,7 @@ def feature_to_score(fid : int) -> Transformer:
Args:
fid: a single feature id that should be kept
"""
return doc_score(lambda row : row["features"][fid])
return pt.apply.doc_score(lambda row : row["features"][fid])

def apply_learned_model(learner, form : str = 'regression', **kwargs) -> Transformer:
"""
Expand Down Expand Up @@ -284,4 +283,4 @@ def score_to_feature() -> Transformer:
three_features = cands >> (bm25f ** pl2f ** pt.ltr.score_to_feature())

"""
return doc_features(lambda row : np.array(row["score"]))
return pt.apply.doc_features(lambda row : np.array(row["score"]))