diff --git a/elastica/surface/__init__.py b/elastica/surface/__init__.py index ce755872e..efe78d06c 100644 --- a/elastica/surface/__init__.py +++ b/elastica/surface/__init__.py @@ -1,2 +1,3 @@ __doc__ = """Surface classes""" from elastica.surface.surface_base import SurfaceBase +from elastica.surface.plane import Plane diff --git a/elastica/surface/plane.py b/elastica/surface/plane.py new file mode 100644 index 000000000..ce2d96f3c --- /dev/null +++ b/elastica/surface/plane.py @@ -0,0 +1,31 @@ +__doc__ = """plane surface class""" + +from elastica.surface.surface_base import SurfaceBase +import numpy as np +from numpy.testing import assert_allclose +from elastica.utils import Tolerance + + +class Plane(SurfaceBase): + def __init__(self, plane_origin: np.ndarray, plane_normal: np.ndarray): + """ + Plane surface initializer. + + Parameters + ---------- + plane_origin: np.ndarray + Origin of the plane. + Expect (3,1)-shaped array. + plane_normal: np.ndarray + The normal vector of the plane, must be normalized. + Expect (3,1)-shaped array. + """ + + assert_allclose( + np.linalg.norm(plane_normal), + 1, + atol=Tolerance.atol(), + err_msg="plane normal is not a unit vector", + ) + self.normal = np.asarray(plane_normal).reshape(3) + self.origin = np.asarray(plane_origin).reshape(3, 1) diff --git a/tests/test_surface/test_plane.py b/tests/test_surface/test_plane.py new file mode 100644 index 000000000..f82d6a0ce --- /dev/null +++ b/tests/test_surface/test_plane.py @@ -0,0 +1,41 @@ +__doc__ = """Tests for plane surface class""" +import numpy as np +from numpy.testing import assert_allclose +from elastica.utils import Tolerance +from elastica.surface import Plane +import pytest + + +# tests Initialisation of plane +def test_plane_initialization(): + """ + This test case is for testing initialization of rigid sphere and it checks the + validity of the members of sphere class. + + Returns + ------- + + """ + # setting up test params + plane_origin = np.random.rand(3).reshape(3, 1) + plane_normal_direction = np.random.rand(3).reshape(3) + plane_normal = plane_normal_direction / np.linalg.norm(plane_normal_direction) + + test_plane = Plane(plane_origin, plane_normal) + # checking plane origin + assert_allclose( + test_plane.origin, + plane_origin, + atol=Tolerance.atol(), + ) + + # checking plane normal + assert_allclose(test_plane.normal, plane_normal, atol=Tolerance.atol()) + + # check unnormalized error message + invalid_plane_normal = np.zeros( + 3, + ) + with pytest.raises(AssertionError) as excinfo: + test_plane = Plane(plane_origin, invalid_plane_normal) + assert "plane normal is not a unit vector" in str(excinfo.value)