Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 26 user defined params #44

Merged
merged 7 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Unreleased

- Allow user-defined parameters to be added using the field ["Parameterisation"]["User-defined"] ([#44](https://github.com/pybamm-team/BPX/pull/44))
- Added validation based on models: SPM, SPMe, DFN ([#34](https://github.com/pybamm-team/BPX/pull/34)). A warning will be produced if the user-defined model type does not match the parameter set (e.g., if the model is `SPM`, but the full DFN model parameters are provided).
- Added support for well-mixed, blended electrodes that contain more than one active material ([#33](https://github.com/pybamm-team/BPX/pull/33))

Expand Down
37 changes: 33 additions & 4 deletions bpx/schema.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
from typing import List, Literal, Union, Dict

from typing import List, Literal, Union, Dict, get_args
from pydantic import BaseModel, Field, Extra, root_validator

from bpx import Function, InterpolatedTable

from warnings import warn

FloatFunctionTable = Union[float, Function, InterpolatedTable]
Expand Down Expand Up @@ -272,6 +269,30 @@ class ElectrodeBlendedSPM(ContactBase):
particle: Dict[str, Particle] = Field(alias="Particle")


class UserDefined(BaseModel):
class Config:
extra = Extra.allow

def __init__(self, **data):
"""
Overwrite the default __init__ to convert strings to Function objects and
dicts to InterpolatedTable objects
"""
for k, v in data.items():
if isinstance(v, str):
data[k] = Function(v)
elif isinstance(v, dict):
data[k] = InterpolatedTable(**v)
super().__init__(**data)

@root_validator(pre=True)
def validate_extra_fields(cls, values):
for k, v in values.items():
if not isinstance(v, get_args(FloatFunctionTable)):
raise TypeError(f"{k} must be of type 'FloatFunctionTable'")
return values


class Experiment(ExtraBaseModel):
time: List[float] = Field(
alias="Time [s]",
Expand Down Expand Up @@ -312,6 +333,10 @@ class Parameterisation(ExtraBaseModel):
separator: Contact = Field(
alias="Separator",
)
user_defined: UserDefined = Field(
None,
alias="User-defined",
)


class ParameterisationSPM(ExtraBaseModel):
Expand All @@ -324,6 +349,10 @@ class ParameterisationSPM(ExtraBaseModel):
positive_electrode: Union[ElectrodeSingleSPM, ElectrodeBlendedSPM] = Field(
alias="Positive electrode",
)
user_defined: UserDefined = Field(
None,
alias="User-defined",
)


class BPX(ExtraBaseModel):
Expand Down
36 changes: 36 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,42 @@ def test_validation_data(self):
},
}

def test_user_defined(self):
test = copy.copy(self.base)
test["Parameterisation"]["User-defined"] = {
"a": 1.0,
"b": 2.0,
"c": 3.0,
}
obj = parse_obj_as(BPX, test)
self.assertEqual(obj.parameterisation.user_defined.a, 1)
self.assertEqual(obj.parameterisation.user_defined.b, 2)
self.assertEqual(obj.parameterisation.user_defined.c, 3)

def test_user_defined_table(self):
test = copy.copy(self.base)
test["Parameterisation"]["User-defined"] = {
"a": {
"x": [1.0, 2.0],
"y": [2.3, 4.5],
},
}
parse_obj_as(BPX, test)

def test_user_defined_function(self):
test = copy.copy(self.base)
test["Parameterisation"]["User-defined"] = {"a": "2.0 * x"}
parse_obj_as(BPX, test)

def test_bad_user_defined(self):
test = copy.copy(self.base)
# bool not allowed type
test["Parameterisation"]["User-defined"] = {
"bad": True,
}
with self.assertRaises(ValidationError):
parse_obj_as(BPX, test)


if __name__ == "__main__":
unittest.main()
Loading