Skip to content

Commit

Permalink
Convert gravity parameter to default dtype in KinDynComputationsBatch
Browse files Browse the repository at this point in the history
  • Loading branch information
GiulioRomualdi committed Mar 2, 2025
1 parent 7def711 commit 1a1c7f9
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/adam/pytorch/computation_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,18 @@ def __init__(
joints_name_list (list): list of the actuated joints
root_link (str, optional): Deprecated. The root link is automatically chosen as the link with no parent in the URDF. Defaults to None.
"""

def to_default_dtype(tensor):
"""Converts a JAX tensor to the default floating-point type (float32 or float64)."""
default_dtype = jnp.array(0.0).dtype # Get the default floating-point dtype
return tensor.astype(default_dtype)

math = SpatialMath()
factory = URDFModelFactory(path=urdfstring, math=math)
model = Model.build(factory=factory, joints_name_list=joints_name_list)
self.rbdalgos = RBDAlgorithms(model=model, math=math)
self.NDoF = self.rbdalgos.NDoF
self.g = gravity
self.g = to_default_dtype(gravity)
self.funcs = {}
if root_link is not None:
warnings.warn(
Expand Down

0 comments on commit 1a1c7f9

Please sign in to comment.