diff --git a/thermo_nerf/thermal_nerf/thermal_field.py b/thermo_nerf/thermal_nerf/thermal_field.py index 20bc6e1..88dac28 100644 --- a/thermo_nerf/thermal_nerf/thermal_field.py +++ b/thermo_nerf/thermal_nerf/thermal_field.py @@ -16,15 +16,12 @@ class ThermalFieldHead(BaseThermalFieldHead): - """Thermal output - - Args: - num_classes: Number of semantic classes - in_dim: input dimension. If not defined in constructor, it must be set later. - activation: output head activation - """ + """Thermal output""" def __init__(self, in_dim: int | None = None) -> None: + """`in_dim` is the input dimension. If not defined in the constructor, + it must be set later. + """ super().__init__( in_dim=in_dim, out_dim=1, diff --git a/thermo_nerf/thermal_nerf/thermal_field_head.py b/thermo_nerf/thermal_nerf/thermal_field_head.py index 84c7021..4b4c58a 100644 --- a/thermo_nerf/thermal_nerf/thermal_field_head.py +++ b/thermo_nerf/thermal_nerf/thermal_field_head.py @@ -29,6 +29,10 @@ def __init__( in_dim: Optional[int] = None, activation: Optional[Union[nn.Module, Callable]] = None, ) -> None: + """`out_dim` represents the output dimension for the renderer. + `field_head_name` is the type of field output. + `in_dim` is the input dimension. If not defined in the constructor, it must be + set later. `activation` is the output head activation.""" super().__init__() self.out_dim = out_dim self.activation = activation @@ -49,13 +53,12 @@ def _construct_net(self): def forward( self, in_tensor: Shaped[Tensor, "*bs in_dim"] ) -> Shaped[Tensor, "*bs out_dim"]: - """Process network output for renderer + """ + Process network output for renderer - Args: - in_tensor: Network input + `in_tensor` is the network input. - Returns: - Render head output + :return: Render head output """ if not self.net: raise SystemError(