Skip to content

Commit

Permalink
fix transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
a-kore committed Jul 16, 2024
1 parent e0077bb commit f062494
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 24 deletions.
60 changes: 41 additions & 19 deletions cyclops/monitor/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,33 +717,45 @@ def __init__(
batch_size=batch_size,
)
self.base_model.initialize()
self.transforms = partial(apply_transforms, transforms=transforms)
model_transforms = transforms
model_transforms.transforms = model_transforms.transforms + (
Lambdad(
keys=("mask", "labels"),
func=lambda x: np.array(x),
allow_missing_keys=True,
),
)
self.model_transforms = partial(
apply_transforms,
transforms=model_transforms,
self.base_model.save_model(
"saved_models/DetectronModule/pretrained_model.pt", log=False
)
if transforms:
self.transforms = partial(apply_transforms, transforms=transforms)
model_transforms = transforms
model_transforms.transforms = model_transforms.transforms + (

Check warning on line 726 in cyclops/monitor/tester.py

View check run for this annotation

Codecov / codecov/patch

cyclops/monitor/tester.py#L724-L726

Added lines #L724 - L726 were not covered by tests
Lambdad(
keys=("mask", "labels"),
func=lambda x: np.array(x),
allow_missing_keys=True,
),
)
self.model_transforms = partial(

Check warning on line 733 in cyclops/monitor/tester.py

View check run for this annotation

Codecov / codecov/patch

cyclops/monitor/tester.py#L733

Added line #L733 was not covered by tests
apply_transforms,
transforms=model_transforms,
)
else:
self.transforms = None
self.model_transforms = None
elif is_sklearn_model(base_model):

Check warning on line 740 in cyclops/monitor/tester.py

View check run for this annotation

Codecov / codecov/patch

cyclops/monitor/tester.py#L740

Added line #L740 was not covered by tests
self.base_model = wrap_model(base_model)
self.base_model.save_model(

Check warning on line 742 in cyclops/monitor/tester.py

View check run for this annotation

Codecov / codecov/patch

cyclops/monitor/tester.py#L742

Added line #L742 was not covered by tests
"saved_models/DetectronModule/pretrained_model.pkl", log=False
)
self.transforms = transforms
self.model_transforms = transforms
elif isinstance(base_model, (PTModel, SKModel)):
elif isinstance(base_model, SKModel):
self.base_model = base_model
self.base_model.save_model(

Check warning on line 749 in cyclops/monitor/tester.py

View check run for this annotation

Codecov / codecov/patch

cyclops/monitor/tester.py#L745-L749

Added lines #L745 - L749 were not covered by tests
"saved_models/DetectronModule/pretrained_model.pkl", log=False
)
self.transforms = transforms
self.model_transforms = transforms
elif isinstance(base_model, PTModel):
self.base_model = base_model
self.base_model.save_model(

Check warning on line 756 in cyclops/monitor/tester.py

View check run for this annotation

Codecov / codecov/patch

cyclops/monitor/tester.py#L754-L756

Added lines #L754 - L756 were not covered by tests
"saved_models/DetectronModule/pretrained_model.pt", log=False
)
else:
raise ValueError("base_model must be a PyTorch or sklearn model.")

Check warning on line 760 in cyclops/monitor/tester.py

View check run for this annotation

Codecov / codecov/patch

cyclops/monitor/tester.py#L760

Added line #L760 was not covered by tests

Expand Down Expand Up @@ -774,9 +786,14 @@ def fit(self, X_s: Union[Dataset, DatasetDict, np.ndarray, TorchDataset]):
for seed in range(self.num_runs):
# train ensemble of for split 'p*'
for e in range(1, self.ensemble_size + 1):
self.base_model.load_model(
"saved_models/DetectronModule/pretrained_model.pkl", log=False
)
if is_pytorch_model(self.base_model.model):
self.base_model.load_model(
"saved_models/DetectronModule/pretrained_model.pt", log=False
)
elif is_sklearn_model(self.base_model.model):
self.base_model.load_model(

Check warning on line 794 in cyclops/monitor/tester.py

View check run for this annotation

Codecov / codecov/patch

cyclops/monitor/tester.py#L793-L794

Added lines #L793 - L794 were not covered by tests
"saved_models/DetectronModule/pretrained_model.pkl", log=False
)
alpha = 1 / (len(X_s) * self.sample_size + 1)
if is_pytorch_model(self.base_model.model):
model = wrap_model(
Expand Down Expand Up @@ -896,9 +913,14 @@ def predict(self, X_t: Union[Dataset, DatasetDict, np.ndarray, TorchDataset]):
for seed in range(self.num_runs):
# train ensemble of for split 'p*'
for e in range(1, self.ensemble_size + 1):
self.base_model.load_model(
"saved_models/DetectronModule/pretrained_model.pkl", log=False
)
if is_pytorch_model(self.base_model.model):
self.base_model.load_model(
"saved_models/DetectronModule/pretrained_model.pt", log=False
)
elif is_sklearn_model(self.base_model.model):
self.base_model.load_model(

Check warning on line 921 in cyclops/monitor/tester.py

View check run for this annotation

Codecov / codecov/patch

cyclops/monitor/tester.py#L920-L921

Added lines #L920 - L921 were not covered by tests
"saved_models/DetectronModule/pretrained_model.pkl", log=False
)
alpha = 1 / (len(X_t) * self.sample_size + 1)
if is_pytorch_model(self.base_model.model):
model = wrap_model(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@
" base_model=readmission_prediction_task.models[\"xgb_classifier\"],\n",
" feature_column=features_list,\n",
" transforms=preprocessor,\n",
" splits_mapping={\"train\": \"train\", \"test\": \"test\"},\n",
" splits_mapping={\"train\": \"train\", \"test\": \"validation\"},\n",
" sample_size=250,\n",
" num_runs=5,\n",
" ensemble_size=5,\n",
Expand All @@ -537,9 +537,9 @@
"outputs": [],
"source": [
"results = tester.predict(\n",
" X_t=DatasetDict({\"train\": dataset[\"train\"][\"train\"], \"test\": dataset[\"test\"]})\n",
" X_t=DatasetDict({\"train\": dataset[\"train\"][\"train\"], \"validation\": dataset[\"test\"]})\n",
")\n",
"print(results[\"model_health\"])"
"print(results[\"data\"][\"model_health\"])"
]
},
{
Expand Down Expand Up @@ -575,9 +575,9 @@
"model_health = []\n",
"for data in test_data_list:\n",
" results = tester.predict(\n",
" X_t=DatasetDict({\"train\": dataset[\"train\"][\"train\"], \"test\": data})\n",
" X_t=DatasetDict({\"train\": dataset[\"train\"][\"train\"], \"validation\": data})\n",
" )\n",
" model_health.append(results[\"model_health\"])"
" model_health.append(results[\"data\"][\"model_health\"])"
]
},
{
Expand Down

0 comments on commit f062494

Please sign in to comment.