Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions trellis2/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
__attributes = {
'FlexiDualGridDataset': 'flexi_dual_grid',
'SparseVoxelPbrDataset':'sparse_voxel_pbr',

'SparseStructureLatent': 'sparse_structure_latent',
'TextConditionedSparseStructureLatent': 'sparse_structure_latent',
'ImageConditionedSparseStructureLatent': 'sparse_structure_latent',

'SLat': 'structured_latent',
'ImageConditionedSLat': 'structured_latent',
'SLatShape': 'structured_latent_shape',
Expand Down Expand Up @@ -35,12 +35,11 @@ def __getattr__(name):


# For Pylance
if __name__ == '__main__':
if __name__ == '__main__':
from .flexi_dual_grid import FlexiDualGridDataset
from .sparse_voxel_pbr import SparseVoxelPbrDataset

from .sparse_structure_latent import SparseStructureLatent, ImageConditionedSparseStructureLatent
from .structured_latent import SLat, ImageConditionedSLat
from .structured_latent_shape import SLatShape, ImageConditionedSLatShape
from .structured_latent_svpbr import SLatPbr, ImageConditionedSLatPbr

22 changes: 11 additions & 11 deletions trellis2/datasets/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self,
root_type = 'list'
self.instances = []
self.metadata = pd.DataFrame()

self._stats = {}
if root_type == 'obj':
for key, root in self.roots.items():
Expand All @@ -54,15 +54,15 @@ def __init__(self,
self.instances.extend([(root, sha256) for sha256 in metadata['sha256'].values])
metadata.set_index('sha256', inplace=True)
self.metadata = pd.concat([self.metadata, metadata])

@abstractmethod
def filter_metadata(self, metadata: pd.DataFrame) -> Tuple[pd.DataFrame, Dict[str, int]]:
pass

@abstractmethod
def get_instance(self, root, instance: str) -> Dict[str, Any]:
pass

def __len__(self):
return len(self.instances)

Expand All @@ -73,7 +73,7 @@ def __getitem__(self, index) -> Dict[str, Any]:
except Exception as e:
print(f'Error loading {instance}: {e}')
return self.__getitem__(np.random.randint(0, len(self)))

def __str__(self):
lines = []
lines.append(self.__class__.__name__)
Expand All @@ -90,16 +90,16 @@ class ImageConditionedMixin:
def __init__(self, roots, *, image_size=518, **kwargs):
self.image_size = image_size
super().__init__(roots, **kwargs)

def filter_metadata(self, metadata):
metadata, stats = super().filter_metadata(metadata)
metadata = metadata[metadata['cond_rendered'].notna()]
stats['Cond rendered'] = len(metadata)
return metadata, stats

def get_instance(self, root, instance):
pack = super().get_instance(root, instance)

image_root = os.path.join(root['render_cond'], instance)
with open(os.path.join(image_root, 'transforms.json')) as f:
metadata = json.load(f)
Expand Down Expand Up @@ -128,7 +128,7 @@ def get_instance(self, root, instance):
alpha = torch.tensor(np.array(alpha)).float() / 255.0
image = image * alpha.unsqueeze(0)
pack['cond'] = image

return pack


Expand All @@ -143,10 +143,10 @@ def filter_metadata(self, metadata):
metadata = metadata[metadata['cond_rendered'].notna()]
stats['Cond rendered'] = len(metadata)
return metadata, stats

def get_instance(self, root, instance):
pack = super().get_instance(root, instance)

image_root = os.path.join(root['render_cond'], instance)
with open(os.path.join(image_root, 'transforms.json')) as f:
metadata = json.load(f)
Expand Down
29 changes: 14 additions & 15 deletions trellis2/datasets/flexi_dual_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ class FlexiDualGridVisMixin:
@torch.no_grad()
def visualize_sample(self, x: dict):
mesh = x['mesh']

renderer = MeshRenderer({'near': 1, 'far': 3})
renderer.rendering_options.resolution = 512
renderer.rendering_options.ssaa = 4

# Build camera
yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
Expand All @@ -39,7 +39,7 @@ def visualize_sample(self, x: dict):
intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
exts.append(extrinsics)
ints.append(intrinsics)

# Build each representation
images = []
for m in mesh:
Expand All @@ -50,14 +50,14 @@ def visualize_sample(self, x: dict):
renderer.render(m.cuda(), ext, intr)['normal']
images.append(image)
images = torch.stack(images)

return images


class FlexiDualGridDataset(FlexiDualGridVisMixin, StandardDatasetBase):
"""
Flexible Dual Grid Dataset

Args:
roots (str): path to the dataset
resolution (int): resolution of the voxel grid
Expand All @@ -79,16 +79,16 @@ def __init__(
self.value_range = (0, 1)

super().__init__(roots)

self.loads = [self.metadata.loc[sha256, f'dual_grid_size'] for _, sha256 in self.instances]

def __str__(self):
lines = [
super().__str__(),
f' - Resolution: {self.resolution}',
]
return '\n'.join(lines)

def filter_metadata(self, metadata):
stats = {}
metadata = metadata[metadata[f'dual_grid_converted'] == True]
Expand All @@ -102,7 +102,7 @@ def filter_metadata(self, metadata):
metadata = metadata[metadata['num_faces'] <= self.max_num_faces]
stats[f'Faces <= {self.max_num_faces}'] = len(metadata)
return metadata, stats

def read_mesh(self, root, instance):
with open(os.path.join(root, f'{instance}.pickle'), 'rb') as f:
dump = pickle.load(f)
Expand All @@ -124,7 +124,7 @@ def read_mesh(self, root, instance):
vertices = (vertices - center) * scale
assert torch.all(vertices >= -0.5) and torch.all(vertices <= 0.5), 'vertices out of range'
return {'mesh': [Mesh(vertices=vertices, faces=faces)]}

def read_dual_grid(self, root, instance):
coords, attr = o_voxel.io.read_vxz(os.path.join(root, f'{instance}.vxz'), num_threads=4)
vertices = sp.SparseTensor(
Expand All @@ -142,7 +142,7 @@ def get_instance(self, root, instance):
mesh = self.read_mesh(root['mesh_dump'], instance)
dual_grid = self.read_dual_grid(root['dual_grid'], instance)
return {**mesh, **dual_grid}

@staticmethod
def collate_fn(batch, split_size=None):
if split_size is None:
Expand All @@ -164,10 +164,9 @@ def collate_fn(batch, split_size=None):
pack[k] = sum([b[k] for b in sub_batch], [])
else:
pack[k] = [b[k] for b in sub_batch]

packs.append(pack)

if split_size is None:
return packs[0]
return packs

21 changes: 10 additions & 11 deletions trellis2/datasets/sparse_structure_latent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
self.pretrained_ss_dec = pretrained_ss_dec
self.ss_dec_path = ss_dec_path
self.ss_dec_ckpt = ss_dec_ckpt

def _loading_ss_dec(self):
if self.ss_dec is not None:
return
Expand Down Expand Up @@ -57,11 +57,11 @@ def decode_latent(self, z, batch_size=4):
def visualize_sample(self, x_0: Union[torch.Tensor, dict]):
x_0 = x_0 if isinstance(x_0, torch.Tensor) else x_0['x_0']
x_0 = self.decode_latent(x_0.cuda())

renderer = VoxelRenderer()
renderer.rendering_options.resolution = 512
renderer.rendering_options.ssaa = 4

# build camera
yaw = [0, np.pi/2, np.pi, 3*np.pi/2]
yaw_offset = -16 / 180 * np.pi
Expand All @@ -70,7 +70,7 @@ def visualize_sample(self, x_0: Union[torch.Tensor, dict]):
exts, ints = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, 2, 30)

images = []

# Build each representation
x_0 = x_0.cuda()
for i in range(x_0.shape[0]):
Expand All @@ -92,14 +92,14 @@ def visualize_sample(self, x_0: Union[torch.Tensor, dict]):
res = renderer.render(rep, ext, intr, colors_overwrite=color)
image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color']
images.append(image)

return torch.stack(images)


class SparseStructureLatent(SparseStructureLatentVisMixin, StandardDatasetBase):
"""
Sparse structure latent dataset

Args:
roots (str): path to the dataset
min_aesthetic_score (float): minimum aesthetic score
Expand All @@ -120,26 +120,26 @@ def __init__(self,
self.min_aesthetic_score = min_aesthetic_score
self.normalization = normalization
self.value_range = (0, 1)

super().__init__(
roots,
pretrained_ss_dec=pretrained_ss_dec,
ss_dec_path=ss_dec_path,
ss_dec_ckpt=ss_dec_ckpt,
)

if self.normalization is not None:
self.mean = torch.tensor(self.normalization['mean']).reshape(-1, 1, 1, 1)
self.std = torch.tensor(self.normalization['std']).reshape(-1, 1, 1, 1)

def filter_metadata(self, metadata):
stats = {}
metadata = metadata[metadata['ss_latent_encoded'] == True]
stats['With latent'] = len(metadata)
metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
return metadata, stats

def get_instance(self, root, instance):
latent = np.load(os.path.join(root['ss_latent'], f'{instance}.npz'))
z = torch.tensor(latent['z']).float()
Expand All @@ -157,4 +157,3 @@ class ImageConditionedSparseStructureLatent(ImageConditionedMixin, SparseStructu
Image-conditioned sparse structure dataset
"""
pass

Loading