Skip to content

Commit

Permalink
Update optimize api (#103)
Browse files Browse the repository at this point in the history
* updating converter to correctly type.

* remove initial_guess_interventions

* update test.

* adding start_time_param_value

* updating version of pyciemss
  • Loading branch information
Tom-Szendrey authored Jul 17, 2024
1 parent 08bf80d commit c1b18fa
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ httpx = "^0.24.1"


[tool.poe.tasks]
install-pyciemss = "pip install --no-cache-dir git+https://github.com/ciemss/pyciemss.git@7967dfaede3136dfec4b8e08dd11b272a6190677 --use-pep517"
install-pyciemss = "pip install --no-cache-dir git+https://github.com/ciemss/pyciemss.git@672201f44752e40ed3a92a335ada136bc30f776e --use-pep517"


[tool.pytest.ini_options]
Expand Down
10 changes: 7 additions & 3 deletions service/models/operations/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pyciemss.integration_utils.intervention_builder import (
param_value_objective,
start_time_objective,
start_time_param_value_objective,
)

from pyciemss.ouu.qoi import obs_nday_average_qoi, obs_max_qoi
Expand Down Expand Up @@ -90,7 +91,6 @@ class Optimize(OperationRequest):
step_size: float = 1.0
qoi: QOI
risk_bound: float
initial_guess_interventions: List[float]
bounds_interventions: List[List[float]]
extra: OptimizeExtra = Field(
None,
Expand All @@ -117,7 +117,7 @@ def gen_pyciemss_args(self, job_id):
param_name=self.optimize_interventions.param_names,
param_value=param_value,
)
else:
if intervention_type == "start_time":
assert self.optimize_interventions.param_values is not None
param_value = [
torch.tensor(value)
Expand All @@ -127,6 +127,10 @@ def gen_pyciemss_args(self, job_id):
param_name=self.optimize_interventions.param_names,
param_value=param_value,
)
if intervention_type == "start_time_param_value":
optimize_interventions = start_time_param_value_objective(
param_name=self.optimize_interventions.param_names
)

extra_options = self.extra.dict()
inferred_parameters = fetch_inferred_parameters(
Expand All @@ -146,7 +150,7 @@ def gen_pyciemss_args(self, job_id):
),
"qoi": self.qoi.gen_call(),
"risk_bound": self.risk_bound,
"initial_guess_interventions": self.initial_guess_interventions,
"initial_guess_interventions": self.optimize_interventions.initial_guess,
"bounds_interventions": self.bounds_interventions,
"static_parameter_interventions": optimize_interventions,
"fixed_static_parameter_interventions": fixed_static_parameter_interventions,
Expand Down
3 changes: 1 addition & 2 deletions tests/examples/optimize/input/request.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
"method": "day_average"
},
"risk_bound": 10.0,
"initial_guess_interventions": [1.0],
"bounds_interventions": [[0.0], [3.0]],
"bounds_interventions": [[0.0], [3.0]],
"extra": {
"num_samples": 4,
"is_minimized": true
Expand Down

0 comments on commit c1b18fa

Please sign in to comment.