Skip to content

Commit 31e1e8b

Browse files
authored
Fix mypy errors (#277)
* Fix mypy errors * Fix run_locally * Fix noqa
1 parent 3f86001 commit 31e1e8b

File tree

11 files changed

+32
-22
lines changed

11 files changed

+32
-22
lines changed

prompt2model/dataset_processor/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""A base class for dataset processor."""
22

3+
from __future__ import annotations # noqa FI58
4+
35
from abc import ABC, abstractmethod
46
from functools import partial
57

@@ -9,7 +11,7 @@
911
class BaseProcessor(ABC):
1012
"""A base class for post-processing datasets."""
1113

12-
def __init__(self, has_encoder: bool, eos_token: str) -> None:
14+
def __init__(self, has_encoder: bool, eos_token: str | None = None) -> None:
1315
"""Initialize the `BaseProcessor`.
1416
1517
Args:

prompt2model/dataset_processor/mock.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,15 @@ def process_dataset_dict(
2323
_ = instruction
2424
return dataset_dicts
2525

26-
def post_process_example(example: dict, instruction: str, task_id: int) -> dict:
26+
@staticmethod
27+
def post_process_example(
28+
example: dict,
29+
instruction: str,
30+
task_id: int,
31+
has_encoder: bool,
32+
dataset_split: str,
33+
eos_token: str,
34+
) -> dict:
2735
"""A mock function that modifies a given example dictionary.
2836
2937
Args:

prompt2model/dataset_retriever/hf_dataset_retriever.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -284,11 +284,11 @@ def retrieve_dataset_dict(
284284
self.dataset_infos[dataset_idx].score = dataset_score
285285
top_dataset_infos.append(self.dataset_infos[dataset_idx])
286286

287-
ranked_list = sorted(top_dataset_infos, key=lambda x: x.score, reverse=True)[
287+
sorted_list = sorted(top_dataset_infos, key=lambda x: x.score, reverse=True)[
288288
: self.max_search_depth
289289
]
290-
assert len(ranked_list) > 0, "No datasets retrieved from search index."
291-
top_dataset_name = self.choose_dataset(ranked_list)
290+
assert len(sorted_list) > 0, "No datasets retrieved from search index."
291+
top_dataset_name = self.choose_dataset(sorted_list)
292292
if top_dataset_name is None:
293293
return None
294294
return self.canonicalize_dataset(top_dataset_name)

prompt2model/dataset_retriever/run_dataset_retriever.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,4 @@
88
prompt_spec._instruction = prompt
99

1010
retriever = DescriptionDatasetRetriever()
11-
retriever.retrieve_dataset_dict(
12-
prompt_spec, blocklist=["squad", "stanford question answering"]
13-
)
11+
retriever.retrieve_dataset_dict(prompt_spec)

prompt2model/model_evaluator/base.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
from typing import Any
88

99
import datasets
10+
import evaluate
1011

1112
from prompt2model.model_executor import ModelOutput
12-
from prompt2model.prompt_parser import PromptSpec
1313

1414

1515
class ModelEvaluator(ABC):
@@ -21,8 +21,9 @@ def evaluate_model(
2121
dataset: datasets.Dataset,
2222
gt_column: str,
2323
predictions: list[ModelOutput],
24-
metrics: list[datasets.Metric] | None = None,
25-
prompt_spec: PromptSpec | None = None,
24+
model_input_column: str | None = None,
25+
metrics: list[evaluate.Metric] | None = None,
26+
encoder_model_name: str = "xlm-roberta-base",
2627
) -> dict[str, Any]:
2728
"""Evaluate a model on a test set..
2829

prompt2model/model_evaluator/mock.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
from typing import Any
55

66
import datasets
7+
import evaluate
78

89
from prompt2model.model_evaluator.base import ModelEvaluator
910
from prompt2model.model_executor import ModelOutput
10-
from prompt2model.prompt_parser import PromptSpec
1111

1212

1313
class MockEvaluator(ModelEvaluator):
@@ -21,8 +21,9 @@ def evaluate_model(
2121
dataset: datasets.Dataset,
2222
gt_column: str,
2323
predictions: list[ModelOutput],
24-
metrics: list[datasets.Metric] | None = None,
25-
prompt_spec: PromptSpec | None = None,
24+
model_input_column: str | None = None,
25+
metrics: list[evaluate.Metric] | None = None,
26+
encoder_model_name: str = "xlm-roberta-base",
2627
) -> dict[str, Any]:
2728
"""Return empty metrics dictionary.
2829

prompt2model/model_retriever/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ class ModelRetriever(ABC):
1313
def retrieve(
1414
self,
1515
prompt: PromptSpec,
16-
) -> str:
16+
) -> list[str]:
1717
"""Retrieve relevant models from HuggingFace.
1818
1919
Args:
2020
prompt: A prompt to use to select relevant models.
2121
2222
Return:
23-
A relevant model's HuggingFace name.
23+
A list of relevant models' HuggingFace names.
2424
"""

prompt2model/model_retriever/mock.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def __init__(self, fixed_model_name: str):
1414
def retrieve(
1515
self,
1616
prompt: PromptSpec,
17-
) -> str:
17+
) -> list[str]:
1818
"""Select an arbitrary, fixed model from HuggingFace.
1919
2020
Args:
@@ -23,4 +23,4 @@ def retrieve(
2323
Return:
2424
A relevant model's HuggingFace name.
2525
"""
26-
return self.fixed_model_name
26+
return [self.fixed_model_name]

prompt2model/model_trainer/generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __init__(
8383

8484
# self.validation_callback is used for evaluate the model on
8585
# the validation dataset after each epoch.
86-
self.validation_callback = None
86+
self.validation_callback: ValidationCallback | None = None
8787
self.training_seed = seed_generator.get_seed()
8888

8989
def get_left_padding_length(cls, input_list, padding_token_id):

prompt2model/param_selector/mock.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def select_from_hyperparameters(
4646
A model and tokenizer (trained using default hyperparameters).
4747
"""
4848
single_model = self.trainer.train_model(
49-
training_sets, self._example_hyperparameter_choices()
49+
self._example_hyperparameter_choices(), training_sets
5050
)
5151
return single_model
5252

prompt2model/run_locally.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def run_skeleton(prompt_tokens: list[str], metrics_output_path: str) -> None:
9494
model_retriever = MockModelRetriever("cardiffnlp/twitter-roberta-base-sentiment")
9595
retrieved_model_name = model_retriever.retrieve(prompt_spec)
9696

97-
trainer = MockTrainer(retrieved_model_name)
97+
trainer = MockTrainer(retrieved_model_name[0])
9898
selector = MockParamSelector(trainer)
9999
model, tokenizer = selector.select_from_hyperparameters(
100100
all_training, validation, {}
@@ -105,7 +105,7 @@ def run_skeleton(prompt_tokens: list[str], metrics_output_path: str) -> None:
105105

106106
evaluator = MockEvaluator()
107107
metrics_dict = evaluator.evaluate_model(
108-
testing, "output_col", predictions, [], prompt_spec
108+
testing, "output_col", predictions, "input_col", []
109109
)
110110
evaluator.write_metrics(metrics_dict, metrics_output_path)
111111
mock_gradio_create(model_executor, prompt_spec)

0 commit comments

Comments
 (0)