Skip to content

Commit

Permalink
Merge pull request #44 from FaradayInstitution/issue-26-user-defined-…
Browse files Browse the repository at this point in the history
…params

Issue 26 user defined params
  • Loading branch information
rtimms committed Oct 19, 2023
2 parents cd00a67 + b797fea commit 79ecbc2
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 5 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Unreleased

- Allow user-defined parameters to be added using the field ["Parameterisation"]["User-defined"] ([#44](https://github.com/pybamm-team/BPX/pull/44))
- Added validation based on models: SPM, SPMe, DFN ([#34](https://github.com/pybamm-team/BPX/pull/34)). A warning will be produced if the user-defined model type does not match the parameter set (e.g., if the model is `SPM`, but the full DFN model parameters are provided).
- Added support for well-mixed, blended electrodes that contain more than one active material ([#33](https://github.com/pybamm-team/BPX/pull/33))
- Added five parametrisation examples (two DFN parametrisation examples from About:Energy open-source release, blended electrode definition, user-defined 0th-order hysteresis, and SPM parametrisation).
Expand Down
37 changes: 33 additions & 4 deletions bpx/schema.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
from typing import List, Literal, Union, Dict

from typing import List, Literal, Union, Dict, get_args
from pydantic import BaseModel, Field, Extra, root_validator

from bpx import Function, InterpolatedTable

from warnings import warn

FloatFunctionTable = Union[float, Function, InterpolatedTable]
Expand Down Expand Up @@ -272,6 +269,30 @@ class ElectrodeBlendedSPM(ContactBase):
particle: Dict[str, Particle] = Field(alias="Particle")


class UserDefined(BaseModel):
class Config:
extra = Extra.allow

def __init__(self, **data):
"""
Overwrite the default __init__ to convert strings to Function objects and
dicts to InterpolatedTable objects
"""
for k, v in data.items():
if isinstance(v, str):
data[k] = Function(v)
elif isinstance(v, dict):
data[k] = InterpolatedTable(**v)
super().__init__(**data)

@root_validator(pre=True)
def validate_extra_fields(cls, values):
for k, v in values.items():
if not isinstance(v, get_args(FloatFunctionTable)):
raise TypeError(f"{k} must be of type 'FloatFunctionTable'")
return values


class Experiment(ExtraBaseModel):
time: List[float] = Field(
alias="Time [s]",
Expand Down Expand Up @@ -312,6 +333,10 @@ class Parameterisation(ExtraBaseModel):
separator: Contact = Field(
alias="Separator",
)
user_defined: UserDefined = Field(
None,
alias="User-defined",
)


class ParameterisationSPM(ExtraBaseModel):
Expand All @@ -324,6 +349,10 @@ class ParameterisationSPM(ExtraBaseModel):
positive_electrode: Union[ElectrodeSingleSPM, ElectrodeBlendedSPM] = Field(
alias="Positive electrode",
)
user_defined: UserDefined = Field(
None,
alias="User-defined",
)


class BPX(ExtraBaseModel):
Expand Down
36 changes: 36 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,42 @@ def test_validation_data(self):
},
}

def test_user_defined(self):
test = copy.copy(self.base)
test["Parameterisation"]["User-defined"] = {
"a": 1.0,
"b": 2.0,
"c": 3.0,
}
obj = parse_obj_as(BPX, test)
self.assertEqual(obj.parameterisation.user_defined.a, 1)
self.assertEqual(obj.parameterisation.user_defined.b, 2)
self.assertEqual(obj.parameterisation.user_defined.c, 3)

def test_user_defined_table(self):
test = copy.copy(self.base)
test["Parameterisation"]["User-defined"] = {
"a": {
"x": [1.0, 2.0],
"y": [2.3, 4.5],
},
}
parse_obj_as(BPX, test)

def test_user_defined_function(self):
test = copy.copy(self.base)
test["Parameterisation"]["User-defined"] = {"a": "2.0 * x"}
parse_obj_as(BPX, test)

def test_bad_user_defined(self):
test = copy.copy(self.base)
# bool not allowed type
test["Parameterisation"]["User-defined"] = {
"bad": True,
}
with self.assertRaises(ValidationError):
parse_obj_as(BPX, test)


if __name__ == "__main__":
unittest.main()

0 comments on commit 79ecbc2

Please sign in to comment.