From ed5d58ab0fa81b50af77be6b8fa7b299e0d5b8c9 Mon Sep 17 00:00:00 2001 From: 0byte-coding <241179729+0byte-coding@users.noreply.github.com> Date: Fri, 9 Jan 2026 16:52:01 +0100 Subject: [PATCH] trimmed some trailing spaces --- trellis2/datasets/__init__.py | 9 +- trellis2/datasets/components.py | 22 +-- trellis2/datasets/flexi_dual_grid.py | 29 ++-- trellis2/datasets/sparse_structure_latent.py | 21 ++- trellis2/datasets/sparse_voxel_pbr.py | 36 ++--- trellis2/datasets/structured_latent.py | 26 ++-- trellis2/datasets/structured_latent_shape.py | 10 +- trellis2/datasets/structured_latent_svpbr.py | 44 +++--- trellis2/models/__init__.py | 6 +- trellis2/models/sparse_elastic_mixin.py | 2 +- trellis2/models/sparse_structure_flow.py | 14 +- trellis2/models/sparse_structure_vae.py | 24 ++-- trellis2/models/structured_latent_flow.py | 16 +-- trellis2/modules/attention/config.py | 12 +- trellis2/modules/attention/full_attn.py | 4 +- trellis2/modules/attention/modules.py | 14 +- trellis2/modules/attention/rope.py | 8 +- trellis2/modules/image_feature_extractor.py | 18 +-- trellis2/modules/norm.py | 11 +- trellis2/modules/sparse/basic.py | 128 +++++++++--------- trellis2/modules/sparse/config.py | 12 +- trellis2/modules/sparse/linear.py | 4 +- trellis2/modules/sparse/nonlinearity.py | 4 +- trellis2/modules/sparse/spatial/basic.py | 9 +- .../modules/sparse/spatial/spatial2channel.py | 6 +- trellis2/modules/transformer/blocks.py | 5 +- trellis2/modules/transformer/modulated.py | 5 +- trellis2/trainers/__init__.py | 14 +- trellis2/trainers/basic.py | 38 +++--- trellis2/trainers/utils.py | 9 +- trellis2/utils/data_utils.py | 12 +- trellis2/utils/dist_utils.py | 7 +- trellis2/utils/elastic_utils.py | 62 ++++----- trellis2/utils/general_utils.py | 22 +-- trellis2/utils/grad_clip_utils.py | 8 +- trellis2/utils/mesh_utils.py | 40 +++--- trellis2/utils/vis_utils.py | 8 +- 37 files changed, 355 insertions(+), 364 deletions(-) diff --git a/trellis2/datasets/__init__.py b/trellis2/datasets/__init__.py index b8f7d94..3e12157 100644 --- a/trellis2/datasets/__init__.py +++ b/trellis2/datasets/__init__.py @@ -3,11 +3,11 @@ __attributes = { 'FlexiDualGridDataset': 'flexi_dual_grid', 'SparseVoxelPbrDataset':'sparse_voxel_pbr', - + 'SparseStructureLatent': 'sparse_structure_latent', 'TextConditionedSparseStructureLatent': 'sparse_structure_latent', 'ImageConditionedSparseStructureLatent': 'sparse_structure_latent', - + 'SLat': 'structured_latent', 'ImageConditionedSLat': 'structured_latent', 'SLatShape': 'structured_latent_shape', @@ -35,12 +35,11 @@ def __getattr__(name): # For Pylance -if __name__ == '__main__': +if __name__ == '__main__': from .flexi_dual_grid import FlexiDualGridDataset from .sparse_voxel_pbr import SparseVoxelPbrDataset - + from .sparse_structure_latent import SparseStructureLatent, ImageConditionedSparseStructureLatent from .structured_latent import SLat, ImageConditionedSLat from .structured_latent_shape import SLatShape, ImageConditionedSLatShape from .structured_latent_svpbr import SLatPbr, ImageConditionedSLatPbr - \ No newline at end of file diff --git a/trellis2/datasets/components.py b/trellis2/datasets/components.py index 6c593ce..4c3e228 100644 --- a/trellis2/datasets/components.py +++ b/trellis2/datasets/components.py @@ -30,7 +30,7 @@ def __init__(self, root_type = 'list' self.instances = [] self.metadata = pd.DataFrame() - + self._stats = {} if root_type == 'obj': for key, root in self.roots.items(): @@ -54,15 +54,15 @@ def __init__(self, self.instances.extend([(root, sha256) for sha256 in metadata['sha256'].values]) metadata.set_index('sha256', inplace=True) self.metadata = pd.concat([self.metadata, metadata]) - + @abstractmethod def filter_metadata(self, metadata: pd.DataFrame) -> Tuple[pd.DataFrame, Dict[str, int]]: pass - + @abstractmethod def get_instance(self, root, instance: str) -> Dict[str, Any]: pass - + def __len__(self): return len(self.instances) @@ -73,7 +73,7 @@ def __getitem__(self, index) -> Dict[str, Any]: except Exception as e: print(f'Error loading {instance}: {e}') return self.__getitem__(np.random.randint(0, len(self))) - + def __str__(self): lines = [] lines.append(self.__class__.__name__) @@ -90,16 +90,16 @@ class ImageConditionedMixin: def __init__(self, roots, *, image_size=518, **kwargs): self.image_size = image_size super().__init__(roots, **kwargs) - + def filter_metadata(self, metadata): metadata, stats = super().filter_metadata(metadata) metadata = metadata[metadata['cond_rendered'].notna()] stats['Cond rendered'] = len(metadata) return metadata, stats - + def get_instance(self, root, instance): pack = super().get_instance(root, instance) - + image_root = os.path.join(root['render_cond'], instance) with open(os.path.join(image_root, 'transforms.json')) as f: metadata = json.load(f) @@ -128,7 +128,7 @@ def get_instance(self, root, instance): alpha = torch.tensor(np.array(alpha)).float() / 255.0 image = image * alpha.unsqueeze(0) pack['cond'] = image - + return pack @@ -143,10 +143,10 @@ def filter_metadata(self, metadata): metadata = metadata[metadata['cond_rendered'].notna()] stats['Cond rendered'] = len(metadata) return metadata, stats - + def get_instance(self, root, instance): pack = super().get_instance(root, instance) - + image_root = os.path.join(root['render_cond'], instance) with open(os.path.join(image_root, 'transforms.json')) as f: metadata = json.load(f) diff --git a/trellis2/datasets/flexi_dual_grid.py b/trellis2/datasets/flexi_dual_grid.py index f870d83..a650599 100644 --- a/trellis2/datasets/flexi_dual_grid.py +++ b/trellis2/datasets/flexi_dual_grid.py @@ -15,11 +15,11 @@ class FlexiDualGridVisMixin: @torch.no_grad() def visualize_sample(self, x: dict): mesh = x['mesh'] - + renderer = MeshRenderer({'near': 1, 'far': 3}) renderer.rendering_options.resolution = 512 renderer.rendering_options.ssaa = 4 - + # Build camera yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2] yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4) @@ -39,7 +39,7 @@ def visualize_sample(self, x: dict): intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov) exts.append(extrinsics) ints.append(intrinsics) - + # Build each representation images = [] for m in mesh: @@ -50,14 +50,14 @@ def visualize_sample(self, x: dict): renderer.render(m.cuda(), ext, intr)['normal'] images.append(image) images = torch.stack(images) - + return images - + class FlexiDualGridDataset(FlexiDualGridVisMixin, StandardDatasetBase): """ Flexible Dual Grid Dataset - + Args: roots (str): path to the dataset resolution (int): resolution of the voxel grid @@ -79,16 +79,16 @@ def __init__( self.value_range = (0, 1) super().__init__(roots) - + self.loads = [self.metadata.loc[sha256, f'dual_grid_size'] for _, sha256 in self.instances] - + def __str__(self): lines = [ super().__str__(), f' - Resolution: {self.resolution}', ] return '\n'.join(lines) - + def filter_metadata(self, metadata): stats = {} metadata = metadata[metadata[f'dual_grid_converted'] == True] @@ -102,7 +102,7 @@ def filter_metadata(self, metadata): metadata = metadata[metadata['num_faces'] <= self.max_num_faces] stats[f'Faces <= {self.max_num_faces}'] = len(metadata) return metadata, stats - + def read_mesh(self, root, instance): with open(os.path.join(root, f'{instance}.pickle'), 'rb') as f: dump = pickle.load(f) @@ -124,7 +124,7 @@ def read_mesh(self, root, instance): vertices = (vertices - center) * scale assert torch.all(vertices >= -0.5) and torch.all(vertices <= 0.5), 'vertices out of range' return {'mesh': [Mesh(vertices=vertices, faces=faces)]} - + def read_dual_grid(self, root, instance): coords, attr = o_voxel.io.read_vxz(os.path.join(root, f'{instance}.vxz'), num_threads=4) vertices = sp.SparseTensor( @@ -142,7 +142,7 @@ def get_instance(self, root, instance): mesh = self.read_mesh(root['mesh_dump'], instance) dual_grid = self.read_dual_grid(root['dual_grid'], instance) return {**mesh, **dual_grid} - + @staticmethod def collate_fn(batch, split_size=None): if split_size is None: @@ -164,10 +164,9 @@ def collate_fn(batch, split_size=None): pack[k] = sum([b[k] for b in sub_batch], []) else: pack[k] = [b[k] for b in sub_batch] - + packs.append(pack) - + if split_size is None: return packs[0] return packs - \ No newline at end of file diff --git a/trellis2/datasets/sparse_structure_latent.py b/trellis2/datasets/sparse_structure_latent.py index 498e115..ad2e034 100644 --- a/trellis2/datasets/sparse_structure_latent.py +++ b/trellis2/datasets/sparse_structure_latent.py @@ -24,7 +24,7 @@ def __init__( self.pretrained_ss_dec = pretrained_ss_dec self.ss_dec_path = ss_dec_path self.ss_dec_ckpt = ss_dec_ckpt - + def _loading_ss_dec(self): if self.ss_dec is not None: return @@ -57,11 +57,11 @@ def decode_latent(self, z, batch_size=4): def visualize_sample(self, x_0: Union[torch.Tensor, dict]): x_0 = x_0 if isinstance(x_0, torch.Tensor) else x_0['x_0'] x_0 = self.decode_latent(x_0.cuda()) - + renderer = VoxelRenderer() renderer.rendering_options.resolution = 512 renderer.rendering_options.ssaa = 4 - + # build camera yaw = [0, np.pi/2, np.pi, 3*np.pi/2] yaw_offset = -16 / 180 * np.pi @@ -70,7 +70,7 @@ def visualize_sample(self, x_0: Union[torch.Tensor, dict]): exts, ints = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, 2, 30) images = [] - + # Build each representation x_0 = x_0.cuda() for i in range(x_0.shape[0]): @@ -92,14 +92,14 @@ def visualize_sample(self, x_0: Union[torch.Tensor, dict]): res = renderer.render(rep, ext, intr, colors_overwrite=color) image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color'] images.append(image) - + return torch.stack(images) class SparseStructureLatent(SparseStructureLatentVisMixin, StandardDatasetBase): """ Sparse structure latent dataset - + Args: roots (str): path to the dataset min_aesthetic_score (float): minimum aesthetic score @@ -120,18 +120,18 @@ def __init__(self, self.min_aesthetic_score = min_aesthetic_score self.normalization = normalization self.value_range = (0, 1) - + super().__init__( roots, pretrained_ss_dec=pretrained_ss_dec, ss_dec_path=ss_dec_path, ss_dec_ckpt=ss_dec_ckpt, ) - + if self.normalization is not None: self.mean = torch.tensor(self.normalization['mean']).reshape(-1, 1, 1, 1) self.std = torch.tensor(self.normalization['std']).reshape(-1, 1, 1, 1) - + def filter_metadata(self, metadata): stats = {} metadata = metadata[metadata['ss_latent_encoded'] == True] @@ -139,7 +139,7 @@ def filter_metadata(self, metadata): metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score] stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata) return metadata, stats - + def get_instance(self, root, instance): latent = np.load(os.path.join(root['ss_latent'], f'{instance}.npz')) z = torch.tensor(latent['z']).float() @@ -157,4 +157,3 @@ class ImageConditionedSparseStructureLatent(ImageConditionedMixin, SparseStructu Image-conditioned sparse structure dataset """ pass - \ No newline at end of file diff --git a/trellis2/datasets/sparse_voxel_pbr.py b/trellis2/datasets/sparse_voxel_pbr.py index 2149b5b..3d5d7c7 100644 --- a/trellis2/datasets/sparse_voxel_pbr.py +++ b/trellis2/datasets/sparse_voxel_pbr.py @@ -31,17 +31,17 @@ def nearest_power_of_two(n: int) -> int: return lower else: return upper - + class SparseVoxelPbrVisMixin: @torch.no_grad() def visualize_sample(self, x: Union[sp.SparseTensor, dict]): x = x if isinstance(x, sp.SparseTensor) else x['x'] - + renderer = VoxelRenderer() renderer.rendering_options.resolution = 512 renderer.rendering_options.ssaa = 4 - + # Build camera yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2] yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4) @@ -63,7 +63,7 @@ def visualize_sample(self, x: Union[sp.SparseTensor, dict]): ints.append(intrinsics) images = {k: [] for k in self.layout} - + # Build each representation x = x.cuda() for i in range(x.shape[0]): @@ -84,17 +84,17 @@ def visualize_sample(self, x: Union[sp.SparseTensor, dict]): res = renderer.render(rep, ext, intr) image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color'] images[k].append(image) - + for k in self.layout: images[k] = torch.stack(images[k]) - + return images class SparseVoxelPbrDataset(SparseVoxelPbrVisMixin, StandardDatasetBase): """ Sparse Voxel PBR dataset. - + Args: roots (str): path to the dataset resolution (int): resolution of the voxel grid @@ -131,9 +131,9 @@ def __init__( start += self.channels[attr] super().__init__(roots) - + self.loads = [self.metadata.loc[sha256, f'num_pbr_voxels'] for _, sha256 in self.instances] - + def __str__(self): lines = [ super().__str__(), @@ -141,7 +141,7 @@ def __str__(self): f' - Attributes: {list(self.layout.keys())}', ] return '\n'.join(lines) - + def filter_metadata(self, metadata): stats = {} metadata = metadata[metadata['pbr_voxelized'] == True] @@ -181,7 +181,7 @@ def _texture_from_dump(pack) -> Texture: def read_mesh_with_texture(self, root, instance): with open(os.path.join(root, f'{instance}.pickle'), 'rb') as f: dump = pickle.load(f) - + # Fix dump alpha map for mat in dump['materials']: if mat['alphaTexture'] is not None and mat['alphaMode'] == 'OPAQUE': @@ -230,12 +230,12 @@ def read_mesh_with_texture(self, root, instance): material_ids.append(obj['mat_ids']) uv_coords.append(obj['uvs'] if obj['uvs'] is not None else np.zeros((obj['faces'].shape[0], 3, 2), dtype=np.float32)) start += len(obj['vertices']) - + vertices = torch.from_numpy(np.concatenate(vertices, axis=0)).float() faces = torch.from_numpy(np.concatenate(faces, axis=0)).long() material_ids = torch.from_numpy(np.concatenate(material_ids, axis=0)).long() uv_coords = torch.from_numpy(np.concatenate(uv_coords, axis=0)).float() - + # Normalize vertices vertices_min = vertices.min(dim=0)[0] vertices_max = vertices.max(dim=0)[0] @@ -243,7 +243,7 @@ def read_mesh_with_texture(self, root, instance): scale = 0.99999 / (vertices_max - vertices_min).max() vertices = (vertices - center) * scale assert torch.all(vertices >= -0.5) and torch.all(vertices <= 0.5), 'vertices out of range' - + return {'mesh': [MeshWithPbrMaterial( vertices=vertices, faces=faces, @@ -260,7 +260,7 @@ def read_pbr_voxel(self, root, instance): torch.cat([torch.zeros_like(coords[:, 0:1]), coords], dim=-1), ) return {'x': x} - + def get_instance(self, root, instance): if self.with_mesh: mesh = self.read_mesh_with_texture(root['pbr_dump'], instance) @@ -268,7 +268,7 @@ def get_instance(self, root, instance): return {**mesh, **pbr_voxel} else: return self.read_pbr_voxel(root['pbr_voxel'], instance) - + @staticmethod def collate_fn(batch, split_size=None): if split_size is None: @@ -290,9 +290,9 @@ def collate_fn(batch, split_size=None): pack[k] = sum([b[k] for b in sub_batch], []) else: pack[k] = [b[k] for b in sub_batch] - + packs.append(pack) - + if split_size is None: return packs[0] return packs diff --git a/trellis2/datasets/structured_latent.py b/trellis2/datasets/structured_latent.py index 2e18a27..a4234e2 100644 --- a/trellis2/datasets/structured_latent.py +++ b/trellis2/datasets/structured_latent.py @@ -25,7 +25,7 @@ def __init__( self.pretrained_slat_dec = pretrained_slat_dec self.slat_dec_path = slat_dec_path self.slat_dec_ckpt = slat_dec_ckpt - + def _loading_slat_dec(self): if self.slat_dec is not None: return @@ -58,7 +58,7 @@ def decode_latent(self, z, batch_size=4): def visualize_sample(self, x_0: Union[SparseTensor, dict]): x_0 = x_0 if isinstance(x_0, SparseTensor) else x_0['x_0'] reps = self.decode_latent(x_0.cuda()) - + # Build camera yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2] yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4) @@ -89,14 +89,14 @@ def visualize_sample(self, x_0: Union[SparseTensor, dict]): image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color'] images.append(image) images = torch.stack(images) - + return images - - + + class SLat(SLatVisMixin, StandardDatasetBase): """ structured latent V2 dataset - + Args: roots (str): path to the dataset min_aesthetic_score (float): minimum aesthetic score @@ -123,7 +123,7 @@ def __init__(self, self.max_tokens = max_tokens self.latent_key = latent_key self.value_range = (0, 1) - + super().__init__( roots, pretrained_slat_dec=pretrained_slat_dec, @@ -132,11 +132,11 @@ def __init__(self, ) self.loads = [self.metadata.loc[sha256, f'{latent_key}_tokens'] for _, sha256 in self.instances] - + if self.normalization is not None: self.mean = torch.tensor(self.normalization['mean']).reshape(1, -1) self.std = torch.tensor(self.normalization['std']).reshape(1, -1) - + def filter_metadata(self, metadata): stats = {} metadata = metadata[metadata[f'{self.latent_key}_encoded'] == True] @@ -157,7 +157,7 @@ def get_instance(self, root, instance): 'coords': coords, 'feats': feats, } - + @staticmethod def collate_fn(batch, split_size=None): if split_size is None: @@ -185,7 +185,7 @@ def collate_fn(batch, split_size=None): ) pack['x_0']._shape = torch.Size([len(group), *sub_batch[0]['feats'].shape[1:]]) pack['x_0'].register_spatial_cache('layout', layout) - + # collate other data keys = [k for k in sub_batch[0].keys() if k not in ['coords', 'feats']] for k in keys: @@ -195,9 +195,9 @@ def collate_fn(batch, split_size=None): pack[k] = sum([b[k] for b in sub_batch], []) else: pack[k] = [b[k] for b in sub_batch] - + packs.append(pack) - + if split_size is None: return packs[0] return packs diff --git a/trellis2/datasets/structured_latent_shape.py b/trellis2/datasets/structured_latent_shape.py index e4a7d88..95fa930 100644 --- a/trellis2/datasets/structured_latent_shape.py +++ b/trellis2/datasets/structured_latent_shape.py @@ -28,14 +28,14 @@ def _loading_slat_dec(self): def visualize_sample(self, x_0: Union[SparseTensor, dict]): x_0 = x_0 if isinstance(x_0, SparseTensor) else x_0['x_0'] reps = self.decode_latent(x_0.cuda()) - + # build camera yaw = [0, np.pi/2, np.pi, 3*np.pi/2] yaw_offset = -16 / 180 * np.pi yaw = [y + yaw_offset for y in yaw] pitch = [20 / 180 * np.pi for _ in range(4)] exts, ints = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, 2, 30) - + # render renderer = get_renderer(reps[0]) images = [] @@ -48,12 +48,12 @@ def visualize_sample(self, x_0: Union[SparseTensor, dict]): images.append(image) images = torch.stack(images) return images - - + + class SLatShape(SLatShapeVisMixin, SLat): """ structured latent for shape generation - + Args: roots (str): path to the dataset resolution (int): resolution of the shape diff --git a/trellis2/datasets/structured_latent_svpbr.py b/trellis2/datasets/structured_latent_svpbr.py index 4c6711e..ca78975 100644 --- a/trellis2/datasets/structured_latent_svpbr.py +++ b/trellis2/datasets/structured_latent_svpbr.py @@ -32,7 +32,7 @@ def __init__( self.pretrained_shape_slat_dec = pretrained_shape_slat_dec self.shape_slat_dec_path = shape_slat_dec_path self.shape_slat_dec_ckpt = shape_slat_dec_ckpt - + def _loading_slat_dec(self): if self.pbr_slat_dec is not None and self.shape_slat_dec is not None: return @@ -60,7 +60,7 @@ def _delete_slat_dec(self): self.pbr_slat_dec = None del self.shape_slat_dec self.shape_slat_dec = None - + @torch.no_grad() def decode_latent(self, z, shape_z, batch_size=4): self._loading_slat_dec() @@ -86,20 +86,20 @@ def decode_latent(self, z, shape_z, batch_size=4): ]) self._delete_slat_dec() return reps - + @torch.no_grad() def visualize_sample(self, sample: dict): shape_z = sample['concat_cond'].cuda() z = sample['x_0'].cuda() reps = self.decode_latent(z, shape_z) - + # build camera yaw = [0, np.pi/2, np.pi, 3*np.pi/2] yaw_offset = -16 / 180 * np.pi yaw = [y + yaw_offset for y in yaw] pitch = [20 / 180 * np.pi for _ in range(4)] exts, ints = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, 2, 30) - + # render renderer = get_renderer(reps[0]) images = {k: [] for k in self.layout} @@ -115,12 +115,12 @@ def visualize_sample(self, sample: dict): for k in self.layout: images[k] = torch.stack(images[k], dim=0) return images - - + + class SLatPbr(SLatPbrVisMixin, StandardDatasetBase): """ structured latent for sparse voxel pbr dataset - + Args: roots (str): path to the dataset latent_key (str): key of the latent to be used @@ -149,7 +149,7 @@ def __init__(self, shape_slat_dec_path: Optional[str] = None, shape_slat_dec_ckpt: Optional[str] = None, **kwargs - ): + ): self.resolution = resolution self.pbr_slat_normalization = pbr_slat_normalization self.shape_slat_normalization = shape_slat_normalization @@ -157,7 +157,7 @@ def __init__(self, self.max_tokens = max_tokens self.full_pbr = full_pbr self.value_range = (-1, 1) - + super().__init__( roots, pretrained_pbr_slat_dec=pretrained_pbr_slat_dec, @@ -168,17 +168,17 @@ def __init__(self, shape_slat_dec_ckpt=shape_slat_dec_ckpt, **kwargs ) - + self.loads = [self.metadata.loc[sha256, 'pbr_latent_tokens'] for _, sha256 in self.instances] - + if self.pbr_slat_normalization is not None: self.pbr_slat_mean = torch.tensor(self.pbr_slat_normalization['mean']).reshape(1, -1) self.pbr_slat_std = torch.tensor(self.pbr_slat_normalization['std']).reshape(1, -1) - + if self.shape_slat_normalization is not None: self.shape_slat_mean = torch.tensor(self.shape_slat_normalization['mean']).reshape(1, -1) self.shape_slat_std = torch.tensor(self.shape_slat_normalization['std']).reshape(1, -1) - + self.attrs = attrs self.channels = { 'base_color': 3, @@ -192,7 +192,7 @@ def __init__(self, for attr in attrs: self.layout[attr] = slice(start, start + self.channels[attr]) start += self.channels[attr] - + def filter_metadata(self, metadata): stats = {} metadata = metadata[metadata['pbr_latent_encoded'] == True] @@ -209,7 +209,7 @@ def filter_metadata(self, metadata): metadata = metadata[metadata['num_roughness_tex'] > 0] stats['Full PBR'] = len(metadata) return metadata, stats - + def get_instance(self, root, instance): # PBR latent data = np.load(os.path.join(root['pbr_latent'], f'{instance}.npz')) @@ -219,7 +219,7 @@ def get_instance(self, root, instance): if self.pbr_slat_normalization is not None: feats = (feats - self.pbr_slat_mean) / self.pbr_slat_std pbr_z = SparseTensor(feats, coords) - + # Shape latent data = np.load(os.path.join(root['shape_latent'], f'{instance}.npz')) coords = torch.tensor(data['coords']).int() @@ -228,15 +228,15 @@ def get_instance(self, root, instance): if self.shape_slat_normalization is not None: feats = (feats - self.shape_slat_mean) / self.shape_slat_std shape_z = SparseTensor(feats, coords) - + assert torch.equal(shape_z.coords, pbr_z.coords), \ f"Shape latent and PBR latent have different coordinates: {shape_z.coords.shape} vs {pbr_z.coords.shape}" - + return { 'x_0': pbr_z, 'concat_cond': shape_z, } - + @staticmethod def collate_fn(batch, split_size=None): if split_size is None: @@ -258,9 +258,9 @@ def collate_fn(batch, split_size=None): pack[k] = sum([b[k] for b in sub_batch], []) else: pack[k] = [b[k] for b in sub_batch] - + packs.append(pack) - + if split_size is None: return packs[0] return packs diff --git a/trellis2/models/__init__.py b/trellis2/models/__init__.py index d4fed03..8893cb1 100644 --- a/trellis2/models/__init__.py +++ b/trellis2/models/__init__.py @@ -5,11 +5,11 @@ 'SparseStructureEncoder': 'sparse_structure_vae', 'SparseStructureDecoder': 'sparse_structure_vae', 'SparseStructureFlowModel': 'sparse_structure_flow', - + # SLat Generation 'SLatFlowModel': 'structured_latent_flow', 'ElasticSLatFlowModel': 'structured_latent_flow', - + # SC-VAEs 'SparseUnetVaeEncoder': 'sc_vaes.sparse_unet_vae', 'SparseUnetVaeDecoder': 'sc_vaes.sparse_unet_vae', @@ -73,6 +73,6 @@ def from_pretrained(path: str, **kwargs): from .sparse_structure_vae import SparseStructureEncoder, SparseStructureDecoder from .sparse_structure_flow import SparseStructureFlowModel from .structured_latent_flow import SLatFlowModel, ElasticSLatFlowModel - + from .sc_vaes.sparse_unet_vae import SparseUnetVaeEncoder, SparseUnetVaeDecoder from .sc_vaes.fdg_vae import FlexiDualGridVaeEncoder, FlexiDualGridVaeDecoder diff --git a/trellis2/models/sparse_elastic_mixin.py b/trellis2/models/sparse_elastic_mixin.py index 66d204c..cc83d4a 100644 --- a/trellis2/models/sparse_elastic_mixin.py +++ b/trellis2/models/sparse_elastic_mixin.py @@ -8,7 +8,7 @@ class SparseTransformerElasticMixin(ElasticModuleMixin): def _get_input_size(self, x: sp.SparseTensor, *args, **kwargs): return x.feats.shape[0] - + @contextmanager def with_mem_ratio(self, mem_ratio=1.0): if mem_ratio == 1.0: diff --git a/trellis2/models/sparse_structure_flow.py b/trellis2/models/sparse_structure_flow.py index 6c97665..529f494 100644 --- a/trellis2/models/sparse_structure_flow.py +++ b/trellis2/models/sparse_structure_flow.py @@ -111,12 +111,12 @@ def __init__( coords = torch.stack(coords, dim=-1).reshape(-1, 3) rope_phases = pos_embedder(coords) self.register_buffer("rope_phases", rope_phases) - + if pe_mode != "rope": self.rope_phases = None self.input_layer = nn.Linear(in_channels, model_channels) - + self.blocks = nn.ModuleList([ ModulatedTransformerCrossBlock( model_channels, @@ -179,7 +179,7 @@ def _basic_init(module): # Zero-out output layers: nn.init.constant_(self.out_layer.weight, 0) nn.init.constant_(self.out_layer.bias, 0) - + elif self.initialization == 'scaled': # Initialize transformer layers: def _basic_init(module): @@ -188,7 +188,7 @@ def _basic_init(module): if module.bias is not None: nn.init.constant_(module.bias, 0) self.apply(_basic_init) - + # Scaled init for to_out and ffn2 def _scaled_init(module): if isinstance(module, nn.Linear): @@ -199,15 +199,15 @@ def _scaled_init(module): block.self_attn.to_out.apply(_scaled_init) block.cross_attn.to_out.apply(_scaled_init) block.mlp.mlp[2].apply(_scaled_init) - + # Initialize input layer to make the initial representation have variance 1 nn.init.normal_(self.input_layer.weight, std=1.0 / np.sqrt(self.in_channels)) nn.init.zeros_(self.input_layer.bias) - + # Initialize timestep embedding MLP: nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) - + # Zero-out adaLN modulation layers in DiT blocks: if self.share_mod: nn.init.constant_(self.adaLN_modulation[-1].weight, 0) diff --git a/trellis2/models/sparse_structure_vae.py b/trellis2/models/sparse_structure_vae.py index c3e0913..9513c4e 100644 --- a/trellis2/models/sparse_structure_vae.py +++ b/trellis2/models/sparse_structure_vae.py @@ -35,7 +35,7 @@ def __init__( self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1) self.conv2 = zero_module(nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1)) self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity() - + def forward(self, x: torch.Tensor) -> torch.Tensor: h = self.norm1(x) h = F.silu(h) @@ -96,12 +96,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return pixel_shuffle_3d(x, 2) else: return F.interpolate(x, scale_factor=2, mode="nearest") - + class SparseStructureEncoder(nn.Module): """ Encoder for Sparse Structure (\mathcal{E}_S in the paper Sec. 3.3). - + Args: in_channels (int): Channels of the input. latent_channels (int): Channels of the latent representation. @@ -143,7 +143,7 @@ def __init__( self.blocks.append( DownsampleBlock3d(ch, channels[i+1]) ) - + self.middle_block = nn.Sequential(*[ ResBlock3d(channels[-1], channels[-1]) for _ in range(num_res_blocks_middle) @@ -201,16 +201,16 @@ def forward(self, x: torch.Tensor, sample_posterior: bool = False, return_raw: b z = mean + std * torch.randn_like(std) else: z = mean - + if return_raw: return z, mean, logvar return z - + class SparseStructureDecoder(nn.Module): """ Decoder for Sparse Structure (\mathcal{D}_S in the paper Sec. 3.3). - + Args: out_channels (int): Channels of the output. latent_channels (int): Channels of the latent representation. @@ -219,7 +219,7 @@ class SparseStructureDecoder(nn.Module): num_res_blocks_middle (int): Number of residual blocks in the middle. norm_type (Literal["group", "layer"]): Type of normalization layer. use_fp16 (bool): Whether to use FP16. - """ + """ def __init__( self, out_channels: int, @@ -273,7 +273,7 @@ def device(self) -> torch.device: Return the device of the model. """ return next(self.parameters()).device - + def convert_to_fp16(self) -> None: """ Convert the torso of the model to float16. @@ -291,12 +291,12 @@ def convert_to_fp32(self) -> None: self.dtype = torch.float32 self.blocks.apply(convert_module_to_f32) self.middle_block.apply(convert_module_to_f32) - + def forward(self, x: torch.Tensor) -> torch.Tensor: h = self.input_layer(x) - + h = h.type(self.dtype) - + h = self.middle_block(h) for block in self.blocks: h = block(h) diff --git a/trellis2/models/structured_latent_flow.py b/trellis2/models/structured_latent_flow.py index 9378ff7..7e84526 100644 --- a/trellis2/models/structured_latent_flow.py +++ b/trellis2/models/structured_latent_flow.py @@ -10,7 +10,7 @@ from ..modules.sparse.transformer import ModulatedSparseTransformerCrossBlock from .sparse_structure_flow import TimestepEmbedder from .sparse_elastic_mixin import SparseTransformerElasticMixin - + class SLatFlowModel(nn.Module): def __init__( @@ -61,7 +61,7 @@ def __init__( self.pos_embedder = AbsolutePositionEmbedder(model_channels) self.input_layer = sp.SparseLinear(in_channels, model_channels) - + self.blocks = nn.ModuleList([ ModulatedSparseTransformerCrossBlock( model_channels, @@ -78,7 +78,7 @@ def __init__( ) for _ in range(num_blocks) ]) - + self.out_layer = sp.SparseLinear(model_channels, out_channels) self.initialize_weights() @@ -124,7 +124,7 @@ def _basic_init(module): # Zero-out output layers: nn.init.constant_(self.out_layer.weight, 0) nn.init.constant_(self.out_layer.bias, 0) - + elif self.initialization == 'scaled': # Initialize transformer layers: def _basic_init(module): @@ -133,7 +133,7 @@ def _basic_init(module): if module.bias is not None: nn.init.constant_(module.bias, 0) self.apply(_basic_init) - + # Scaled init for to_out and ffn2 def _scaled_init(module): if isinstance(module, nn.Linear): @@ -144,15 +144,15 @@ def _scaled_init(module): block.self_attn.to_out.apply(_scaled_init) block.cross_attn.to_out.apply(_scaled_init) block.mlp.mlp[2].apply(_scaled_init) - + # Initialize input layer to make the initial representation have variance 1 nn.init.normal_(self.input_layer.weight, std=1.0 / np.sqrt(self.in_channels)) nn.init.zeros_(self.input_layer.bias) - + # Initialize timestep embedding MLP: nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) - + # Zero-out adaLN modulation layers in DiT blocks: if self.share_mod: nn.init.constant_(self.adaLN_modulation[-1].weight, 0) diff --git a/trellis2/modules/attention/config.py b/trellis2/modules/attention/config.py index a6d5180..cef21c3 100644 --- a/trellis2/modules/attention/config.py +++ b/trellis2/modules/attention/config.py @@ -1,27 +1,27 @@ from typing import * -BACKEND = 'flash_attn' +BACKEND = 'flash_attn' DEBUG = False def __from_env(): import os - + global BACKEND global DEBUG - + env_attn_backend = os.environ.get('ATTN_BACKEND') env_attn_debug = os.environ.get('ATTN_DEBUG') - + if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'flash_attn_3', 'sdpa', 'naive']: BACKEND = env_attn_backend if env_attn_debug is not None: DEBUG = env_attn_debug == '1' print(f"[ATTENTION] Using backend: {BACKEND}") - + __from_env() - + def set_backend(backend: Literal['xformers', 'flash_attn']): global BACKEND diff --git a/trellis2/modules/attention/full_attn.py b/trellis2/modules/attention/full_attn.py index e2f9b2a..56034c5 100644 --- a/trellis2/modules/attention/full_attn.py +++ b/trellis2/modules/attention/full_attn.py @@ -92,7 +92,7 @@ def scaled_dot_product_attention(*args, **kwargs): assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]" assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]" assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]" - device = q.device + device = q.device if config.BACKEND == 'xformers': if 'xops' not in globals(): @@ -141,5 +141,5 @@ def scaled_dot_product_attention(*args, **kwargs): out = _naive_sdpa(q, k, v) else: raise ValueError(f"Unknown attention module: {config.BACKEND}") - + return out diff --git a/trellis2/modules/attention/modules.py b/trellis2/modules/attention/modules.py index 492784c..a72968d 100644 --- a/trellis2/modules/attention/modules.py +++ b/trellis2/modules/attention/modules.py @@ -14,7 +14,7 @@ def __init__(self, dim: int, heads: int): def forward(self, x: torch.Tensor) -> torch.Tensor: return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype) - + class MultiHeadAttention(nn.Module): def __init__( @@ -36,10 +36,10 @@ def __init__( assert type in ["self", "cross"], f"Invalid attention type: {type}" assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}" assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention" - + if attn_mode == "windowed": raise NotImplementedError("Windowed attention is not yet implemented") - + self.channels = channels self.head_dim = channels // num_heads self.ctx_channels = ctx_channels if ctx_channels is not None else channels @@ -56,19 +56,19 @@ def __init__( else: self.to_q = nn.Linear(channels, channels, bias=qkv_bias) self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) - + if self.qk_rms_norm: self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) - + self.to_out = nn.Linear(channels, channels) - + def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, phases: Optional[torch.Tensor] = None) -> torch.Tensor: B, L, C = x.shape if self._type == "self": qkv = self.to_qkv(x) qkv = qkv.reshape(B, L, 3, self.num_heads, -1) - + if self.attn_mode == "full": if self.qk_rms_norm or self.use_rope: q, k, v = qkv.unbind(dim=2) diff --git a/trellis2/modules/attention/rope.py b/trellis2/modules/attention/rope.py index 1cf6c5b..344d37e 100644 --- a/trellis2/modules/attention/rope.py +++ b/trellis2/modules/attention/rope.py @@ -5,7 +5,7 @@ class RotaryPositionEmbedder(nn.Module): def __init__( - self, + self, head_dim: int, dim: int = 3, rope_freq: Tuple[float, float] = (1.0, 10000.0) @@ -18,20 +18,20 @@ def __init__( self.freq_dim = head_dim // 2 // dim self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs)) - + def _get_phases(self, indices: torch.Tensor) -> torch.Tensor: self.freqs = self.freqs.to(indices.device) phases = torch.outer(indices, self.freqs) phases = torch.polar(torch.ones_like(phases), phases) return phases - + @staticmethod def apply_rotary_embedding(x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor: x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) x_rotated = x_complex * phases.unsqueeze(-2) x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype) return x_embed - + def forward(self, indices: torch.Tensor) -> torch.Tensor: """ Args: diff --git a/trellis2/modules/image_feature_extractor.py b/trellis2/modules/image_feature_extractor.py index c3cb515..0e56e52 100644 --- a/trellis2/modules/image_feature_extractor.py +++ b/trellis2/modules/image_feature_extractor.py @@ -27,15 +27,15 @@ def cuda(self): def cpu(self): self.model.cpu() - + @torch.no_grad() def __call__(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor: """ Extract features from the image. - + Args: image: A batch of images as a tensor of shape (B, C, H, W) or a list of PIL images. - + Returns: A tensor of shape (B, N, D) where N is the number of patches and D is the feature dimension. """ @@ -49,12 +49,12 @@ def __call__(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tenso image = torch.stack(image).cuda() else: raise ValueError(f"Unsupported type of image: {type(image)}") - + image = self.transform(image).cuda() features = self.model(image, is_training=True)['x_prenorm'] patchtokens = F.layer_norm(features, features.shape[-1:]) return patchtokens - + class DinoV3FeatureExtractor: """ @@ -90,15 +90,15 @@ def extract_features(self, image: torch.Tensor) -> torch.Tensor: ) return F.layer_norm(hidden_states, hidden_states.shape[-1:]) - + @torch.no_grad() def __call__(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor: """ Extract features from the image. - + Args: image: A batch of images as a tensor of shape (B, C, H, W) or a list of PIL images. - + Returns: A tensor of shape (B, N, D) where N is the number of patches and D is the feature dimension. """ @@ -112,7 +112,7 @@ def __call__(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tenso image = torch.stack(image).cuda() else: raise ValueError(f"Unsupported type of image: {type(image)}") - + image = self.transform(image).cuda() features = self.extract_features(image) return features diff --git a/trellis2/modules/norm.py b/trellis2/modules/norm.py index 2484c54..7368d10 100644 --- a/trellis2/modules/norm.py +++ b/trellis2/modules/norm.py @@ -6,13 +6,13 @@ def chunked_apply(module, x: torch.Tensor, chunk_size: int) -> torch.Tensor: if chunk_size <= 0 or x.shape[0] <= chunk_size: return module(x) - + # Process first chunk to determine output shape and dtype out_0 = module(x[0:chunk_size]) out_shape = (x.shape[0],) + out_0.shape[1:] out = torch.empty(out_shape, device=x.device, dtype=out_0.dtype) out[0:chunk_size] = out_0 - + # Process remaining chunks for i in range(chunk_size, x.shape[0], chunk_size): out[i:i+chunk_size] = module(x[i:i+chunk_size]) @@ -35,7 +35,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.low_vram: return chunked_apply(self._forward, x, self.chunk_size) return self._forward(x) - + class GroupNorm32(nn.GroupNorm): """ @@ -56,8 +56,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.low_vram: return chunked_apply(self._forward, x, self.chunk_size) return self._forward(x) - - + + class ChannelLayerNorm32(LayerNorm32): def forward(self, x: torch.Tensor) -> torch.Tensor: DIM = x.dim() @@ -65,4 +65,3 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = super().forward(x) x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous() return x - \ No newline at end of file diff --git a/trellis2/modules/sparse/basic.py b/trellis2/modules/sparse/basic.py index 880973b..2e63fc8 100644 --- a/trellis2/modules/sparse/basic.py +++ b/trellis2/modules/sparse/basic.py @@ -17,7 +17,7 @@ class VarLenTensor: """ Sequential tensor with variable length. - + Args: feats (torch.Tensor): Features of the varlen tensor. layout (List[slice]): Layout of the varlen tensor for each batch @@ -26,7 +26,7 @@ def __init__(self, feats: torch.Tensor, layout: List[slice]=None): self.feats = feats self.layout = layout if layout is not None else [slice(0, feats.shape[0])] self._cache = {} - + @staticmethod def layout_from_seqlen(seqlen: list) -> List[slice]: """ @@ -38,7 +38,7 @@ def layout_from_seqlen(seqlen: list) -> List[slice]: layout.append(slice(start, start + l)) start += l return layout - + @staticmethod def from_tensor_list(tensor_list: List[torch.Tensor]) -> 'VarLenTensor': """ @@ -51,7 +51,7 @@ def from_tensor_list(tensor_list: List[torch.Tensor]) -> 'VarLenTensor': layout.append(slice(start, start + tensor.shape[0])) start += tensor.shape[0] return VarLenTensor(feats, layout) - + def to_tensor_list(self) -> List[torch.Tensor]: """ Convert a VarLenTensor to a list of tensors. @@ -60,17 +60,17 @@ def to_tensor_list(self) -> List[torch.Tensor]: for s in self.layout: tensor_list.append(self.feats[s]) return tensor_list - + def __len__(self) -> int: return len(self.layout) - + @property def shape(self) -> torch.Size: return torch.Size([len(self.layout), *self.feats.shape[1:]]) - + def dim(self) -> int: return len(self.shape) - + @property def ndim(self) -> int: return self.dim() @@ -82,13 +82,13 @@ def dtype(self): @property def device(self): return self.feats.device - + @property def seqlen(self) -> torch.LongTensor: if 'seqlen' not in self._cache: self._cache['seqlen'] = torch.tensor([l.stop - l.start for l in self.layout], dtype=torch.long, device=self.device) return self._cache['seqlen'] - + @property def cum_seqlen(self) -> torch.LongTensor: if 'cum_seqlen' not in self._cache: @@ -97,7 +97,7 @@ def cum_seqlen(self) -> torch.LongTensor: self.seqlen.cumsum(dim=0) ], dim=0) return self._cache['cum_seqlen'] - + @property def batch_boardcast_map(self) -> torch.LongTensor: """ @@ -109,7 +109,7 @@ def batch_boardcast_map(self) -> torch.LongTensor: self.seqlen, ) return self._cache['batch_boardcast_map'] - + @overload def to(self, dtype: torch.dtype, *, non_blocking: bool = False, copy: bool = False) -> 'VarLenTensor': ... @@ -134,7 +134,7 @@ def to(self, *args, **kwargs) -> 'VarLenTensor': device = kwargs['device'] non_blocking = kwargs.get('non_blocking', False) copy = kwargs.get('copy', False) - + new_feats = self.feats.to(device=device, dtype=dtype, non_blocking=non_blocking, copy=copy) return self.replace(new_feats) @@ -145,7 +145,7 @@ def type(self, dtype): def cpu(self) -> 'VarLenTensor': new_feats = self.feats.cpu() return self.replace(new_feats) - + def cuda(self) -> 'VarLenTensor': new_feats = self.feats.cuda() return self.replace(new_feats) @@ -153,11 +153,11 @@ def cuda(self) -> 'VarLenTensor': def half(self) -> 'VarLenTensor': new_feats = self.feats.half() return self.replace(new_feats) - + def float(self) -> 'VarLenTensor': new_feats = self.feats.float() return self.replace(new_feats) - + def detach(self) -> 'VarLenTensor': new_feats = self.feats.detach() return self.replace(new_feats) @@ -165,7 +165,7 @@ def detach(self) -> 'VarLenTensor': def reshape(self, *shape) -> 'VarLenTensor': new_feats = self.feats.reshape(self.feats.shape[0], *shape) return self.replace(new_feats) - + def unbind(self, dim: int) -> List['VarLenTensor']: return varlen_unbind(self, dim) @@ -176,11 +176,11 @@ def replace(self, feats: torch.Tensor) -> 'VarLenTensor': ) new_tensor._cache = self._cache return new_tensor - + def to_dense(self, max_length=None) -> torch.Tensor: """ Convert a VarLenTensor to a dense representation without for-loop. - + Returns: dense (torch.Tensor): (N, L, C) dense tensor mask (torch.BoolTensor): (N, L) mask indicating valid positions @@ -197,7 +197,7 @@ def to_dense(self, max_length=None) -> torch.Tensor: def __neg__(self) -> 'VarLenTensor': return self.replace(-self.feats) - + def __elemwise__(self, other: Union[torch.Tensor, 'VarLenTensor'], op: callable) -> 'VarLenTensor': if isinstance(other, torch.Tensor): try: @@ -216,10 +216,10 @@ def __add__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenT def __radd__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': return self.__elemwise__(other, torch.add) - + def __sub__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': return self.__elemwise__(other, torch.sub) - + def __rsub__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': return self.__elemwise__(other, lambda x, y: torch.sub(y, x)) @@ -252,7 +252,7 @@ def __getitem__(self, idx): raise ValueError(f"Unknown index type: {idx.dtype}") else: raise ValueError(f"Unknown index type: {type(idx)}") - + new_feats = [] new_layout = [] start = 0 @@ -263,11 +263,11 @@ def __getitem__(self, idx): new_feats = torch.cat(new_feats, dim=0).contiguous() new_tensor = VarLenTensor(feats=new_feats, layout=new_layout) return new_tensor - + def reduce(self, op: str, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: if isinstance(dim, int): dim = (dim,) - + if op =='mean': red = self.feats.mean(dim=dim, keepdim=keepdim) elif op =='sum': @@ -276,28 +276,28 @@ def reduce(self, op: str, dim: Optional[Union[int, Tuple[int,...]]] = None, keep red = self.feats.prod(dim=dim, keepdim=keepdim) else: raise ValueError(f"Unsupported reduce operation: {op}") - + if dim is None or 0 in dim: return red - + red = torch.segment_reduce(red, reduce=op, lengths=self.seqlen) return red - + def mean(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: return self.reduce(op='mean', dim=dim, keepdim=keepdim) - + def sum(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: return self.reduce(op='sum', dim=dim, keepdim=keepdim) - + def prod(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: return self.reduce(op='prod', dim=dim, keepdim=keepdim) - + def std(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: mean = self.mean(dim=dim, keepdim=True) mean2 = self.replace(self.feats ** 2).mean(dim=dim, keepdim=True) std = (mean2 - mean ** 2).sqrt() return std - + def __repr__(self) -> str: return f"VarLenTensor(shape={self.shape}, dtype={self.dtype}, device={self.device})" @@ -305,7 +305,7 @@ def __repr__(self) -> str: def varlen_cat(inputs: List[VarLenTensor], dim: int = 0) -> VarLenTensor: """ Concatenate a list of varlen tensors. - + Args: inputs (List[VarLenTensor]): List of varlen tensors to concatenate. """ @@ -328,7 +328,7 @@ def varlen_cat(inputs: List[VarLenTensor], dim: int = 0) -> VarLenTensor: def varlen_unbind(input: VarLenTensor, dim: int) -> Union[List[VarLenTensor]]: """ Unbind a varlen tensor along a dimension. - + Args: input (VarLenTensor): Varlen tensor to unbind. dim (int): Dimension to unbind. @@ -338,12 +338,12 @@ def varlen_unbind(input: VarLenTensor, dim: int) -> Union[List[VarLenTensor]]: else: feats = input.feats.unbind(dim) return [input.replace(f) for f in feats] - + class SparseTensor(VarLenTensor): """ Sparse tensor with support for both torchsparse and spconv backends. - + Parameters: - feats (torch.Tensor): Features of the sparse tensor. - coords (torch.Tensor): Coordinates of the sparse tensor. @@ -371,7 +371,7 @@ def __init__(self, *args, **kwargs): self.SparseTensorData = importlib.import_module('torchsparse').SparseTensor elif config.CONV == 'spconv': self.SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor - + method_id = 0 if len(args) != 0: method_id = 0 if isinstance(args[0], torch.Tensor) else 1 @@ -430,7 +430,7 @@ def __init__(self, *args, **kwargs): print(f"- Scale: {self._scale}") print(f"- Coords: {self.coords}") raise e - + @staticmethod def from_tensor_list(feats_list: List[torch.Tensor], coords_list: List[torch.Tensor]) -> 'SparseTensor': """ @@ -443,7 +443,7 @@ def from_tensor_list(feats_list: List[torch.Tensor], coords_list: List[torch.Ten coords.append(coord) coords = torch.cat(coords, dim=0) return SparseTensor(feats, coords) - + def to_tensor_list(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: """ Convert a SparseTensor to list of tensors. @@ -454,31 +454,31 @@ def to_tensor_list(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: feats_list.append(self.feats[s]) coords_list.append(self.coords[s]) return feats_list, coords_list - + def __len__(self) -> int: return len(self.layout) - + def __cal_shape(self, feats, coords): shape = [] shape.append(coords[:, 0].max().item() + 1) shape.extend([*feats.shape[1:]]) return torch.Size(shape) - + def __cal_layout(self, coords, batch_size): seq_len = torch.bincount(coords[:, 0], minlength=batch_size) - offset = torch.cumsum(seq_len, dim=0) + offset = torch.cumsum(seq_len, dim=0) layout = [slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size)] return layout - + def __cal_spatial_shape(self, coords): return torch.Size((coords[:, 1:].max(0)[0] + 1).tolist()) - + @property def shape(self) -> torch.Size: if self._shape is None: self._shape = self.__cal_shape(self.feats, self.coords) return self._shape - + @property def layout(self) -> List[slice]: layout = self.get_spatial_cache('layout') @@ -486,7 +486,7 @@ def layout(self) -> List[slice]: layout = self.__cal_layout(self.coords, self.shape[0]) self.register_spatial_cache('layout', layout) return layout - + @property def spatial_shape(self) -> torch.Size: spatial_shape = self.get_spatial_cache('shape') @@ -503,7 +503,7 @@ def feats(self) -> torch.Tensor: return self.data.features else: return self.data['feats'] - + @feats.setter def feats(self, value: torch.Tensor): if config.CONV == 'torchsparse': @@ -521,7 +521,7 @@ def coords(self) -> torch.Tensor: return self.data.indices else: return self.data['coords'] - + @coords.setter def coords(self, value: torch.Tensor): if config.CONV == 'torchsparse': @@ -538,7 +538,7 @@ def dtype(self): @property def device(self): return self.feats.device - + @property def seqlen(self) -> torch.LongTensor: seqlen = self.get_spatial_cache('seqlen') @@ -546,7 +546,7 @@ def seqlen(self) -> torch.LongTensor: seqlen = torch.tensor([l.stop - l.start for l in self.layout], dtype=torch.long, device=self.device) self.register_spatial_cache('seqlen', seqlen) return seqlen - + @property def cum_seqlen(self) -> torch.LongTensor: cum_seqlen = self.get_spatial_cache('cum_seqlen') @@ -557,7 +557,7 @@ def cum_seqlen(self) -> torch.LongTensor: ], dim=0) self.register_spatial_cache('cum_seqlen', cum_seqlen) return cum_seqlen - + @property def batch_boardcast_map(self) -> torch.LongTensor: """ @@ -596,7 +596,7 @@ def to(self, *args, **kwargs) -> 'SparseTensor': device = kwargs['device'] non_blocking = kwargs.get('non_blocking', False) copy = kwargs.get('copy', False) - + new_feats = self.feats.to(device=device, dtype=dtype, non_blocking=non_blocking, copy=copy) new_coords = self.coords.to(device=device, non_blocking=non_blocking, copy=copy) return self.replace(new_feats, new_coords) @@ -609,7 +609,7 @@ def cpu(self) -> 'SparseTensor': new_feats = self.feats.cpu() new_coords = self.coords.cpu() return self.replace(new_feats, new_coords) - + def cuda(self) -> 'SparseTensor': new_feats = self.feats.cuda() new_coords = self.coords.cuda() @@ -618,11 +618,11 @@ def cuda(self) -> 'SparseTensor': def half(self) -> 'SparseTensor': new_feats = self.feats.half() return self.replace(new_feats) - + def float(self) -> 'SparseTensor': new_feats = self.feats.float() return self.replace(new_feats) - + def detach(self) -> 'SparseTensor': new_coords = self.coords.detach() new_feats = self.feats.detach() @@ -631,7 +631,7 @@ def detach(self) -> 'SparseTensor': def reshape(self, *shape) -> 'SparseTensor': new_feats = self.feats.reshape(self.feats.shape[0], *shape) return self.replace(new_feats) - + def unbind(self, dim: int) -> List['SparseTensor']: return sparse_unbind(self, dim) @@ -675,7 +675,7 @@ def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> spatial_cache=self._spatial_cache ) return new_tensor - + def to_dense(self) -> torch.Tensor: if config.CONV == 'torchsparse': return self.data.dense() @@ -713,7 +713,7 @@ def __merge_sparse_cache(self, other: 'SparseTensor') -> dict: else: new_cache[k].update(other._spatial_cache[k]) return new_cache - + def __elemwise__(self, other: Union[torch.Tensor, VarLenTensor], op: callable) -> 'SparseTensor': if isinstance(other, torch.Tensor): try: @@ -746,7 +746,7 @@ def __getitem__(self, idx): raise ValueError(f"Unknown index type: {idx.dtype}") else: raise ValueError(f"Unknown index type: {type(idx)}") - + new_coords = [] new_feats = [] new_layout = [] @@ -763,7 +763,7 @@ def __getitem__(self, idx): new_tensor = SparseTensor(feats=new_feats, coords=new_coords, shape=new_shape) new_tensor.register_spatial_cache('layout', new_layout) return new_tensor - + def clear_spatial_cache(self) -> None: """ Clear all spatial caches. @@ -790,14 +790,14 @@ def get_spatial_cache(self, key=None): if key is None: return cur_scale_cache return cur_scale_cache.get(key, None) - + def __repr__(self) -> str: return f"SparseTensor(shape={self.shape}, dtype={self.dtype}, device={self.device})" def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor: """ Concatenate a list of sparse tensors. - + Args: inputs (List[SparseTensor]): List of sparse tensors to concatenate. """ @@ -824,7 +824,7 @@ def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor: def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]: """ Unbind a sparse tensor along a dimension. - + Args: input (SparseTensor): Sparse tensor to unbind. dim (int): Dimension to unbind. diff --git a/trellis2/modules/sparse/config.py b/trellis2/modules/sparse/config.py index a5f4d53..fead5b4 100644 --- a/trellis2/modules/sparse/config.py +++ b/trellis2/modules/sparse/config.py @@ -1,16 +1,16 @@ from typing import * -CONV = 'flex_gemm' +CONV = 'flex_gemm' DEBUG = False ATTN = 'flash_attn' def __from_env(): import os - + global CONV global DEBUG global ATTN - + env_sparse_conv_backend = os.environ.get('SPARSE_CONV_BACKEND') env_sparse_debug = os.environ.get('SPARSE_DEBUG') env_sparse_attn_backend = os.environ.get('SPARSE_ATTN_BACKEND') @@ -23,12 +23,12 @@ def __from_env(): DEBUG = env_sparse_debug == '1' if env_sparse_attn_backend is not None and env_sparse_attn_backend in ['xformers', 'flash_attn', 'flash_attn_3']: ATTN = env_sparse_attn_backend - + print(f"[SPARSE] Conv backend: {CONV}; Attention backend: {ATTN}") - + __from_env() - + def set_conv_backend(backend: Literal['none', 'spconv', 'torchsparse', 'flex_gemm']): global CONV diff --git a/trellis2/modules/sparse/linear.py b/trellis2/modules/sparse/linear.py index bc51f79..391320d 100644 --- a/trellis2/modules/sparse/linear.py +++ b/trellis2/modules/sparse/linear.py @@ -10,13 +10,13 @@ def chunked_apply(module, x: torch.Tensor, chunk_size: int) -> torch.Tensor: if chunk_size <= 0 or x.shape[0] <= chunk_size: return module(x) - + # Process first chunk to determine output shape and dtype out_0 = module(x[0:chunk_size]) out_shape = (x.shape[0],) + out_0.shape[1:] out = torch.empty(out_shape, device=x.device, dtype=out_0.dtype) out[0:chunk_size] = out_0 - + # Process remaining chunks for i in range(chunk_size, x.shape[0], chunk_size): out[i:i+chunk_size] = module(x[i:i+chunk_size]) diff --git a/trellis2/modules/sparse/nonlinearity.py b/trellis2/modules/sparse/nonlinearity.py index 950e5c0..0e78836 100644 --- a/trellis2/modules/sparse/nonlinearity.py +++ b/trellis2/modules/sparse/nonlinearity.py @@ -13,7 +13,7 @@ class SparseReLU(nn.ReLU): def forward(self, input: VarLenTensor) -> VarLenTensor: return input.replace(super().forward(input.feats)) - + class SparseSiLU(nn.SiLU): def forward(self, input: VarLenTensor) -> VarLenTensor: @@ -32,4 +32,4 @@ def __init__(self, activation: nn.Module): def forward(self, input: VarLenTensor) -> VarLenTensor: return input.replace(self.activation(input.feats)) - + diff --git a/trellis2/modules/sparse/spatial/basic.py b/trellis2/modules/sparse/spatial/basic.py index eaeb8af..0791db6 100644 --- a/trellis2/modules/sparse/spatial/basic.py +++ b/trellis2/modules/sparse/spatial/basic.py @@ -41,7 +41,7 @@ def forward(self, x: SparseTensor) -> SparseTensor: ) else: new_coords, idx = cache - + new_feats = torch.scatter_reduce( torch.zeros(new_coords.shape[0], x.feats.shape[1], device=x.feats.device, dtype=x.feats.dtype), dim=0, @@ -53,7 +53,7 @@ def forward(self, x: SparseTensor) -> SparseTensor: out = SparseTensor(new_feats, new_coords, x._shape) out._scale = tuple([s * self.factor for s in x._scale]) out._spatial_cache = x._spatial_cache - + if cache is None: x.register_spatial_cache(f'downsample_{self.factor}', (new_coords, idx)) out.register_spatial_cache(f'upsample_{self.factor}', (x.coords, idx)) @@ -98,12 +98,11 @@ def forward(self, x: SparseTensor, subdivision: Optional[SparseTensor] = None) - idx = torch.repeat_interleave(torch.arange(x.coords.shape[0], device=x.device), N_leaf, dim=0, output_size=subidx.shape[0]) else: new_coords, idx = cache - + new_feats = x.feats[idx] out = SparseTensor(new_feats, new_coords, x._shape) out._scale = tuple([s / self.factor for s in x._scale]) if cache is not None: # only keep cache when subdiv following it out._spatial_cache = x._spatial_cache - + return out - \ No newline at end of file diff --git a/trellis2/modules/sparse/spatial/spatial2channel.py b/trellis2/modules/sparse/spatial/spatial2channel.py index 577f36d..2eaa60a 100644 --- a/trellis2/modules/sparse/spatial/spatial2channel.py +++ b/trellis2/modules/sparse/spatial/spatial2channel.py @@ -35,14 +35,14 @@ def forward(self, x: SparseTensor) -> SparseTensor: ) else: new_coords, idx, subidx = cache - + new_feats = torch.zeros(new_coords.shape[0] * self.factor ** DIM, x.feats.shape[1], device=x.feats.device, dtype=x.feats.dtype) new_feats[idx * self.factor ** DIM + subidx] = x.feats out = SparseTensor(new_feats.reshape(new_coords.shape[0], -1), new_coords, None if x._shape is None else torch.Size([x._shape[0], x._shape[1] * self.factor ** DIM])) out._scale = tuple([s * self.factor for s in x._scale]) out._spatial_cache = x._spatial_cache - + if cache is None: x.register_spatial_cache(f'spatial2channel_{self.factor}', (new_coords, idx, subidx)) out.register_spatial_cache(f'channel2spatial_{self.factor}', (x.coords, idx, subidx)) @@ -51,7 +51,7 @@ def forward(self, x: SparseTensor) -> SparseTensor: subdivision = torch.zeros((new_coords.shape[0], self.factor ** DIM), device=x.device, dtype=torch.bool) subdivision[idx, subidx] = True out.register_spatial_cache(f'subdivision', subdivision) - + return out diff --git a/trellis2/modules/transformer/blocks.py b/trellis2/modules/transformer/blocks.py index fb6f5eb..30693e9 100644 --- a/trellis2/modules/transformer/blocks.py +++ b/trellis2/modules/transformer/blocks.py @@ -16,7 +16,7 @@ def __init__(self, channels: int, in_channels: int = 3): self.freq_dim = channels // in_channels // 2 self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim self.freqs = 1.0 / (10000 ** self.freqs) - + def _sin_cos_embedding(self, x: torch.Tensor) -> torch.Tensor: """ Create sinusoidal position embeddings. @@ -129,7 +129,7 @@ def __init__( shift_window: Optional[Tuple[int, int, int]] = None, use_checkpoint: bool = False, use_rope: bool = False, - rope_freq: Tuple[int, int] = (1.0, 10000.0), + rope_freq: Tuple[int, int] = (1.0, 10000.0), qk_rms_norm: bool = False, qk_rms_norm_cross: bool = False, qkv_bias: bool = True, @@ -183,4 +183,3 @@ def forward(self, x: torch.Tensor, context: torch.Tensor, phases: Optional[torch return torch.utils.checkpoint.checkpoint(self._forward, x, context, phases, use_reentrant=False) else: return self._forward(x, context, phases) - \ No newline at end of file diff --git a/trellis2/modules/transformer/modulated.py b/trellis2/modules/transformer/modulated.py index 0d71e58..2e641ec 100644 --- a/trellis2/modules/transformer/modulated.py +++ b/trellis2/modules/transformer/modulated.py @@ -20,7 +20,7 @@ def __init__( shift_window: Optional[Tuple[int, int, int]] = None, use_checkpoint: bool = False, use_rope: bool = False, - rope_freq: Tuple[int, int] = (1.0, 10000.0), + rope_freq: Tuple[int, int] = (1.0, 10000.0), qk_rms_norm: bool = False, qkv_bias: bool = True, share_mod: bool = False, @@ -92,7 +92,7 @@ def __init__( shift_window: Optional[Tuple[int, int, int]] = None, use_checkpoint: bool = False, use_rope: bool = False, - rope_freq: Tuple[int, int] = (1.0, 10000.0), + rope_freq: Tuple[int, int] = (1.0, 10000.0), qk_rms_norm: bool = False, qk_rms_norm_cross: bool = False, qkv_bias: bool = True, @@ -162,4 +162,3 @@ def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, pha return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, phases, use_reentrant=False) else: return self._forward(x, mod, context, phases) - \ No newline at end of file diff --git a/trellis2/trainers/__init__.py b/trellis2/trainers/__init__.py index 8f25130..6b3554b 100644 --- a/trellis2/trainers/__init__.py +++ b/trellis2/trainers/__init__.py @@ -2,22 +2,22 @@ __attributes = { 'BasicTrainer': 'basic', - + 'SparseStructureVaeTrainer': 'vae.sparse_structure_vae', 'ShapeVaeTrainer': 'vae.shape_vae', 'PbrVaeTrainer': 'vae.pbr_vae', - + 'FlowMatchingTrainer': 'flow_matching.flow_matching', 'FlowMatchingCFGTrainer': 'flow_matching.flow_matching', 'TextConditionedFlowMatchingCFGTrainer': 'flow_matching.flow_matching', 'ImageConditionedFlowMatchingCFGTrainer': 'flow_matching.flow_matching', - + 'SparseFlowMatchingTrainer': 'flow_matching.sparse_flow_matching', 'SparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching', 'TextConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching', 'ImageConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching', 'MultiImageConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching', - + 'DinoV2FeatureExtractor': 'flow_matching.mixins.image_conditioned', 'DinoV3FeatureExtractor': 'flow_matching.mixins.image_conditioned', } @@ -47,21 +47,21 @@ def __getattr__(name): from .vae.sparse_structure_vae import SparseStructureVaeTrainer from .vae.shape_vae import ShapeVaeTrainer from .vae.pbr_vae import PbrVaeTrainer - + from .flow_matching.flow_matching import ( FlowMatchingTrainer, FlowMatchingCFGTrainer, TextConditionedFlowMatchingCFGTrainer, ImageConditionedFlowMatchingCFGTrainer, ) - + from .flow_matching.sparse_flow_matching import ( SparseFlowMatchingTrainer, SparseFlowMatchingCFGTrainer, TextConditionedSparseFlowMatchingCFGTrainer, ImageConditionedSparseFlowMatchingCFGTrainer, ) - + from .flow_matching.mixins.image_conditioned import ( DinoV2FeatureExtractor, DinoV3FeatureExtractor, diff --git a/trellis2/trainers/basic.py b/trellis2/trainers/basic.py index c8e4b4c..d8333a6 100644 --- a/trellis2/trainers/basic.py +++ b/trellis2/trainers/basic.py @@ -26,7 +26,7 @@ class BasicTrainer: """ Trainer for basic training loop. - + Args: models (dict[str, nn.Module]): Models to train. dataset (torch.utils.data.Dataset): Dataset. @@ -119,7 +119,7 @@ def __init__(self, self.i_log = i_log self.i_sample = i_sample self.i_save = i_save - self.i_ddpcheck = i_ddpcheck + self.i_ddpcheck = i_ddpcheck if dist.is_initialized(): # Multi-GPU params @@ -141,14 +141,14 @@ def __init__(self, self.init_models_and_more(**kwargs) self.prepare_dataloader(**kwargs) - + # Load checkpoint self.step = 0 if load_dir is not None and step is not None: self.load(load_dir, step) elif finetune_ckpt is not None: self.finetune_from(finetune_ckpt) - + if self.is_master: os.makedirs(os.path.join(self.output_dir, 'ckpts'), exist_ok=True) os.makedirs(os.path.join(self.output_dir, 'samples'), exist_ok=True) @@ -156,7 +156,7 @@ def __init__(self, if self.parallel_mode == 'ddp' and self.world_size > 1: self.check_ddp() - + if self.is_master: print('\n\nTrainer initialized.') print(self) @@ -198,7 +198,7 @@ def device(self): if hasattr(model, 'device'): return model.device return next(list(self.models.values())[0].parameters()).device - + def init_models_and_more(self, **kwargs): """ Initialize models and more. @@ -244,7 +244,7 @@ def init_models_and_more(self, **kwargs): self.optimizer = getattr(torch.optim, self.optimizer_config['name'])(self.master_params, **self.optimizer_config['args']) else: self.optimizer = globals()[self.optimizer_config['name']](self.master_params, **self.optimizer_config['args']) - + # Initalize learning rate scheduler if self.lr_scheduler_config is not None: if hasattr(torch.optim.lr_scheduler, self.lr_scheduler_config['name']): @@ -323,7 +323,7 @@ def load(self, load_dir, step=0): """ if self.is_master: print(f'\nLoading checkpoint from step {step}...', end='') - + model_ckpts = {} for name, model in self.models.items(): model_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'{name}_step{step:07d}.pt')), map_location=self.device, weights_only=True) @@ -340,7 +340,7 @@ def load(self, load_dir, step=0): ema_ckpts[name] = ema_ckpt self._state_dicts_to_master_params(self.ema_params[i], ema_ckpts) del ema_ckpts - + misc_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'misc_step{step:07d}.pt')), map_location=torch.device('cpu'), weights_only=False) self.optimizer.load_state_dict(misc_ckpt['optimizer']) self.step = misc_ckpt['step'] @@ -372,7 +372,7 @@ def save(self, non_blocking=True): """ assert self.is_master, 'save() should be called only by the rank 0 process.' print(f'\nSaving checkpoint at step {self.step}...', end='') - + model_ckpts = self._master_params_to_state_dicts(self.master_params) for name, model_ckpt in model_ckpts.items(): model_ckpt = {k: v.cpu() for k, v in model_ckpt.items()} # Move to CPU for saving @@ -383,7 +383,7 @@ def save(self, non_blocking=True): ).start() else: torch.save(model_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_step{self.step:07d}.pt')) - + for i, ema_rate in enumerate(self.ema_rate): ema_ckpts = self._master_params_to_state_dicts(self.ema_params[i]) for name, ema_ckpt in ema_ckpts.items(): @@ -429,7 +429,7 @@ def finetune_from(self, finetune_ckpt): print('\nFinetuning from:') for name, path in finetune_ckpt.items(): print(f' - {name}: {path}') - + model_ckpts = {} for name, model in self.models.items(): model_state_dict = model.state_dict() @@ -649,7 +649,7 @@ def load_data(self): self._data_prefetched = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True) else: data = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True) - + # if the data is a dict, we need to split it into multiple dicts with batch_size_per_gpu if isinstance(data, dict): if self.batch_split == 1: @@ -664,7 +664,7 @@ def load_data(self): data_list = data else: raise ValueError('Data must be a dict or a list of dicts.') - + return data_list def run_step(self, data_list): @@ -745,7 +745,7 @@ def run_step(self, data_list): if not any(not p.grad.isfinite().all() for p in self.model_params): self.optimizer.step() else: - print('\n\033[93mWarning: NaN detected in gradients. Skipping update.\033[0m') + print('\n\033[93mWarning: NaN detected in gradients. Skipping update.\033[0m') ## adjust learning rate if self.lr_scheduler_config is not None: statuses[-1]['lr'] = self.lr_scheduler.get_last_lr()[0] @@ -758,7 +758,7 @@ def run_step(self, data_list): step_log['elastic'] = dict_reduce(elastic_controller_logs, lambda x: np.mean(x)) if self.grad_clip is not None: step_log['grad_clip'] = self.grad_clip if isinstance(self.grad_clip, float) else self.grad_clip.log() - + # Check grad and norm of each param if self.log_param_stats: param_norms = {} @@ -792,7 +792,7 @@ def save_logs(self): for key, value in log_show.items(): self.writer.add_scalar(key, value, self.step) self.log = [] - + def check_abort(self): """ Check if training should be aborted due to certain conditions. @@ -884,7 +884,7 @@ def run(self): # Save checkpoint if self.step % self.i_save == 0: self.save() - + # Check abort self.check_abort() @@ -894,7 +894,7 @@ def run(self): if self.is_master: self.writer.close() print('Training finished.') - + def profile(self, wait=2, warmup=3, active=5): """ Profile the training loop. diff --git a/trellis2/trainers/utils.py b/trellis2/trainers/utils.py index 23e4286..c72068f 100644 --- a/trellis2/trainers/utils.py +++ b/trellis2/trainers/utils.py @@ -16,7 +16,7 @@ def str_to_dtype(dtype_str: str): 'fp32': torch.float32, 'float32': torch.float32, }[dtype_str] - + def make_master_params(model_params): """ @@ -64,7 +64,7 @@ def model_grads_to_master_grads(model_params, master_params): master_params[0].grad = _flatten_dense_tensors( [param.grad.data.detach().float() for param in model_params] ) - + def zero_grad(model_params): for param in model_params: @@ -74,7 +74,7 @@ def zero_grad(model_params): else: param.grad.requires_grad_(False) param.grad.zero_() - + # LR Schedulers from torch.optim.lr_scheduler import LambdaLR @@ -83,9 +83,8 @@ class LinearWarmupLRScheduler(LambdaLR): def __init__(self, optimizer, warmup_steps, last_epoch=-1): self.warmup_steps = warmup_steps super(LinearWarmupLRScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) - + def lr_lambda(self, current_step): if current_step < self.warmup_steps: return float(current_step + 1) / self.warmup_steps return 1.0 - \ No newline at end of file diff --git a/trellis2/utils/data_utils.py b/trellis2/utils/data_utils.py index 805b6cc..d86f86b 100644 --- a/trellis2/utils/data_utils.py +++ b/trellis2/utils/data_utils.py @@ -58,7 +58,7 @@ def cycle(data_loader: DataLoader) -> Iterator: if isinstance(data_loader.sampler, ResumableSampler): data_loader.sampler.epoch += 1 data_loader.sampler.idx = 0 - + class ResumableSampler(Sampler): """ @@ -133,7 +133,7 @@ def __iter__(self) -> Iterator: # subsample indices = indices[self.rank : self.total_size : self.world_size] - + # resume from previous state indices = indices[self.idx:] @@ -147,11 +147,11 @@ def state_dict(self) -> dict[str, int]: 'epoch': self.epoch, 'idx': self.idx, } - + def load_state_dict(self, state_dict): self.epoch = state_dict['epoch'] self.idx = state_dict['idx'] - + class BalancedResumableSampler(ResumableSampler): """ @@ -185,7 +185,7 @@ def __init__( super().__init__(dataset, shuffle, seed, drop_last) self.batch_size = batch_size self.loads = dataset.loads - + def __iter__(self) -> Iterator: if self.shuffle: # deterministically shuffle based on epoch and seed @@ -219,7 +219,7 @@ def __iter__(self) -> Iterator: batch_loads = [self.loads[idx] for idx in batch_indices] groups = load_balanced_group_indices(batch_loads, self.world_size, equal_size=True) balanced_indices.extend([batch_indices[j] for j in groups[self.rank]]) - + # resume from previous state indices = balanced_indices[self.idx:] diff --git a/trellis2/utils/dist_utils.py b/trellis2/utils/dist_utils.py index 348799c..6ef472e 100644 --- a/trellis2/utils/dist_utils.py +++ b/trellis2/utils/dist_utils.py @@ -14,7 +14,7 @@ def setup_dist(rank, local_rank, world_size, master_addr, master_port): os.environ['LOCAL_RANK'] = str(local_rank) torch.cuda.set_device(local_rank) dist.init_process_group('nccl', rank=rank, world_size=world_size) - + def read_file_dist(path): """ @@ -49,7 +49,7 @@ def read_file_dist(path): data = f.read() data = io.BytesIO(data) return data - + def unwrap_dist(model): """ @@ -74,7 +74,7 @@ def master_first(): else: dist.barrier() yield - + @contextmanager def local_master_first(): @@ -90,4 +90,3 @@ def local_master_first(): else: dist.barrier() yield - \ No newline at end of file diff --git a/trellis2/utils/elastic_utils.py b/trellis2/utils/elastic_utils.py index cba3cf8..e7a0fe4 100644 --- a/trellis2/utils/elastic_utils.py +++ b/trellis2/utils/elastic_utils.py @@ -10,29 +10,29 @@ class MemoryController: """ Base class for memory management during training. """ - + _last_input_size = None _last_mem_ratio = [] - + @contextmanager def record(self): pass - + def update_run_states(self, input_size=None, mem_ratio=None): if self._last_input_size is None: self._last_input_size = input_size elif self._last_input_size!= input_size: raise ValueError(f'Input size should not change for different ElasticModules.') self._last_mem_ratio.append(mem_ratio) - + @abstractmethod def get_mem_ratio(self, input_size): pass - + @abstractmethod def state_dict(self): pass - + @abstractmethod def log(self): pass @@ -63,7 +63,7 @@ def __init__( self.target_ratio = target_ratio self.device = device or torch.cuda.current_device() self.available_memory = available_memory or torch.cuda.get_device_properties(self.device).total_memory / 1024**3 - + self._memory = np.zeros(buffer_size, dtype=np.float32) self._input_size = np.zeros(buffer_size, dtype=np.float32) self._mem_ratio = np.zeros(buffer_size, dtype=np.float32) @@ -75,14 +75,14 @@ def __init__( def __repr__(self): return f'LinearMemoryController(target_ratio={self.target_ratio}, available_memory={self.available_memory})' - + def _add_sample(self, memory, input_size, mem_ratio): self._memory[self._buffer_ptr] = memory self._input_size[self._buffer_ptr] = input_size self._mem_ratio[self._buffer_ptr] = mem_ratio self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size self._buffer_length = min(self._buffer_length + 1, self.buffer_size) - + @contextmanager def record(self): torch.cuda.reset_peak_memory_stats(self.device) @@ -96,45 +96,45 @@ def record(self): if self.step % self.update_every == 0: self._max_mem_ratio = min(1.0, self._max_mem_ratio + 0.1) self._fit_params() - + def _fit_params(self): memory_usage = self._memory[:self._buffer_length] input_size = self._input_size[:self._buffer_length] mem_ratio = self._mem_ratio[:self._buffer_length] - + x = input_size * mem_ratio y = memory_usage k, b = np.polyfit(x, y, 1) self._params = (k, b) # self._visualize() - + def _visualize(self): import matplotlib.pyplot as plt memory_usage = self._memory[:self._buffer_length] input_size = self._input_size[:self._buffer_length] mem_ratio = self._mem_ratio[:self._buffer_length] k, b = self._params - + plt.scatter(input_size * mem_ratio, memory_usage, c=mem_ratio, cmap='viridis') x = np.array([0.0, 20000.0]) plt.plot(x, k * x + b, c='r') plt.savefig(f'linear_memory_controller_{self.step}.png') plt.cla() - + def get_mem_ratio(self, input_size): k, b = self._params if k == 0: return np.random.rand() * self._max_mem_ratio pred = (self.available_memory * self.target_ratio - b) / (k * input_size) return min(self._max_mem_ratio, max(0.0, pred)) - + def state_dict(self): return { 'params': self._params, } - + def load_state_dict(self, state_dict): self._params = tuple(state_dict['params']) - + def log(self): return { 'params/k': self._params[0], @@ -143,8 +143,8 @@ def log(self): 'input_size': self._last_input_size, 'mem_ratio': self._last_mem_ratio, } - - + + class ElasticModule(nn.Module): """ Module for training with elastic memory management. @@ -152,27 +152,27 @@ class ElasticModule(nn.Module): def __init__(self): super().__init__() self._memory_controller: MemoryController = None - + @abstractmethod def _get_input_size(self, *args, **kwargs) -> int: """ Get the size of the input data. - + Returns: int: The size of the input data. """ pass - + @abstractmethod def _forward_with_mem_ratio(self, *args, mem_ratio=0.0, **kwargs) -> Tuple[float, Tuple]: """ Forward with a given memory ratio. """ pass - + def register_memory_controller(self, memory_controller: MemoryController): self._memory_controller = memory_controller - + def forward(self, *args, **kwargs): if self._memory_controller is None or not torch.is_grad_enabled() or not self.training: _, ret = self._forward_with_mem_ratio(*args, **kwargs) @@ -182,7 +182,7 @@ def forward(self, *args, **kwargs): mem_ratio, ret = self._forward_with_mem_ratio(*args, mem_ratio=mem_ratio, **kwargs) self._memory_controller.update_run_states(input_size, mem_ratio) return ret - + class ElasticModuleMixin: """ @@ -191,31 +191,31 @@ class ElasticModuleMixin: def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._memory_controller: MemoryController = None - + @abstractmethod def _get_input_size(self, *args, **kwargs) -> int: """ Get the size of the input data. - + Returns: int: The size of the input data. """ pass - + @abstractmethod @contextmanager def with_mem_ratio(self, mem_ratio=1.0) -> float: """ Context manager for training with a reduced memory ratio compared to the full memory usage. - + Returns: float: The exact memory ratio used during the forward pass. """ pass - + def register_memory_controller(self, memory_controller: MemoryController): self._memory_controller = memory_controller - + def forward(self, *args, **kwargs): if self._memory_controller is None or not torch.is_grad_enabled() or not self.training: ret = super().forward(*args, **kwargs) diff --git a/trellis2/utils/general_utils.py b/trellis2/utils/general_utils.py index 589c103..24608d0 100644 --- a/trellis2/utils/general_utils.py +++ b/trellis2/utils/general_utils.py @@ -141,7 +141,7 @@ def make_grid(images, nrow=None, ncol=None, aspect_ratio=None): ncol = (num_images + nrow - 1) // nrow else: assert nrow * ncol >= num_images, 'nrow * ncol must be greater than or equal to the number of images' - + if images[0].ndim == 2: grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1]), dtype=images[0].dtype) else: @@ -169,14 +169,14 @@ def text_image(text, resolution=(512, 512), max_size=0.5, h_align="left", v_alig and scaled so that it fits completely within the image while preserving any explicit line breaks and original spacing. Horizontal and vertical alignment can be controlled via flags. - + Parameters: text (str): The input text. Newline characters and spacing are preserved. resolution (tuple): The image resolution as (width, height). max_size (float): The maximum font size. h_align (str): Horizontal alignment. Options: "left", "center", "right". v_align (str): Vertical alignment. Options: "top", "center", "bottom". - + Returns: numpy.ndarray: The resulting image (BGR format) with the text drawn. """ @@ -201,14 +201,14 @@ def wrap_line(line, max_width, font, thickness, scale): width (measured at the given scale) does not exceed max_width. This function preserves the original spacing by splitting the line into tokens (words and whitespace) using a regular expression. - + Parameters: line (str): The input text line. max_width (int): Maximum allowed width in pixels. font (int): OpenCV font identifier. thickness (int): Text thickness. scale (float): The current font scale. - + Returns: List[str]: A list of wrapped lines. """ @@ -216,7 +216,7 @@ def wrap_line(line, max_width, font, thickness, scale): tokens = re.split(r'(\s+)', line) if not tokens: return [''] - + wrapped_lines = [] current_line = "" for token in tokens: @@ -249,7 +249,7 @@ def compute_text_block(scale): """ Wrap the entire text (splitting at explicit newline characters) using the provided scale, and then compute the overall width and height of the text block. - + Returns: wrapped_lines (List[str]): The list of wrapped lines. block_width (int): Maximum width among the wrapped lines. @@ -263,18 +263,18 @@ def compute_text_block(scale): for line in input_lines: wrapped = wrap_line(line, avail_width, font, thickness, scale) wrapped_lines.extend(wrapped) - + sizes = [] for line in wrapped_lines: (text_size, _) = cv2.getTextSize(line, font, scale, thickness) sizes.append(text_size) # (width, height) - + block_width = max((w for w, h in sizes), default=0) # Use the height of "A" (at the current scale) to compute line spacing base_height = cv2.getTextSize("A", font, scale, thickness)[0][1] spacing = int(line_spacing_ratio * base_height) block_height = sum(h for w, h in sizes) + spacing * (len(sizes) - 1) if sizes else 0 - + return wrapped_lines, block_width, block_height, sizes, spacing # Use binary search to find the maximum scale that allows the text block to fit @@ -298,7 +298,7 @@ def compute_text_block(scale): if best_result is None: best_scale = 0.5 best_result = compute_text_block(best_scale) - + wrapped_lines, block_width, block_height, sizes, spacing = best_result # Compute starting y-coordinate based on vertical alignment flag diff --git a/trellis2/utils/grad_clip_utils.py b/trellis2/utils/grad_clip_utils.py index 990a435..a392d4d 100644 --- a/trellis2/utils/grad_clip_utils.py +++ b/trellis2/utils/grad_clip_utils.py @@ -17,7 +17,7 @@ def __init__( self.max_norm = max_norm self.clip_percentile = clip_percentile self.buffer_size = buffer_size - + self._grad_norm = np.zeros(buffer_size, dtype=np.float32) self._max_norm = max_norm self._buffer_ptr = 0 @@ -25,7 +25,7 @@ def __init__( def __repr__(self): return f'AdaptiveGradClipper(max_norm={self.max_norm}, clip_percentile={self.clip_percentile})' - + def state_dict(self): return { 'grad_norm': self._grad_norm, @@ -69,7 +69,7 @@ def __call__(self, parameters, norm_type=2.0, error_if_nonfinite=False, foreach= """ max_norm = self._max_norm if self._max_norm is not None else float('inf') grad_norm = torch.nn.utils.clip_grad_norm_(parameters, max_norm=max_norm, norm_type=norm_type, error_if_nonfinite=error_if_nonfinite, foreach=foreach) - + if torch.isfinite(grad_norm): self._grad_norm[self._buffer_ptr] = grad_norm self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size @@ -77,5 +77,5 @@ def __call__(self, parameters, norm_type=2.0, error_if_nonfinite=False, foreach= if self._buffer_length == self.buffer_size: self._max_norm = np.percentile(self._grad_norm, self.clip_percentile) self._max_norm = min(self._max_norm, self.max_norm) if self.max_norm is not None else self._max_norm - + return grad_norm \ No newline at end of file diff --git a/trellis2/utils/mesh_utils.py b/trellis2/utils/mesh_utils.py index a9f1451..9b8e932 100644 --- a/trellis2/utils/mesh_utils.py +++ b/trellis2/utils/mesh_utils.py @@ -9,10 +9,10 @@ def read_ply(filename): """ Read a PLY file and return vertices, triangle faces, and quad faces. - + Args: filename (str): The file path to read from. - + Returns: vertices (np.ndarray): Array of shape [N, 3] containing vertex positions. tris (np.ndarray): Array of shape [M, 3] containing triangle face indices (empty if none). @@ -29,40 +29,40 @@ def read_ply(filename): if b"end_header" in line: break header = header_bytes.decode('utf-8') - + # Determine if the file is in ASCII or binary format is_ascii = "ascii" in header - + # Extract the number of vertices and faces from the header using regex vertex_match = re.search(r'element vertex (\d+)', header) if vertex_match: num_vertices = int(vertex_match.group(1)) else: raise ValueError("Vertex count not found in header") - + face_match = re.search(r'element face (\d+)', header) if face_match: num_faces = int(face_match.group(1)) else: raise ValueError("Face count not found in header") - + vertices = [] tris = [] quads = [] - + if is_ascii: # For ASCII format, read each line of vertex data (each line contains 3 floats) for _ in range(num_vertices): line = f.readline().decode('utf-8').strip() - if not line: + if not line: continue parts = line.split() vertices.append([float(parts[0]), float(parts[1]), float(parts[2])]) - + # Read face data, where the first number indicates the number of vertices for the face for _ in range(num_faces): line = f.readline().decode('utf-8').strip() - if not line: + if not line: continue parts = line.split() count = int(parts[0]) @@ -83,7 +83,7 @@ def read_ply(filename): raise ValueError("Insufficient vertex data") v = struct.unpack(' 0 else np.empty((0, 3), dtype=np.int32) quads = np.array(quads, dtype=np.int32) if len(quads) > 0 else np.empty((0, 4), dtype=np.int32) - + return vertices, tris, quads @@ -128,7 +128,7 @@ def write_ply( """ Write a mesh to a PLY file, with the option to save in ASCII or binary format, and optional per-vertex colors. - + Args: filename (str): The filename to write to. vertices (np.ndarray): [N, 3] The vertex positions. @@ -211,7 +211,7 @@ def write_ply( f.write(struct.pack('