diff --git a/festim/__init__.py b/festim/__init__.py index e2c4ee4c8..b22e350eb 100644 --- a/festim/__init__.py +++ b/festim/__init__.py @@ -84,7 +84,16 @@ from .exports.derived_quantities.minimum_volume import MinimumVolume from .exports.derived_quantities.minimum_surface import MinimumSurface from .exports.derived_quantities.maximum_surface import MaximumSurface -from .exports.derived_quantities.total_surface import TotalSurface +from .exports.derived_quantities.total_surface import ( + TotalSurface, + TotalSurfaceCylindrical, + TotalSurfaceSpherical, +) +from .exports.derived_quantities.total_volume import ( + TotalVolume, + TotalVolumeCylindrical, + TotalVolumeSpherical, +) from .exports.derived_quantities.total_volume import TotalVolume from .exports.derived_quantities.average_surface import ( AverageSurface, diff --git a/festim/exports/derived_quantities/total_surface.py b/festim/exports/derived_quantities/total_surface.py index d9870ef42..db8c1c784 100644 --- a/festim/exports/derived_quantities/total_surface.py +++ b/festim/exports/derived_quantities/total_surface.py @@ -1,5 +1,6 @@ from festim import SurfaceQuantity import fenics as f +import numpy as np class TotalSurface(SurfaceQuantity): @@ -58,3 +59,166 @@ def title(self): def compute(self): return f.assemble(self.function * self.ds(self.surface)) + + +class TotalSurfaceCylindrical(TotalSurface): + """ + Computes the total value of a field on a given surface + int(f ds) + ds is the surface measure in cylindrical coordinates. + ds = r dr dtheta + + Args: + field (str, int): the field ("solute", 0, 1, "T", "retention") + surface (int): the surface id + azimuth_range (tuple, optional): Range of the azimuthal angle + (theta) needs to be between 0 and 2 pi. Defaults to (0, 2 * np.pi) + + Attributes: + field (str, int): the field ("solute", 0, 1, "T", "retention") + surface (int): the surface id + title (str): the title of the derived quantity + show_units (bool): show the units in the title in the derived quantities + file + function (dolfin.function.function.Function): the solution function of + the field + r (ufl.indexed.Indexed): the radius of the cylinder + + .. note:: + Units are in H/m in 1D, H in 2D domains for hydrogen concentration + and K m in 1D, K m2 in 2D domains for temperature + """ + + def __init__(self, field, surface, azimuth_range=(0, 2 * np.pi)) -> None: + super().__init__(field=field, surface=surface) + self.r = None + self.azimuth_range = azimuth_range + + @property + def export_unit(self): + # obtain domain dimension + try: + dim = self.function.function_space().mesh().topology().dim() + except AttributeError: + dim = self.dx._domain._topological_dimension + # TODO we could simply do that all the time + # return unit depending on field and dimension of domain + if self.field == "T": + return f"K m{dim}".replace(" m1", " m") + else: + return f"H m{dim-2}".replace(" m0", "") + + @property + def azimuth_range(self): + return self._azimuth_range + + @azimuth_range.setter + def azimuth_range(self, value): + if value[0] < 0 or value[1] > 2 * np.pi: + raise ValueError("Azimuthal range must be between 0 and pi") + self._azimuth_range = value + + @property + def allowed_meshes(self): + return ["cylindrical"] + + def compute(self): + + if self.r is None: + mesh = ( + self.function.function_space().mesh() + ) # get the mesh from the function + rthetaz = f.SpatialCoordinate(mesh) # get the coordinates from the mesh + self.r = rthetaz[0] # only care about r here + + tot_surf = f.assemble(self.function * self.r * self.ds(self.surface)) + tot_surf *= self.azimuth_range[1] - self.azimuth_range[0] + + return tot_surf + + +class TotalSurfaceSpherical(TotalSurface): + """ + Computes the total value of a field on a given surface + int(f ds) + ds is the surface measure in spherical coordinates. + ds = r**2 sin(theta) dtheta dphi + + Args: + field (str, int): the field ("solute", 0, 1, "T", "retention") + surface (int): the surface id + azimuth_range (tuple, optional): Range of the azimuthal angle + (phi) needs to be between 0 and 2 pi. Defaults to (0, 2 * np.pi) + polar_range (tuple, optional): Range of the polar angle + (theta) needs to be between 0 and pi. Defaults to (0, np.pi). + + Attributes: + field (str, int): the field ("solute", 0, 1, "T", "retention") + surface (int): the surface id + title (str): the title of the derived quantity + show_units (bool): show the units in the title in the derived quantities + file + function (dolfin.function.function.Function): the solution function of + the field + r (ufl.indexed.Indexed): the radius of the cylinder + + .. note:: + Units are in H for hydrogen concentration + and K in 1D, K m in 2D domains for temperature + """ + + def __init__( + self, field, surface, azimuth_range=(0, 2 * np.pi), polar_range=(0, np.pi) + ) -> None: + super().__init__(field=field, surface=surface) + self.r = None + self.azimuth_range = azimuth_range + self.polar_range = polar_range + + @property + def export_unit(self): + if self.field == "T": + return f"K m2" + else: + return "H" + + @property + def azimuth_range(self): + return self._azimuth_range + + @azimuth_range.setter + def azimuth_range(self, value): + if value[0] < 0 or value[1] > 2 * np.pi: + raise ValueError("Azimuthal range must be between 0 and 2 pi") + self._azimuth_range = value + + @property + def polar_range(self): + return self._polar_range + + @polar_range.setter + def polar_range(self, value): + if value[0] < 0 or value[1] > np.pi: + raise ValueError("Polar range must be between 0 and pi") + self._polar_range = value + + @property + def allowed_meshes(self): + return ["spherical"] + + def compute(self): + + if self.r is None: + mesh = ( + self.function.function_space().mesh() + ) # get the mesh from the function + rthetaphi = f.SpatialCoordinate(mesh) # get the coordinates from the mesh + self.r = rthetaphi[0] # only care about r here + + tot_surf = f.assemble(self.function * self.r**2 * self.ds(self.surface)) + + tot_surf *= (self.azimuth_range[1] - self.azimuth_range[0]) * ( + np.cos(self.polar_range[0]) - np.cos(self.polar_range[1]) + ) + + return tot_surf diff --git a/festim/exports/derived_quantities/total_volume.py b/festim/exports/derived_quantities/total_volume.py index fd770dde1..b197ca32d 100644 --- a/festim/exports/derived_quantities/total_volume.py +++ b/festim/exports/derived_quantities/total_volume.py @@ -1,5 +1,6 @@ from festim import VolumeQuantity import fenics as f +import numpy as np class TotalVolume(VolumeQuantity): @@ -58,3 +59,163 @@ def title(self): def compute(self): return f.assemble(self.function * self.dx(self.volume)) + + +class TotalVolumeCylindrical(TotalVolume): + """Computes the total value of a field for a given volume + int(f dx) + dx is the volume measure in cylindrical coordinates. + dx = r dr dtheta dz + + Args: + field (str, int): the field ("solute", 0, 1, "T", "retention") + volume (int): the volume id + azimuth_range (tuple, optional): Range of the azimuthal angle + (theta) needs to be between 0 and 2 pi. Defaults to (0, 2 * np.pi) + + Attributes: + field (str, int): the field ("solute", 0, 1, "T", "retention") + volume (int): the volume id + title (str): the title of the derived quantity + show_units (bool): show the units in the title in the derived quantities + file + function (dolfin.function.function.Function): the solution function of + the field + r (ufl.indexed.Indexed): the radius of the cylinder + + .. note:: + Units are in H/m in 1D and H in 2D for hydrogen concentration + and K m2 in 1D, K m3 in 2D domains for temperature + """ + + def __init__(self, field, volume, azimuth_range=(0, 2 * np.pi)) -> None: + super().__init__(field=field, volume=volume) + self.r = None + self.azimuth_range = azimuth_range + + @property + def export_unit(self): + # obtain domain dimension + try: + dim = self.function.function_space().mesh().topology().dim() + except AttributeError: + dim = self.dx._domain._topological_dimension + # TODO we could simply do that all the time + # return unit depending on field and dimension of domain + if self.field == "T": + return f"K m{dim+1}" + else: + return f"H m{dim-2}".replace(" m0", "") + + @property + def azimuth_range(self): + return self._azimuth_range + + @azimuth_range.setter + def azimuth_range(self, value): + if value[0] < 0 or value[1] > 2 * np.pi: + raise ValueError("Azimuthal range must be between 0 and 2 pi") + self._azimuth_range = value + + @property + def allowed_meshes(self): + return ["cylindrical"] + + def compute(self): + + if self.r is None: + mesh = ( + self.function.function_space().mesh() + ) # get the mesh from the function + rthetaz = f.SpatialCoordinate(mesh) # get the coordinates from the mesh + self.r = rthetaz[0] # only care about r here + + tot_vol = f.assemble(self.function * self.r * self.dx(self.volume)) + tot_vol *= self.azimuth_range[1] - self.azimuth_range[0] + + return tot_vol + + +class TotalVolumeSpherical(TotalVolume): + """Computes the total value of a field for a given volume + int(f dx) + dx is the volume measure in cylindrical coordinates. + dx = r**2 sin(theta) dtheta dphi dr + + Args: + field (str, int): the field ("solute", 0, 1, "T", "retention") + volume (int): the volume id + azimuth_range (tuple, optional): Range of the azimuthal angle + (phi) needs to be between 0 and 2 pi. Defaults to (0, 2 * np.pi) + polar_range (tuple, optional): Range of the polar angle + (theta) needs to be between 0 and pi. Defaults to (0, np.pi). + + Attributes: + field (str, int): the field ("solute", 0, 1, "T", "retention") + volume (int): the volume id + title (str): the title of the derived quantity + show_units (bool): show the units in the title in the derived quantities + file + function (dolfin.function.function.Function): the solution function of + the field + r (ufl.indexed.Indexed): the radius of the cylinder + + .. note:: + Units are in H for hydrogen concentration and K m2 for temperature + """ + + def __init__( + self, field, volume, azimuth_range=(0, 2 * np.pi), polar_range=(0, np.pi) + ) -> None: + super().__init__(field=field, volume=volume) + self.r = None + self.azimuth_range = azimuth_range + self.polar_range = polar_range + + @property + def export_unit(self): + if self.field == "T": + return f"K m3" + else: + return f"H" + + @property + def azimuth_range(self): + return self._azimuth_range + + @azimuth_range.setter + def azimuth_range(self, value): + if value[0] < 0 or value[1] > 2 * np.pi: + raise ValueError("Azimuthal range must be between 0 and pi") + self._azimuth_range = value + + @property + def polar_range(self): + return self._polar_range + + @polar_range.setter + def polar_range(self, value): + if value[0] < 0 or value[1] > np.pi: + raise ValueError("Polar range must be between 0 and pi") + self._polar_range = value + + @property + def allowed_meshes(self): + return ["spherical"] + + def compute(self): + + if self.r is None: + mesh = ( + self.function.function_space().mesh() + ) # get the mesh from the function + rthetaphi = f.SpatialCoordinate(mesh) # get the coordinates from the mesh + self.r = rthetaphi[0] # only care about r here + + tot_vol = f.assemble(self.function * self.r**2 * self.dx(self.volume)) + + tot_vol *= (self.azimuth_range[1] - self.azimuth_range[0]) * ( + np.cos(self.polar_range[0]) - np.cos(self.polar_range[1]) + ) + + return tot_vol diff --git a/test/unit/test_exports/test_derived_quantities/test_total_surface.py b/test/unit/test_exports/test_derived_quantities/test_total_surface.py index e11a2f131..de4f0319d 100644 --- a/test/unit/test_exports/test_derived_quantities/test_total_surface.py +++ b/test/unit/test_exports/test_derived_quantities/test_total_surface.py @@ -1,8 +1,10 @@ -from festim import TotalSurface +from festim import x, y, TotalSurface, TotalSurfaceCylindrical, TotalSurfaceSpherical import fenics as f import pytest -from .tools import c_1D, c_2D, c_3D +from .tools import c_1D, c_2D, c_3D, mesh_1D, mesh_2D import pytest +from sympy.printing import ccode +import numpy as np @pytest.mark.parametrize("field,surface", [("solute", 1), ("T", 2)]) @@ -64,3 +66,201 @@ def test_title_with_units(function, field, expected_title): my_export.show_units = True assert my_export.title == expected_title + + +@pytest.mark.parametrize("radius", [2, 3]) +@pytest.mark.parametrize("r0", [0, 2]) +@pytest.mark.parametrize("height", [2, 3]) +def test_compute_cylindrical(r0, radius, height): + """ + Test that TotalSurfaceCylindrical computes the total value of a function + correctly on a hollow cylinder + + Args: + r0 (float): internal radius + radius (float): cylinder radius + height (float): cylinder height + """ + # creating a mesh with FEniCS + r1 = r0 + radius + z0, z1 = 0, height + + mesh_fenics = f.RectangleMesh(f.Point(r0, z0), f.Point(r1, z1), 10, 10) + + outer_surface = f.CompiledSubDomain( + f"on_boundary && near(x[0], {r1}, tol)", tol=1e-14 + ) + + surface_markers = f.MeshFunction( + "size_t", mesh_fenics, mesh_fenics.topology().dim() - 1 + ) + surface_markers.set_all(0) + ds = f.Measure("ds", domain=mesh_fenics, subdomain_data=surface_markers) + outer_id = 1 + outer_surface.mark(surface_markers, outer_id) + + my_exp = TotalSurfaceCylindrical("solute", outer_id) + V = f.FunctionSpace(mesh_fenics, "P", 1) + c_fun = lambda r, z: r**2 + z + expr = f.Expression( + ccode(c_fun(x, y)), + degree=1, + ) + my_exp.function = f.interpolate(expr, V) + my_exp.ds = ds + + expected_value = 2 * np.pi * r1**3 * z1 + np.pi * r1 * z1**2 + + computed_value = my_exp.compute() + + assert np.isclose(computed_value, expected_value) + + +@pytest.mark.parametrize("r0", [0, 1.5]) +@pytest.mark.parametrize("radius", [3, 4]) +def test_compute_spherical(r0, radius): + """ + Test that TotalSurfaceSpherical computes the total value of a function + correctly on a hollow sphere + + Args: + r0 (float): internal radius + radius (float): sphere radius + """ + # creating a mesh with FEniCS + r1 = r0 + radius + mesh_fenics = f.IntervalMesh(10, r0, r1) + + # marking physical groups (volumes and surfaces) + outer_surface = f.CompiledSubDomain( + f"on_boundary && near(x[0], {r1}, tol)", tol=1e-14 + ) + surface_markers = f.MeshFunction( + "size_t", mesh_fenics, mesh_fenics.topology().dim() - 1 + ) + surface_markers.set_all(0) + outer_id = 1 + outer_surface.mark(surface_markers, outer_id) + ds = f.Measure("ds", domain=mesh_fenics, subdomain_data=surface_markers) + + my_tot = TotalSurfaceSpherical("solute", outer_id) + V = f.FunctionSpace(mesh_fenics, "P", 1) + c_fun = lambda r: r**2 + r + expr = f.Expression( + ccode(c_fun(x)), + degree=1, + ) + my_tot.function = f.interpolate(expr, V) + my_tot.ds = ds + + expected_value = 4 * np.pi * r1**3 * (1 + r1) + + computed_value = my_tot.compute() + + assert np.isclose(computed_value, expected_value) + + +@pytest.mark.parametrize( + "azimuth_range", [(-1, np.pi), (0, 3 * np.pi), (-1, 3 * np.pi)] +) +def test_azimuthal_range_cylindrical(azimuth_range): + """ + Tests that an error is raised when the azimuthal range is out of bounds + """ + with pytest.raises(ValueError): + TotalSurfaceCylindrical("solute", 1, azimuth_range=azimuth_range) + + +@pytest.mark.parametrize( + "azimuth_range", [(-1, np.pi), (0, 3 * np.pi), (-1, 3 * np.pi)] +) +def test_azimuthal_range_spherical(azimuth_range): + """ + Tests that an error is raised when the azimuthal range is out of bounds + """ + with pytest.raises(ValueError): + TotalSurfaceSpherical("solute", 1, azimuth_range=azimuth_range) + + +@pytest.mark.parametrize( + "polar_range", [(0, 2 * np.pi), (-np.pi, 0), (-2 * np.pi, 3 * np.pi)] +) +def test_polar_range_spherical(polar_range): + """ + Tests that an error is raised when the polar range is out of bounds + """ + with pytest.raises(ValueError): + TotalSurfaceSpherical("solute", 1, polar_range=polar_range) + + +@pytest.mark.parametrize( + "function, field, expected_title", + [ + (c_1D, "solute", "Total solute surface 3 (H m-1)"), + (c_1D, "T", "Total T surface 3 (K m)"), + (c_2D, "solute", "Total solute surface 3 (H)"), + (c_2D, "T", "Total T surface 3 (K m2)"), + ], +) +def test_TotalSurfaceCylindrical_title_with_units(function, field, expected_title): + my_exp = TotalSurfaceCylindrical(field=field, surface=3) + my_exp.function = function + my_exp.show_units = True + + assert my_exp.title == expected_title + + +@pytest.mark.parametrize( + "function, field, expected_title", + [ + (c_1D, "solute", "Total solute surface 4 (H)"), + (c_1D, "T", "Total T surface 4 (K m2)"), + ], +) +def test_TotalSurfaceSpherical_title_with_units(function, field, expected_title): + my_exp = TotalSurfaceSpherical(field=field, surface=4) + my_exp.function = function + my_exp.show_units = True + + assert my_exp.title == expected_title + + +def test_tot_surf_cylindrical_allow_meshes(): + """A simple test to check cylindrical meshes are the only + meshes allowed when using TotalSurfaceCylindrical""" + + my_export = TotalSurfaceCylindrical("solute", 2) + + assert my_export.allowed_meshes == ["cylindrical"] + + +def test_tot_surf_spherical_allow_meshes(): + """A simple test to check spherical meshes are the only + meshes allowed when using TotalSurfaceSpherical""" + + my_export = TotalSurfaceSpherical("solute", 1) + + assert my_export.allowed_meshes == ["spherical"] + + +@pytest.mark.parametrize( + "mesh, field, expected_title", + [ + (mesh_1D, "solute", "Total solute surface 1 (H m-1)"), + (mesh_1D, "T", "Total T surface 1 (K m)"), + (mesh_2D, "solute", "Total solute surface 1 (H)"), + (mesh_2D, "T", "Total T surface 1 (K m2)"), + ], +) +def test_tot_surf_cyl_get_dimension_from_mesh(mesh, field, expected_title): + """A test to ensure the dimension required for the units can be taken + from a mesh and produces the expected title""" + + my_export = TotalSurfaceCylindrical(field, 1) + + vm = f.MeshFunction("size_t", mesh, mesh.topology().dim()) + dx = f.Measure("dx", domain=mesh, subdomain_data=vm) + + my_export.dx = dx + + assert my_export.title == expected_title diff --git a/test/unit/test_exports/test_derived_quantities/test_total_volume.py b/test/unit/test_exports/test_derived_quantities/test_total_volume.py index 1f6929977..6a654db7e 100644 --- a/test/unit/test_exports/test_derived_quantities/test_total_volume.py +++ b/test/unit/test_exports/test_derived_quantities/test_total_volume.py @@ -1,8 +1,9 @@ -from festim import TotalVolume +from festim import x, y, TotalVolume, TotalVolumeCylindrical, TotalVolumeSpherical import fenics as f import pytest -from .tools import c_1D, c_2D, c_3D -import pytest +from .tools import c_1D, c_2D, c_3D, mesh_1D, mesh_2D +from sympy.printing import ccode +import numpy as np @pytest.mark.parametrize("field,volume", [("solute", 1), ("T", 2)]) @@ -64,3 +65,190 @@ def test_title_with_units(function, field, expected_title): my_export.show_units = True assert my_export.title == expected_title + + +@pytest.mark.parametrize("radius", [2, 4]) +@pytest.mark.parametrize("r0", [0, 1.5]) +@pytest.mark.parametrize("height", [2, 3]) +def test_compute_cylindrical(r0, radius, height): + """ + Test that TotalVolumeCylindrical computes the total value of a function + correctly on a hollow cylinder + + Args: + r0 (float): internal radius + radius (float): cylinder radius + height (float): cylinder height + """ + # creating a mesh with FEniCS + r1 = r0 + radius + z0, z1 = 0, height + + mesh_fenics = f.RectangleMesh(f.Point(r0, z0), f.Point(r1, z1), 10, 10) + + volume_id = 3 + volume_markers = f.MeshFunction("size_t", mesh_fenics, mesh_fenics.topology().dim()) + volume_markers.set_all(volume_id) + dx = f.Measure("dx", domain=mesh_fenics, subdomain_data=volume_markers) + + my_exp = TotalVolumeCylindrical("solute", volume_id) + V = f.FunctionSpace(mesh_fenics, "P", 1) + c_fun = lambda r, z: r + z + expr = f.Expression( + ccode(c_fun(x, y)), + degree=1, + ) + my_exp.function = f.interpolate(expr, V) + my_exp.dx = dx + + expected_value = ((np.pi * z1) / 6) * ( + (-4 * r0**3) - (3 * r0**2 * z1) + (r1**2 * (4 * r1 + 3 * z1)) + ) + + computed_value = my_exp.compute() + + assert np.isclose(computed_value, expected_value) + + +@pytest.mark.parametrize("radius", [1.5, 2.5]) +@pytest.mark.parametrize("r0", [0, 1]) +def test_compute_spherical(r0, radius): + """ + Test that TotalVolumeSpherical computes the total value of a function + correctly on a hollow sphere + + Args: + r0 (float): internal radius + radius (float): sphere radius + """ + # creating a mesh with FEniCS + r1 = r0 + radius + + mesh_fenics = f.IntervalMesh(100, r0, r1) + + volume_id = 2 + volume_markers = f.MeshFunction("size_t", mesh_fenics, mesh_fenics.topology().dim()) + volume_markers.set_all(volume_id) + dx = f.Measure("dx", domain=mesh_fenics, subdomain_data=volume_markers) + + my_exp = TotalVolumeSpherical("solute", volume_id) + V = f.FunctionSpace(mesh_fenics, "P", 1) + c_fun = lambda r: r**2 + expr = f.Expression( + ccode(c_fun(x)), + degree=1, + ) + my_exp.function = f.interpolate(expr, V) + my_exp.dx = dx + + expected_value = (4 * np.pi / 5) * (r1**5 - r0**5) + + computed_value = my_exp.compute() + + assert np.isclose(computed_value, expected_value, rtol=1e-04) + + +@pytest.mark.parametrize( + "azimuth_range", [(-1, np.pi), (0, 3 * np.pi), (-1, 3 * np.pi)] +) +def test_azimuthal_range_cylindrical(azimuth_range): + """ + Tests that an error is raised when the azimuthal range is out of bounds + """ + with pytest.raises(ValueError): + TotalVolumeCylindrical("solute", 1, azimuth_range=azimuth_range) + + +@pytest.mark.parametrize( + "azimuth_range", [(-1, np.pi), (0, 3 * np.pi), (-1, 3 * np.pi)] +) +def test_azimuthal_range_spherical(azimuth_range): + """ + Tests that an error is raised when the azimuthal range is out of bounds + """ + with pytest.raises(ValueError): + TotalVolumeSpherical("solute", 1, azimuth_range=azimuth_range) + + +@pytest.mark.parametrize( + "polar_range", [(0, 2 * np.pi), (-np.pi, 0), (-2 * np.pi, 3 * np.pi)] +) +def test_polar_range_spherical(polar_range): + """ + Tests that an error is raised when the polar range is out of bounds + """ + with pytest.raises(ValueError): + TotalVolumeSpherical("solute", 1, polar_range=polar_range) + + +@pytest.mark.parametrize( + "function, field, expected_title", + [ + (c_1D, "solute", "Total solute volume 3 (H m-1)"), + (c_1D, "T", "Total T volume 3 (K m2)"), + (c_2D, "solute", "Total solute volume 3 (H)"), + (c_2D, "T", "Total T volume 3 (K m3)"), + ], +) +def test_TotalVolumeCylindrical_title_with_units(function, field, expected_title): + my_exp = TotalVolumeCylindrical(field=field, volume=3) + my_exp.function = function + my_exp.show_units = True + + assert my_exp.title == expected_title + + +@pytest.mark.parametrize( + "function, field, expected_title", + [ + (c_1D, "solute", "Total solute volume 4 (H)"), + (c_1D, "T", "Total T volume 4 (K m3)"), + ], +) +def test_TotalVolumeSpherical_title_with_units(function, field, expected_title): + my_exp = TotalVolumeSpherical(field=field, volume=4) + my_exp.function = function + my_exp.show_units = True + + assert my_exp.title == expected_title + + +def test_tot_vol_cylindrical_allow_meshes(): + """A simple test to check cylindrical meshes are the only + meshes allowed when using TotalVolumeCylindrical""" + + my_export = TotalVolumeCylindrical("solute", 2) + + assert my_export.allowed_meshes == ["cylindrical"] + + +def test_tot_vol_spherical_allow_meshes(): + """A simple test to check spherical meshes are the only + meshes allowed when using TotalVolumeSpherical""" + + my_export = TotalVolumeSpherical("solute", 1) + + assert my_export.allowed_meshes == ["spherical"] + + +@pytest.mark.parametrize( + "mesh, field, expected_title", + [ + (mesh_1D, "solute", "Total solute volume 1 (H m-1)"), + (mesh_1D, "T", "Total T volume 1 (K m2)"), + (mesh_2D, "solute", "Total solute volume 1 (H)"), + (mesh_2D, "T", "Total T volume 1 (K m3)"), + ], +) +def test_tot_vol_cyl_get_dimension_from_mesh(mesh, field, expected_title): + """A test to ensure the dimension required for the units can be taken + from a mesh and produces the expected title""" + + my_export = TotalVolumeCylindrical(field, 1) + + vm = f.MeshFunction("size_t", mesh, mesh.topology().dim()) + dx = f.Measure("dx", domain=mesh, subdomain_data=vm) + + my_export.dx = dx + + assert my_export.title == expected_title