Skip to content

Commit

Permalink
Merge pull request #393 from HiroIshida/grouping_sdf_implemented
Browse files Browse the repository at this point in the history
refactor: sdf-converatable class inherits from SDFCapable
  • Loading branch information
iory authored Sep 3, 2024
2 parents f5752d5 + ee4d1d8 commit 06cef95
Showing 1 changed file with 25 additions and 8 deletions.
33 changes: 25 additions & 8 deletions skrobot/model/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@
from skrobot.sdf import trimesh2sdf


class SDFImplemented:
@property
def sdf(self):
if self._sdf is None:
msg = "This link does not have SDF. Please set with_sdf=True"
raise AttributeError(msg)
return self._sdf


class Axis(Link):

def __init__(self,
Expand Down Expand Up @@ -46,7 +55,7 @@ def from_cascoords(cls, cascoords, **kwargs):
return link


class Box(Link):
class Box(Link, SDFImplemented):

def __init__(self, extents, vertex_colors=None, face_colors=None,
pos=(0, 0, 0), rot=np.eye(3), name=None, with_sdf=False):
Expand All @@ -66,7 +75,9 @@ def __init__(self, extents, vertex_colors=None, face_colors=None,
if with_sdf:
sdf = BoxSDF(extents)
self.assoc(sdf, relative_coords="local")
self.sdf = sdf
self._sdf = sdf
else:
self._sdf = None


class CameraMarker(Link):
Expand Down Expand Up @@ -118,7 +129,7 @@ def __init__(self, radius, height,
visual_mesh=mesh)


class Cylinder(Link):
class Cylinder(Link, SDFImplemented):

def __init__(self, radius, height,
sections=32,
Expand All @@ -142,10 +153,12 @@ def __init__(self, radius, height,
if with_sdf:
sdf = CylinderSDF(height, radius)
self.assoc(sdf, relative_coords="local")
self.sdf = sdf
self._sdf = sdf
else:
self._sdf = None


class Sphere(Link):
class Sphere(Link, SDFImplemented):

def __init__(self, radius, subdivisions=3, color=None,
pos=(0, 0, 0), rot=np.eye(3), name=None, with_sdf=False):
Expand All @@ -165,7 +178,9 @@ def __init__(self, radius, subdivisions=3, color=None,
if with_sdf:
sdf = SphereSDF(radius)
self.assoc(sdf, relative_coords="local")
self.sdf = sdf
self._sdf = sdf
else:
self._sdf = None


class Annulus(Link):
Expand Down Expand Up @@ -224,7 +239,7 @@ def __init__(self,
visual_mesh=mesh)


class MeshLink(Link):
class MeshLink(Link, SDFImplemented):

def __init__(self,
visual_mesh=None,
Expand All @@ -245,7 +260,9 @@ def __init__(self,
if with_sdf:
sdf = trimesh2sdf(self._collision_mesh, **gridsdf_kwargs)
self.assoc(sdf, relative_coords="local")
self.sdf = sdf
self._sdf = sdf
else:
self._sdf = None


class PointCloudLink(Link):
Expand Down

0 comments on commit 06cef95

Please sign in to comment.