diff --git a/nodes.py b/nodes.py index 08f6398..0597634 100644 --- a/nodes.py +++ b/nodes.py @@ -40,7 +40,7 @@ def pil2tensor(image): return torch.from_numpy(np.array(image).astype(np.float32) / 255.0)[None,] - + def tensor2pil(image: torch.Tensor) -> Image.Image: """ Accepts either: @@ -61,8 +61,8 @@ def tensor2pil(image: torch.Tensor) -> Image.Image: arr = (t.numpy() * 255.0).clip(0, 255).astype(np.uint8) return Image.fromarray(arr) - raise TypeError(f"tensor2pil expected torch.Tensor, got {type(image)}") - + raise TypeError(f"tensor2pil expected torch.Tensor, got {type(image)}") + def tensor_batch_to_pil_list(images: torch.Tensor, max_views: int = 4) -> list[Image.Image]: """ Converts a ComfyUI IMAGE tensor (B,H,W,C) into a list of PIL images. @@ -79,40 +79,40 @@ def tensor_batch_to_pil_list(images: torch.Tensor, max_views: int = 4) -> list[I if images.ndim == 3: return [tensor2pil(images)] - raise ValueError(f"Unsupported IMAGE tensor shape: {tuple(images.shape)}") - + raise ValueError(f"Unsupported IMAGE tensor shape: {tuple(images.shape)}") + def convert_tensor_images_to_pil(images): pil_array = [] - + for image in images: pil_array.append(tensor2pil(image)) - + return pil_array - + def simplify_with_meshlib(vertices, faces, target=1000000): current_faces_num = len(faces) print(f'Current Faces Number: {current_faces_num}') - + if current_faces_num 0 rast_chunk[..., 3:4] += i # Store face ID in alpha channel rast = torch.where(mask_chunk, rast_chunk, rast) - + # Mask of valid pixels in texture mask = rast[0, ..., 3] > 0 - + # Interpolate 3D positions in UV space (finding 3D coord for every texel) pos = dr.interpolate(out_vertices.unsqueeze(0), rast, out_faces)[0][0] valid_pos = pos[mask] - + # Map these positions back to the *original* high-res mesh to get accurate attributes # This corrects geometric errors introduced by simplification/remeshing _, face_id, uvw = bvh.unsigned_distance(valid_pos, return_uvw=True) orig_tri_verts = vertices[faces[face_id.long()]] # (N_new, 3, 3) valid_pos = (orig_tri_verts * uvw.unsqueeze(-1)).sum(dim=1) - + # Trilinear sampling from the attribute volume (Color, Material props) attrs = torch.zeros(texture_size, texture_size, attr_volume.shape[1], device='cuda') attrs[mask] = grid_sample_3d( @@ -694,31 +727,31 @@ def process(self, mesh, mesh_cluster_threshold_cone_half_angle_rad, mesh_cluster grid=((valid_pos - aabb[0]) / voxel_size).reshape(1, -1, 3), mode='trilinear', ) - + # --- Texture Post-Processing & Material Construction --- print("Finalizing mesh...") - + mask = mask.cpu().numpy() - + # Extract channels based on layout (BaseColor, Metallic, Roughness, Alpha) base_color = np.clip(attrs[..., attr_layout['base_color']].cpu().numpy() * 255, 0, 255).astype(np.uint8) metallic = np.clip(attrs[..., attr_layout['metallic']].cpu().numpy() * 255, 0, 255).astype(np.uint8) roughness = np.clip(attrs[..., attr_layout['roughness']].cpu().numpy() * 255, 0, 255).astype(np.uint8) alpha = np.clip(attrs[..., attr_layout['alpha']].cpu().numpy() * 255, 0, 255).astype(np.uint8) alpha_mode = texture_alpha_mode - + # Inpainting: fill gaps (dilation) to prevent black seams at UV boundaries mask_inv = (~mask).astype(np.uint8) base_color = cv2.inpaint(base_color, mask_inv, 3, cv2.INPAINT_TELEA) metallic = cv2.inpaint(metallic, mask_inv, 1, cv2.INPAINT_TELEA)[..., None] roughness = cv2.inpaint(roughness, mask_inv, 1, cv2.INPAINT_TELEA)[..., None] alpha = cv2.inpaint(alpha, mask_inv, 1, cv2.INPAINT_TELEA)[..., None] - + # Create PBR material # Standard PBR packs Metallic and Roughness into Blue and Green channels baseColorTexture_np = Image.fromarray(np.concatenate([base_color, alpha], axis=-1)) metallicRoughnessTexture_np = Image.fromarray(np.concatenate([np.zeros_like(metallic), roughness, metallic], axis=-1)) - + material = Trimesh.visual.material.PBRMaterial( baseColorTexture=baseColorTexture_np, baseColorFactor=np.array([255, 255, 255, 255], dtype=np.uint8), @@ -727,34 +760,34 @@ def process(self, mesh, mesh_cluster_threshold_cone_half_angle_rad, mesh_cluster roughnessFactor=1.0, alphaMode=alpha_mode, doubleSided=double_side_material, - ) - + ) + vertices_np = out_vertices.cpu().numpy() faces_np = out_faces.cpu().numpy() uvs_np = out_uvs.cpu().numpy() normals_np = out_normals.cpu().numpy() - + # Swap Y and Z axes, invert Y (common conversion for GLB compatibility) vertices_np[:, 1], vertices_np[:, 2] = vertices_np[:, 2], -vertices_np[:, 1] normals_np[:, 1], normals_np[:, 2] = normals_np[:, 2], -normals_np[:, 1] uvs_np[:, 1] = 1 - uvs_np[:, 1] # Flip UV V-coordinate - + textured_mesh = Trimesh.Trimesh( vertices=vertices_np, faces=faces_np, vertex_normals=normals_np, process=False, visual=Trimesh.visual.TextureVisuals(uv=uvs_np,material=material) - ) + ) del cumesh - gc.collect() + gc.collect() baseColorTexture = pil2tensor(baseColorTexture_np) metallicRoughnessTexture = pil2tensor(metallicRoughnessTexture_np) - + return (textured_mesh, baseColorTexture, metallicRoughnessTexture, ) - + class Trellis2MeshWithVoxelAdvancedGenerator: @classmethod def INPUT_TYPES(s): @@ -771,11 +804,11 @@ def INPUT_TYPES(s): "shape_steps": ("INT",{"default":12, "min":1, "max":100},), "shape_guidance_strength": ("FLOAT",{"default":7.50}), "shape_guidance_rescale": ("FLOAT",{"default":0.50}), - "shape_rescale_t": ("FLOAT",{"default":3.00}), + "shape_rescale_t": ("FLOAT",{"default":3.00}), "texture_steps": ("INT",{"default":12, "min":1, "max":100},), "texture_guidance_strength": ("FLOAT",{"default":1.00}), "texture_guidance_rescale": ("FLOAT",{"default":0.00}), - "texture_rescale_t": ("FLOAT",{"default":3.00}), + "texture_rescale_t": ("FLOAT",{"default":3.00}), "max_num_tokens": ("INT",{"default":49152,"min":0,"max":999999}), "max_views": ("INT", {"default": 4, "min": 1, "max": 16}), "sparse_structure_resolution": ("INT", {"default":32,"min":8,"max":128,"step":8}), @@ -796,18 +829,18 @@ def INPUT_TYPES(s): CATEGORY = "Trellis2Wrapper" OUTPUT_NODE = True - def process(self, pipeline, image, seed, pipeline_type, sparse_structure_steps, - sparse_structure_guidance_strength, + def process(self, pipeline, image, seed, pipeline_type, sparse_structure_steps, + sparse_structure_guidance_strength, sparse_structure_guidance_rescale, sparse_structure_rescale_t, - shape_steps, - shape_guidance_strength, + shape_steps, + shape_guidance_strength, shape_guidance_rescale, - shape_rescale_t, - texture_steps, - texture_guidance_strength, + shape_rescale_t, + texture_steps, + texture_guidance_strength, texture_guidance_rescale, - texture_rescale_t, + texture_rescale_t, max_num_tokens, max_views, sparse_structure_resolution, @@ -822,25 +855,25 @@ def process(self, pipeline, image, seed, pipeline_type, sparse_structure_steps, images = tensor_batch_to_pil_list(image, max_views=max_views) image_in = images[0] if len(images) == 1 else images - + sparse_structure_guidance_interval = [sparse_structure_guidance_interval_start,sparse_structure_guidance_interval_end] shape_guidance_interval = [shape_guidance_interval_start,shape_guidance_interval_end] texture_guidance_interval = [texture_guidance_interval_start,texture_guidance_interval_end] - - sparse_structure_sampler_params = {"steps":sparse_structure_steps,"guidance_strength":sparse_structure_guidance_strength,"guidance_rescale":sparse_structure_guidance_rescale,"guidance_interval":sparse_structure_guidance_interval,"rescale_t":sparse_structure_rescale_t} - shape_slat_sampler_params = {"steps":shape_steps,"guidance_strength":shape_guidance_strength,"guidance_rescale":shape_guidance_rescale,"guidance_interval":shape_guidance_interval,"rescale_t":shape_rescale_t} + + sparse_structure_sampler_params = {"steps":sparse_structure_steps,"guidance_strength":sparse_structure_guidance_strength,"guidance_rescale":sparse_structure_guidance_rescale,"guidance_interval":sparse_structure_guidance_interval,"rescale_t":sparse_structure_rescale_t} + shape_slat_sampler_params = {"steps":shape_steps,"guidance_strength":shape_guidance_strength,"guidance_rescale":shape_guidance_rescale,"guidance_interval":shape_guidance_interval,"rescale_t":shape_rescale_t} tex_slat_sampler_params = {"steps":texture_steps,"guidance_strength":texture_guidance_strength,"guidance_rescale":texture_guidance_rescale,"guidance_interval":texture_guidance_interval,"rescale_t":texture_rescale_t} - + if generate_texture_slat: num_steps = 5 else: num_steps = 4 pbar = ProgressBar(num_steps) - - mesh = pipeline.run(image=image_in, seed=seed, pipeline_type=pipeline_type, sparse_structure_sampler_params = sparse_structure_sampler_params, shape_slat_sampler_params = shape_slat_sampler_params, tex_slat_sampler_params = tex_slat_sampler_params, max_num_tokens = max_num_tokens, sparse_structure_resolution = sparse_structure_resolution, max_views = max_views, generate_texture_slat=generate_texture_slat, use_tiled=use_tiled_decoder, pbar=pbar)[0] - - return (mesh,) + + mesh = pipeline.run(image=image_in, seed=seed, pipeline_type=pipeline_type, sparse_structure_sampler_params = sparse_structure_sampler_params, shape_slat_sampler_params = shape_slat_sampler_params, tex_slat_sampler_params = tex_slat_sampler_params, max_num_tokens = max_num_tokens, sparse_structure_resolution = sparse_structure_resolution, max_views = max_views, generate_texture_slat=generate_texture_slat, use_tiled=use_tiled_decoder, pbar=pbar)[0] + + return (mesh,) class Trellis2PostProcessAndUnWrapAndRasterizer: @classmethod @@ -851,7 +884,7 @@ def INPUT_TYPES(s): "mesh_cluster_threshold_cone_half_angle_rad": ("FLOAT",{"default":90.0,"min":0.0,"max":359.9}), "mesh_cluster_refine_iterations": ("INT",{"default":0}), "mesh_cluster_global_iterations": ("INT",{"default":1}), - "mesh_cluster_smooth_strength": ("INT",{"default":1}), + "mesh_cluster_smooth_strength": ("INT",{"default":1}), "texture_size": ("INT",{"default":2048, "min":512, "max":16384}), "remesh": ("BOOLEAN",{"default":True}), "remesh_band": ("FLOAT",{"default":1.0}), @@ -876,14 +909,14 @@ def INPUT_TYPES(s): def process(self, mesh, mesh_cluster_threshold_cone_half_angle_rad, mesh_cluster_refine_iterations, mesh_cluster_global_iterations, mesh_cluster_smooth_strength, texture_size, remesh, remesh_band, remesh_project, target_face_num, simplify_method, fill_holes, fill_holes_max_perimeter, texture_alpha_mode, dual_contouring_resolution, double_side_material,remove_floaters): pbar = ProgressBar(5) mesh_copy = copy.deepcopy(mesh) - + aabb = [[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]] - + attr_volume = mesh_copy.attrs coords = mesh_copy.coords attr_layout = mesh_copy.layout - voxel_size = mesh_copy.voxel_size - + voxel_size = mesh_copy.voxel_size + mesh_cluster_threshold_cone_half_angle_rad = np.radians(mesh_cluster_threshold_cone_half_angle_rad) # --- Input Normalization (AABB, Voxel Size, Grid Size) --- @@ -892,7 +925,7 @@ def process(self, mesh, mesh_cluster_threshold_cone_half_angle_rad, mesh_cluster if isinstance(aabb, np.ndarray): aabb = torch.tensor(aabb, dtype=torch.float32, device=coords.device) - # Calculate grid dimensions based on AABB and voxel size + # Calculate grid dimensions based on AABB and voxel size if voxel_size is not None: if isinstance(voxel_size, float): voxel_size = [voxel_size, voxel_size, voxel_size] @@ -909,50 +942,50 @@ def process(self, mesh, mesh_cluster_threshold_cone_half_angle_rad, mesh_cluster if isinstance(grid_size, np.ndarray): grid_size = torch.tensor(grid_size, dtype=torch.int32, device=coords.device) voxel_size = (aabb[1] - aabb[0]) / grid_size - + if remove_floaters: mesh_copy = remove_floater(mesh_copy) - + vertices = mesh_copy.vertices faces = mesh_copy.faces - + vertices = vertices.cuda() - faces = faces.cuda() - + faces = faces.cuda() + # Initialize CUDA mesh handler cumesh = CuMesh.CuMesh() cumesh.init(vertices, faces) print(f"Current vertices: {cumesh.num_vertices}, faces: {cumesh.num_faces}") - + # --- Initial Mesh Cleaning --- # Fills holes as much as we can before processing if fill_holes: cumesh.fill_holes(max_hole_perimeter=fill_holes_max_perimeter) - print(f"After filling holes: {cumesh.num_vertices} vertices, {cumesh.num_faces} faces") - + print(f"After filling holes: {cumesh.num_vertices} vertices, {cumesh.num_faces} faces") + # Build BVH for the current mesh to guide remeshing print(f"Building BVH for current mesh...") bvh = CuMesh.cuBVH(vertices, faces) pbar.update(1) - - print("Cleaning mesh...") + + print("Cleaning mesh...") # --- Branch 1: Standard Pipeline (Simplification & Cleaning) --- - if not remesh: + if not remesh: if simplify_method == 'Cumesh': cumesh.simplify(target_face_num * 3, verbose=True) elif simplify_method == 'Meshlib': # GPU -> CPU -> Meshlib -> CPU -> GPU v, f = cumesh.read() new_vertices, new_faces = simplify_with_meshlib(v.cpu().numpy(), f.cpu().numpy(), target_face_num) - cumesh.init(torch.from_numpy(new_vertices).float().cuda(), torch.from_numpy(new_faces).int().cuda()) - + cumesh.init(torch.from_numpy(new_vertices).float().cuda(), torch.from_numpy(new_faces).int().cuda()) + cumesh.remove_duplicate_faces() cumesh.repair_non_manifold_edges() cumesh.remove_small_connected_components(1e-5) - + if fill_holes: cumesh.fill_holes(max_hole_perimeter=fill_holes_max_perimeter) - + if simplify_method == 'Cumesh': cumesh.simplify(target_face_num, verbose=True) elif simplify_method == 'Meshlib': @@ -960,30 +993,30 @@ def process(self, mesh, mesh_cluster_threshold_cone_half_angle_rad, mesh_cluster v, f = cumesh.read() new_vertices, new_faces = simplify_with_meshlib(v.cpu().numpy(), f.cpu().numpy(), target_face_num) cumesh.init(torch.from_numpy(new_vertices).float().cuda(), torch.from_numpy(new_faces).int().cuda()) - + cumesh.remove_duplicate_faces() cumesh.repair_non_manifold_edges() - cumesh.remove_small_connected_components(1e-5) + cumesh.remove_small_connected_components(1e-5) if fill_holes: - cumesh.fill_holes(max_hole_perimeter=fill_holes_max_perimeter) - - print(f"After initial cleanup: {cumesh.num_vertices} vertices, {cumesh.num_faces} faces") - + cumesh.fill_holes(max_hole_perimeter=fill_holes_max_perimeter) + + print(f"After initial cleanup: {cumesh.num_vertices} vertices, {cumesh.num_faces} faces") + # Step 2: Unify face orientations cumesh.unify_face_orientations() - + # --- Branch 2: Remeshing Pipeline --- else: center = aabb.mean(dim=0) scale = (aabb[1] - aabb[0]).max().item() - + if dual_contouring_resolution == "Auto": resolution = grid_size.max().item() print(f"Dual Contouring resolution: {resolution}") else: resolution = int(dual_contouring_resolution) - + print('Performing Dual Contouring ...') # Perform Dual Contouring remeshing (rebuilds topology) cumesh.init(*CuMesh.remeshing.remesh_narrow_band_dc( @@ -996,11 +1029,11 @@ def process(self, mesh, mesh_cluster_threshold_cone_half_angle_rad, mesh_cluster verbose = True, bvh = bvh, )) - + print(f"After remeshing: {cumesh.num_vertices} vertices, {cumesh.num_faces} faces") - + # Step 2: Unify face orientations - #cumesh.unify_face_orientations() + #cumesh.unify_face_orientations() if simplify_method == 'Cumesh': cumesh.simplify(target_face_num, verbose=True) @@ -1010,10 +1043,10 @@ def process(self, mesh, mesh_cluster_threshold_cone_half_angle_rad, mesh_cluster new_vertices, new_faces = simplify_with_meshlib(v.cpu().numpy(), f.cpu().numpy(), target_face_num) cumesh.init(torch.from_numpy(new_vertices).float().cuda(), torch.from_numpy(new_faces).int().cuda()) - print(f"After simplifying: {cumesh.num_vertices} vertices, {cumesh.num_faces} faces") + print(f"After simplifying: {cumesh.num_vertices} vertices, {cumesh.num_faces} faces") pbar.update(1) - - print('Unwrapping ...') + + print('Unwrapping ...') out_vertices, out_faces, out_uvs, out_vmaps = cumesh.uv_unwrap( compute_charts_kwargs={ "threshold_cone_half_angle_rad": mesh_cluster_threshold_cone_half_angle_rad, @@ -1025,13 +1058,13 @@ def process(self, mesh, mesh_cluster_threshold_cone_half_angle_rad, mesh_cluster verbose=True, ) pbar.update(1) - + out_vertices = out_vertices.cuda() out_faces = out_faces.cuda() out_uvs = out_uvs.cuda() out_vmaps = out_vmaps.cuda() cumesh.compute_vertex_normals() - out_normals = cumesh.read_vertex_normals()[out_vmaps] + out_normals = cumesh.read_vertex_normals()[out_vmaps] print("Sampling attributes...") # Setup differentiable rasterizer context @@ -1039,7 +1072,7 @@ def process(self, mesh, mesh_cluster_threshold_cone_half_angle_rad, mesh_cluster # Prepare UV coordinates for rasterization (rendering in UV space) uvs_rast = torch.cat([out_uvs * 2 - 1, torch.zeros_like(out_uvs[:, :1]), torch.ones_like(out_uvs[:, :1])], dim=-1).unsqueeze(0) rast = torch.zeros((1, texture_size, texture_size, 4), device='cuda', dtype=torch.float32) - + # Rasterize in chunks to save memory for i in range(0, out_faces.shape[0], 100000): rast_chunk, _ = dr.rasterize( @@ -1049,20 +1082,20 @@ def process(self, mesh, mesh_cluster_threshold_cone_half_angle_rad, mesh_cluster mask_chunk = rast_chunk[..., 3:4] > 0 rast_chunk[..., 3:4] += i # Store face ID in alpha channel rast = torch.where(mask_chunk, rast_chunk, rast) - + # Mask of valid pixels in texture mask = rast[0, ..., 3] > 0 - + # Interpolate 3D positions in UV space (finding 3D coord for every texel) pos = dr.interpolate(out_vertices.unsqueeze(0), rast, out_faces)[0][0] valid_pos = pos[mask] - + # Map these positions back to the *original* high-res mesh to get accurate attributes # This corrects geometric errors introduced by simplification/remeshing _, face_id, uvw = bvh.unsigned_distance(valid_pos, return_uvw=True) orig_tri_verts = vertices[faces[face_id.long()]] # (N_new, 3, 3) valid_pos = (orig_tri_verts * uvw.unsqueeze(-1)).sum(dim=1) - + # Trilinear sampling from the attribute volume (Color, Material props) attrs = torch.zeros(texture_size, texture_size, attr_volume.shape[1], device='cuda') attrs[mask] = grid_sample_3d( @@ -1072,27 +1105,27 @@ def process(self, mesh, mesh_cluster_threshold_cone_half_angle_rad, mesh_cluster grid=((valid_pos - aabb[0]) / voxel_size).reshape(1, -1, 3), mode='trilinear', ) - + # --- Texture Post-Processing & Material Construction --- print("Finalizing mesh...") pbar.update(1) - + mask = mask.cpu().numpy() - + # Extract channels based on layout (BaseColor, Metallic, Roughness, Alpha) base_color = np.clip(attrs[..., attr_layout['base_color']].cpu().numpy() * 255, 0, 255).astype(np.uint8) metallic = np.clip(attrs[..., attr_layout['metallic']].cpu().numpy() * 255, 0, 255).astype(np.uint8) roughness = np.clip(attrs[..., attr_layout['roughness']].cpu().numpy() * 255, 0, 255).astype(np.uint8) alpha = np.clip(attrs[..., attr_layout['alpha']].cpu().numpy() * 255, 0, 255).astype(np.uint8) alpha_mode = texture_alpha_mode - + # Inpainting: fill gaps (dilation) to prevent black seams at UV boundaries mask_inv = (~mask).astype(np.uint8) base_color = cv2.inpaint(base_color, mask_inv, 3, cv2.INPAINT_TELEA) metallic = cv2.inpaint(metallic, mask_inv, 1, cv2.INPAINT_TELEA)[..., None] roughness = cv2.inpaint(roughness, mask_inv, 1, cv2.INPAINT_TELEA)[..., None] alpha = cv2.inpaint(alpha, mask_inv, 1, cv2.INPAINT_TELEA)[..., None] - + # Create PBR material # Standard PBR packs Metallic and Roughness into Blue and Green channels baseColorTexture_np = Image.fromarray(np.concatenate([base_color, alpha], axis=-1)) @@ -1105,18 +1138,18 @@ def process(self, mesh, mesh_cluster_threshold_cone_half_angle_rad, mesh_cluster roughnessFactor=1.0, alphaMode=alpha_mode, doubleSided=double_side_material, - ) - + ) + vertices_np = out_vertices.cpu().numpy() faces_np = out_faces.cpu().numpy() uvs_np = out_uvs.cpu().numpy() normals_np = out_normals.cpu().numpy() - + # Swap Y and Z axes, invert Y (common conversion for GLB compatibility) vertices_np[:, 1], vertices_np[:, 2] = vertices_np[:, 2], -vertices_np[:, 1] normals_np[:, 1], normals_np[:, 2] = normals_np[:, 2], -normals_np[:, 1] uvs_np[:, 1] = 1 - uvs_np[:, 1] # Flip UV V-coordinate - + textured_mesh = Trimesh.Trimesh( vertices=vertices_np, faces=faces_np, @@ -1124,15 +1157,15 @@ def process(self, mesh, mesh_cluster_threshold_cone_half_angle_rad, mesh_cluster process=False, visual=Trimesh.visual.TextureVisuals(uv=uvs_np,material=material) ) - pbar.update(1) - + pbar.update(1) + del cumesh - gc.collect() + gc.collect() baseColorTexture = pil2tensor(baseColorTexture_np) metallicRoughnessTexture = pil2tensor(metallicRoughnessTexture_np) - - return (textured_mesh, baseColorTexture, metallicRoughnessTexture, ) + + return (textured_mesh, baseColorTexture, metallicRoughnessTexture, ) class Trellis2Remesh: @classmethod @@ -1145,7 +1178,7 @@ def INPUT_TYPES(s): "fill_holes": ("BOOLEAN", {"default":True}), "fill_holes_max_perimeter": ("FLOAT",{"default":0.03,"min":0.001,"max":99.999,"step":0.001}), "dual_contouring_resolution": (["Auto","128","256","512","1024","2048"],{"default":"Auto"}), - "remove_floaters": ("BOOLEAN",{"default":True}), + "remove_floaters": ("BOOLEAN",{"default":True}), }, } @@ -1157,26 +1190,26 @@ def INPUT_TYPES(s): def process(self, mesh, remesh_band, remesh_project, fill_holes, fill_holes_max_perimeter, dual_contouring_resolution, remove_floaters): mesh_copy = copy.deepcopy(mesh) - + if remove_floaters: mesh_copy = remove_floater(mesh_copy) - + aabb = [[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]] - + vertices = mesh_copy.vertices faces = mesh_copy.faces attr_volume = mesh_copy.attrs coords = mesh_copy.coords attr_layout = mesh_copy.layout - voxel_size = mesh_copy.voxel_size - + voxel_size = mesh_copy.voxel_size + # --- Input Normalization (AABB, Voxel Size, Grid Size) --- if isinstance(aabb, (list, tuple)): aabb = np.array(aabb) if isinstance(aabb, np.ndarray): aabb = torch.tensor(aabb, dtype=torch.float32, device='cuda') - # Calculate grid dimensions based on AABB and voxel size + # Calculate grid dimensions based on AABB and voxel size if voxel_size is not None: if isinstance(voxel_size, float): voxel_size = [voxel_size, voxel_size, voxel_size] @@ -1197,34 +1230,34 @@ def process(self, mesh, remesh_band, remesh_project, fill_holes, fill_holes_max_ # Move data to GPU vertices = vertices.cuda() faces = faces.cuda() - + # Initialize CUDA mesh handler cumesh = CuMesh.CuMesh() cumesh.init(vertices, faces) print(f"Current vertices: {cumesh.num_vertices}, faces: {cumesh.num_faces}") - + # --- Initial Mesh Cleaning --- # Fills holes as much as we can before processing if fill_holes: cumesh.fill_holes(max_hole_perimeter=fill_holes_max_perimeter) print(f"After filling holes: {cumesh.num_vertices} vertices, {cumesh.num_faces} faces") - + vertices, faces = cumesh.read() - + # Build BVH for the current mesh to guide remeshing print(f"Building BVH for current mesh...") bvh = CuMesh.cuBVH(vertices, faces) - - print("Cleaning mesh...") + + print("Cleaning mesh...") center = aabb.mean(dim=0) scale = (aabb[1] - aabb[0]).max().item() - + if dual_contouring_resolution == "Auto": resolution = grid_size.max().item() print(f"Dual Contouring resolution: {resolution}") else: resolution = int(dual_contouring_resolution) - + print('Performing Dual Contouring ...') # Perform Dual Contouring remeshing (rebuilds topology) cumesh.init(*CuMesh.remeshing.remesh_narrow_band_dc( @@ -1237,22 +1270,22 @@ def process(self, mesh, remesh_band, remesh_project, fill_holes, fill_holes_max_ verbose = True, bvh = bvh, )) - - print(f"After remeshing: {cumesh.num_vertices} vertices, {cumesh.num_faces} faces") - + + print(f"After remeshing: {cumesh.num_vertices} vertices, {cumesh.num_faces} faces") + # Step 2: Unify face orientations - #cumesh.unify_face_orientations() - + #cumesh.unify_face_orientations() + new_vertices, new_faces = cumesh.read() - + mesh_copy.vertices = new_vertices.to(mesh_copy.device) - mesh_copy.faces = new_faces.to(mesh_copy.device) - + mesh_copy.faces = new_faces.to(mesh_copy.device) + del cumesh - gc.collect() - + gc.collect() + return (mesh_copy,) - + class Trellis2MeshTexturing: @classmethod def INPUT_TYPES(s): @@ -1269,9 +1302,9 @@ def INPUT_TYPES(s): "resolution": ([512,1024],{"default":1024}), "texture_size": ("INT",{"default":2048,"min":512,"max":16384}), "texture_alpha_mode": (["OPAQUE","MASK","BLEND"],{"default":"OPAQUE"}), - "double_side_material": ("BOOLEAN",{"default":True}), + "double_side_material": ("BOOLEAN",{"default":True}), "texture_guidance_interval_start": ("FLOAT",{"default":0.60,"min":0.00,"max":1.00,"step":0.01}), - "texture_guidance_interval_end": ("FLOAT",{"default":0.90,"min":0.00,"max":1.00,"step":0.01}), + "texture_guidance_interval_end": ("FLOAT",{"default":0.90,"min":0.00,"max":1.00,"step":0.01}), }, } @@ -1284,39 +1317,39 @@ def INPUT_TYPES(s): def process(self, pipeline, image, trimesh, seed, texture_steps, texture_guidance_strength, texture_guidance_rescale, texture_rescale_t, resolution, texture_size, texture_alpha_mode, double_side_material, texture_guidance_interval_start, texture_guidance_interval_end): #image = tensor2pil_v2(image) image = tensor2pil(image) - - texture_guidance_interval = [texture_guidance_interval_start,texture_guidance_interval_end] - + + texture_guidance_interval = [texture_guidance_interval_start,texture_guidance_interval_end] + tex_slat_sampler_params = {"steps":texture_steps,"guidance_strength":texture_guidance_strength,"guidance_rescale":texture_guidance_rescale,"guidance_interval":texture_guidance_interval,"rescale_t":texture_rescale_t} - textured_mesh, baseColorTexture_np, metallicRoughnessTexture_np = pipeline.texture_mesh(mesh=trimesh, - image=image, - seed=seed, + textured_mesh, baseColorTexture_np, metallicRoughnessTexture_np = pipeline.texture_mesh(mesh=trimesh, + image=image, + seed=seed, tex_slat_sampler_params = tex_slat_sampler_params, resolution = resolution, texture_size = texture_size, texture_alpha_mode = texture_alpha_mode, double_side_material = double_side_material ) - + baseColorTexture = pil2tensor(baseColorTexture_np) metallicRoughnessTexture = pil2tensor(metallicRoughnessTexture_np) - + return (textured_mesh, baseColorTexture, metallicRoughnessTexture, ) - + class Trellis2LoadMesh: @classmethod def INPUT_TYPES(s): return { "required": { - "glb_path": ("STRING", {"default": "", "tooltip": "The glb path with mesh to load."}), + "glb_path": ("STRING", {"default": "", "tooltip": "The glb path with mesh to load."}), } } RETURN_TYPES = ("TRIMESH",) RETURN_NAMES = ("trimesh",) OUTPUT_TOOLTIPS = ("The glb model with mesh to texturize.",) - + FUNCTION = "load" CATEGORY = "Trellis2Wrapper" DESCRIPTION = "Loads a glb model from the given path." @@ -1324,22 +1357,22 @@ def INPUT_TYPES(s): def load(self, glb_path): if not os.path.exists(glb_path): glb_path = os.path.join(folder_paths.get_input_directory(), glb_path) - + trimesh = Trimesh.load(glb_path, force="mesh") - - return (trimesh,) - + + return (trimesh,) + class Trellis2PreProcessImage: @classmethod def INPUT_TYPES(s): return { "required": { - "image": ("IMAGE",), + "image": ("IMAGE",), } } RETURN_TYPES = ("IMAGE",) RETURN_NAMES = ("image",) - + FUNCTION = "process" CATEGORY = "Trellis2Wrapper" @@ -1347,8 +1380,8 @@ def process(self, image): image = tensor2pil(image) image = self.preprocess_image(image) image = pil2tensor(image) - - return (image,) + + return (image,) def preprocess_image(self, input: Image.Image) -> Image.Image: @@ -1387,7 +1420,7 @@ def preprocess_image(self, input: Image.Image) -> Image.Image: output = np.array(output).astype(np.float32) / 255 output = output[:, :, :3] * output[:, :, 3:4] output = Image.fromarray((output * 255).astype(np.uint8)) - return output + return output class Trellis2MeshRefiner: @classmethod @@ -1402,11 +1435,11 @@ def INPUT_TYPES(s): "shape_steps": ("INT",{"default":12, "min":1, "max":100},), "shape_guidance_strength": ("FLOAT",{"default":7.50}), "shape_guidance_rescale": ("FLOAT",{"default":0.50}), - "shape_rescale_t": ("FLOAT",{"default":3.00}), + "shape_rescale_t": ("FLOAT",{"default":3.00}), "texture_steps": ("INT",{"default":12, "min":1, "max":100},), "texture_guidance_strength": ("FLOAT",{"default":1.00}), "texture_guidance_rescale": ("FLOAT",{"default":0.00}), - "texture_rescale_t": ("FLOAT",{"default":3.00}), + "texture_rescale_t": ("FLOAT",{"default":3.00}), "max_num_tokens": ("INT",{"default":49152,"min":0,"max":999999}), "generate_texture_slat": ("BOOLEAN", {"default":True}), "downsampling":([16,32,64],{"default":16}), @@ -1425,14 +1458,14 @@ def INPUT_TYPES(s): OUTPUT_NODE = True def process(self, pipeline, trimesh, image, seed, resolution, - shape_steps, - shape_guidance_strength, + shape_steps, + shape_guidance_strength, shape_guidance_rescale, - shape_rescale_t, - texture_steps, - texture_guidance_strength, + shape_rescale_t, + texture_steps, + texture_guidance_strength, texture_guidance_rescale, - texture_rescale_t, + texture_rescale_t, max_num_tokens, generate_texture_slat, downsampling, @@ -1443,15 +1476,15 @@ def process(self, pipeline, trimesh, image, seed, resolution, use_tiled_decoder): image = tensor2pil(image) - + shape_guidance_interval = [shape_guidance_interval_start,shape_guidance_interval_end] - texture_guidance_interval = [texture_guidance_interval_start,texture_guidance_interval_end] - - shape_slat_sampler_params = {"steps":shape_steps,"guidance_strength":shape_guidance_strength,"guidance_rescale":shape_guidance_rescale,"guidance_interval":shape_guidance_interval,"rescale_t":shape_rescale_t} + texture_guidance_interval = [texture_guidance_interval_start,texture_guidance_interval_end] + + shape_slat_sampler_params = {"steps":shape_steps,"guidance_strength":shape_guidance_strength,"guidance_rescale":shape_guidance_rescale,"guidance_interval":shape_guidance_interval,"rescale_t":shape_rescale_t} tex_slat_sampler_params = {"steps":texture_steps,"guidance_strength":texture_guidance_strength,"guidance_rescale":texture_guidance_rescale,"guidance_interval":texture_guidance_interval,"rescale_t":texture_rescale_t} - - mesh = pipeline.refine_mesh(mesh = trimesh, image=image, seed=seed, shape_slat_sampler_params = shape_slat_sampler_params, tex_slat_sampler_params = tex_slat_sampler_params, resolution = resolution, max_num_tokens = max_num_tokens, generate_texture_slat=generate_texture_slat, downsampling=downsampling, use_tiled=use_tiled_decoder)[0] - + + mesh = pipeline.refine_mesh(mesh = trimesh, image=image, seed=seed, shape_slat_sampler_params = shape_slat_sampler_params, tex_slat_sampler_params = tex_slat_sampler_params, resolution = resolution, max_num_tokens = max_num_tokens, generate_texture_slat=generate_texture_slat, downsampling=downsampling, use_tiled=use_tiled_decoder)[0] + return (mesh,) class Trellis2PostProcess2: @@ -1476,19 +1509,19 @@ def INPUT_TYPES(s): def process(self, mesh, fill_holes, fix_normals, fix_face_orientation, remove_duplicate_faces, remove_infinite_values): mesh_copy = copy.deepcopy(mesh) - + vertices_np = mesh_copy.vertices.cpu().numpy() faces_np = mesh_copy.faces.cpu().numpy() - + trimesh = Trimesh.Trimesh(vertices=vertices_np,faces=faces_np) - + print(f"Initial mesh: {len(trimesh.faces)} faces") - print(f"Is winding consistent? {trimesh.is_winding_consistent}") - + print(f"Is winding consistent? {trimesh.is_winding_consistent}") + if fix_normals: print('Fixing normals ...') - trimesh.fix_normals() - + trimesh.fix_normals() + if fix_face_orientation: if trimesh.is_watertight: print('Mesh is watertight, fixing inversion ...') @@ -1499,25 +1532,25 @@ def process(self, mesh, fill_holes, fix_normals, fix_face_orientation, remove_du if remove_duplicate_faces: print('Removing duplicate faces ...') trimesh.remove_duplicate_faces() - + if remove_infinite_values: print('Removing infinite values ...') trimesh.remove_infinite_values() - + if fill_holes: print('Filling holes ...') - trimesh.fill_holes() - + trimesh.fill_holes() + new_vertices = torch.from_numpy(trimesh.vertices).float() - new_faces = torch.from_numpy(trimesh.faces).int() - + new_faces = torch.from_numpy(trimesh.faces).int() + mesh_copy.vertices = new_vertices.to(mesh_copy.device) - mesh_copy.faces = new_faces.to(mesh_copy.device) - + mesh_copy.faces = new_faces.to(mesh_copy.device) + del trimesh gc.collect() - - return (mesh_copy,) + + return (mesh_copy,) class Trellis2OvoxelExportToGLB: @classmethod @@ -1555,7 +1588,7 @@ def process(self, mesh, resolution, texture_size, target_face_num): remesh_project=0, use_tqdm=True, ) - + return (glb,) class Trellis2TrimeshToMeshWithVoxel: @@ -1574,17 +1607,17 @@ def INPUT_TYPES(s): CATEGORY = "Trellis2Wrapper" OUTPUT_NODE = True - def process(self, trimesh, resolution): + def process(self, trimesh, resolution): mesh_copy = trimesh.copy() - + mvoxel = self.get_voxelmesh_from_trimesh(mesh_copy, resolution) - - return (mvoxel,) - + + return (mvoxel,) + def get_voxelmesh_from_trimesh(self, mesh, resolution): vertices = torch.from_numpy(mesh.vertices).float() faces = torch.from_numpy(mesh.faces).long() - + voxel_indices, dual_vertices, intersected = o_voxel.convert.mesh_to_flexible_dual_grid( vertices.cpu(), faces.cpu(), grid_size=resolution, @@ -1594,15 +1627,15 @@ def get_voxelmesh_from_trimesh(self, mesh, resolution): regularization_weight=1e-2, timing=True, ) - - coords = torch.cat([torch.zeros_like(voxel_indices[:, 0:1]), voxel_indices], dim=-1) + + coords = torch.cat([torch.zeros_like(voxel_indices[:, 0:1]), voxel_indices], dim=-1) coords = coords.cpu() del voxel_indices del dual_vertices del intersected gc.collect() - + pbr_attr_layout = { 'base_color': slice(0, 3), 'metallic': slice(3, 4), @@ -1619,7 +1652,7 @@ def get_voxelmesh_from_trimesh(self, mesh, resolution): voxel_shape = None, layout=pbr_attr_layout ) - + return mvoxel NODE_CLASS_MAPPINGS = { diff --git a/trellis2/pipelines/trellis2_image_to_3d.py b/trellis2/pipelines/trellis2_image_to_3d.py index 5414241..3949fd3 100644 --- a/trellis2/pipelines/trellis2_image_to_3d.py +++ b/trellis2/pipelines/trellis2_image_to_3d.py @@ -142,10 +142,10 @@ def from_pretrained(cls, path: str, config_file: str = "pipeline.json", keep_mod #pipeline.image_cond_model = getattr(image_feature_extractor, args['image_cond_model']['name'])(**args['image_cond_model']['args']) #pipeline.rembg_model = getattr(rembg, args['rembg_model']['name'])(**args['rembg_model']['args']) - + pipeline.image_cond_model = None pipeline.rembg_model = None - + pipeline.low_vram = args.get('low_vram', True) pipeline.default_pipeline_type = args.get('default_pipeline_type', '1024_cascade') pipeline.pbr_attr_layout = { @@ -158,64 +158,88 @@ def from_pretrained(cls, path: str, config_file: str = "pipeline.json", keep_mod pipeline.path = path pipeline.keep_models_loaded = keep_models_loaded pipeline.last_processing = '' - - pipeline._pretrained_args['models']['sparse_structure_decoder'] = os.path.join(folder_paths.models_dir,"microsoft","TRELLIS-image-large","ckpts","ss_dec_conv3d_16l8_fp16") - facebook_model_path = os.path.join(folder_paths.models_dir,"facebook","dinov3-vitl16-pretrain-lvd1689m") - pipeline._pretrained_args['image_cond_model']['args']['model_name'] = facebook_model_path + + # Check for trellis-image-large in registered paths + trellis_image_large_path = None + microsoft_paths = folder_paths.get_folder_paths("microsoft") + for base in microsoft_paths: + candidate = os.path.join(base, "TRELLIS-image-large", "ckpts", "ss_dec_conv3d_16l8_fp16") + if os.path.exists(candidate + ".safetensors"): + trellis_image_large_path = candidate + break + + if trellis_image_large_path is None: + trellis_image_large_path = os.path.join(folder_paths.models_dir, "microsoft", "TRELLIS-image-large", "ckpts", "ss_dec_conv3d_16l8_fp16") + + pipeline._pretrained_args['models']['sparse_structure_decoder'] = trellis_image_large_path + + # Check for dinov3 in registered paths + facebook_model_path = None + facebook_paths = folder_paths.get_folder_paths("facebook") + for base in facebook_paths: + candidate = os.path.join(base, "dinov3-vitl16-pretrain-lvd1689m") + if os.path.exists(os.path.join(candidate, "model.safetensors")): + facebook_model_path = candidate + break + + if facebook_model_path is None: + facebook_model_path = os.path.join(folder_paths.models_dir, "facebook", "dinov3-vitl16-pretrain-lvd1689m") + + pipeline._pretrained_args['image_cond_model']['args']['model_name'] = facebook_model_path return pipeline - - def load_sparse_structure_model(self): + + def load_sparse_structure_model(self): if self.models['sparse_structure_flow_model'] is None: print('Loading Sparse Structure model ...') self.models['sparse_structure_flow_model'] = models.from_pretrained(f"{self.path}/{self._pretrained_args['models']['sparse_structure_flow_model']}") self.models['sparse_structure_flow_model'].eval() self.models['sparse_structure_flow_model'].to(self._device) - - if self.models['sparse_structure_decoder'] is None: + + if self.models['sparse_structure_decoder'] is None: self.models['sparse_structure_decoder'] = models.from_pretrained(self._pretrained_args['models']['sparse_structure_decoder']) - self.models['sparse_structure_decoder'].eval() + self.models['sparse_structure_decoder'].eval() self.models['sparse_structure_decoder'].to(self._device) if hasattr(self.models['sparse_structure_decoder'], 'low_vram'): self.models['sparse_structure_decoder'].low_vram = self.low_vram - + def unload_sparse_structure_model(self): if self.models['sparse_structure_flow_model']: del self.models['sparse_structure_flow_model'] self.models['sparse_structure_flow_model'] = None gc.collect() - + if self.models['sparse_structure_decoder']: del self.models['sparse_structure_decoder'] self.models['sparse_structure_decoder'] = None - gc.collect() - + gc.collect() + def load_image_cond_model(self): if self.image_cond_model is None: print('Loading Image Cond model ...') self.image_cond_model = getattr(image_feature_extractor, self._pretrained_args['image_cond_model']['name'])(**self._pretrained_args['image_cond_model']['args']) self.image_cond_model.to(self._device) - + def unload_image_cond_model(self): if self.image_cond_model is not None: del self.image_cond_model self.image_cond_model = None gc.collect() - - def load_shape_slat_flow_model_512(self): + + def load_shape_slat_flow_model_512(self): if self.models['shape_slat_flow_model_512'] is None: print('Loading Shape Slat Flow 512 model ...') self.models['shape_slat_flow_model_512'] = models.from_pretrained(f"{self.path}/{self._pretrained_args['models']['shape_slat_flow_model_512']}") self.models['shape_slat_flow_model_512'].eval() self.models['shape_slat_flow_model_512'].to(self._device) - + def unload_shape_slat_flow_model_512(self): if self.models['shape_slat_flow_model_512'] is not None: del self.models['shape_slat_flow_model_512'] self.models['shape_slat_flow_model_512'] = None gc.collect() - - def load_tex_slat_flow_model_512(self): + + def load_tex_slat_flow_model_512(self): if self.models['tex_slat_flow_model_512'] is None: print('Loading Texture Slat Flow 512 model ...') self.models['tex_slat_flow_model_512'] = models.from_pretrained(f"{self.path}/{self._pretrained_args['models']['tex_slat_flow_model_512']}") @@ -226,9 +250,9 @@ def unload_tex_slat_flow_model_512(self): if self.models['tex_slat_flow_model_512'] is not None: del self.models['tex_slat_flow_model_512'] self.models['tex_slat_flow_model_512'] = None - gc.collect() + gc.collect() - def load_tex_slat_decoder(self): + def load_tex_slat_decoder(self): if self.models['tex_slat_decoder'] is None: print('Loading Texture Slat decoder model ...') self.models['tex_slat_decoder'] = models.from_pretrained(f"{self.path}/{self._pretrained_args['models']['tex_slat_decoder']}") @@ -242,8 +266,8 @@ def unload_tex_slat_decoder(self): del self.models['tex_slat_decoder'] self.models['tex_slat_decoder'] = None gc.collect() - - def load_shape_slat_decoder(self): + + def load_shape_slat_decoder(self): if self.models['shape_slat_decoder'] is None: print('Loading Shape Slat decoder model ...') self.models['shape_slat_decoder'] = models.from_pretrained(f"{self.path}/{self._pretrained_args['models']['shape_slat_decoder']}") @@ -256,9 +280,9 @@ def unload_shape_slat_decoder(self): if self.models['shape_slat_decoder'] is not None: del self.models['shape_slat_decoder'] self.models['shape_slat_decoder'] = None - gc.collect() + gc.collect() - def load_shape_slat_flow_model_1024(self): + def load_shape_slat_flow_model_1024(self): if self.models['shape_slat_flow_model_1024'] is None: print('Loading Shape Slat Flow 1024 model ...') self.models['shape_slat_flow_model_1024'] = models.from_pretrained(f"{self.path}/{self._pretrained_args['models']['shape_slat_flow_model_1024']}") @@ -269,9 +293,9 @@ def unload_shape_slat_flow_model_1024(self): if self.models['shape_slat_flow_model_1024'] is not None: del self.models['shape_slat_flow_model_1024'] self.models['shape_slat_flow_model_1024'] = None - gc.collect() + gc.collect() - def load_tex_slat_flow_model_1024(self): + def load_tex_slat_flow_model_1024(self): if self.models['tex_slat_flow_model_1024'] is None: print('Loading Texture Slat Flow 1024 model ...') self.models['tex_slat_flow_model_1024'] = models.from_pretrained(f"{self.path}/{self._pretrained_args['models']['tex_slat_flow_model_1024']}") @@ -282,9 +306,9 @@ def unload_tex_slat_flow_model_1024(self): if self.models['tex_slat_flow_model_1024'] is not None: del self.models['tex_slat_flow_model_1024'] self.models['tex_slat_flow_model_1024'] = None - gc.collect() + gc.collect() - def load_shape_slat_encoder(self): + def load_shape_slat_encoder(self): if self.models['shape_slat_encoder'] is None: print('Loading Shape Slat Encoder model ...') self.models['shape_slat_encoder'] = models.from_pretrained(f"{self.path}/ckpts/shape_enc_next_dc_f16c32_fp16") @@ -297,7 +321,7 @@ def unload_shape_slat_encoder(self): if self.models['shape_slat_encoder'] is not None: del self.models['shape_slat_encoder'] self.models['shape_slat_encoder'] = None - gc.collect() + gc.collect() def to(self, device: torch.device) -> None: self._device = device @@ -344,7 +368,7 @@ def preprocess_image(self, input: Image.Image) -> Image.Image: output = output[:, :, :3] * output[:, :, 3:4] output = Image.fromarray((output * 255).astype(np.uint8)) return output - + def get_cond( self, image: Union[torch.Tensor, Image.Image, List[Image.Image]], @@ -378,7 +402,7 @@ def get_cond( # Expect ComfyUI IMAGE tensor: (B,H,W,C) float in [0,1] if image.ndim == 4: # Lazy import to avoid circulars if tensor2pil is in nodes/utils - from .nodes import tensor2pil + from .nodes import tensor2pil images = [tensor2pil(image[i]) for i in range(min(int(image.shape[0]), max_views))] else: raise ValueError(f"Expected image tensor with shape (B,H,W,C), got {tuple(image.shape)}") @@ -440,7 +464,7 @@ def sample_sparse_structure( ) -> torch.Tensor: """ Sample sparse structures with the given conditioning. - + Args: cond (dict): The conditioning information. resolution (int): The resolution of the sparse structure. @@ -448,7 +472,7 @@ def sample_sparse_structure( sampler_params (dict): Additional parameters for the sampler. """ if self.low_vram: - cond = self._cond_to(cond, self.device) + cond = self._cond_to(cond, self.device) # Sample sparse structure latent flow_model = self.models['sparse_structure_flow_model'] reso = flow_model.resolution @@ -467,7 +491,7 @@ def sample_sparse_structure( ).samples if self.low_vram: flow_model.cpu() - self._cleanup_cuda() + self._cleanup_cuda() # Decode sparse structure latent decoder = self.models['sparse_structure_decoder'] if self.low_vram: @@ -497,7 +521,7 @@ def sample_shape_slat( ) -> SparseTensor: """ Sample structured latent with the given conditioning. - + Args: cond (dict): The conditioning information. coords (torch.Tensor): The coordinates of the sparse structure. @@ -506,7 +530,7 @@ def sample_shape_slat( if self.low_vram: cond = self._cond_to(cond, self.device) - coords_dev = coords.to(self.device) + coords_dev = coords.to(self.device) # Sample structured latent noise = SparseTensor( feats=torch.randn(coords.shape[0], flow_model.in_channels, device=self.device), @@ -525,19 +549,19 @@ def sample_shape_slat( ).samples if self.low_vram: flow_model.cpu() - self._cleanup_cuda() + self._cleanup_cuda() std = torch.tensor(self.shape_slat_normalization['std'])[None].to(slat.device) mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(slat.device) slat = slat * std + mean - + del coords_dev if self.low_vram: cond = self._cond_cpu(cond) self._cleanup_cuda() return slat - + def sample_shape_slat_cascade( self, lr_cond: dict, @@ -552,7 +576,7 @@ def sample_shape_slat_cascade( ) -> SparseTensor: """ Sample structured latent with the given conditioning. - + Args: cond (dict): The conditioning information. coords (torch.Tensor): The coordinates of the sparse structure. @@ -564,7 +588,7 @@ def sample_shape_slat_cascade( lr_cond = self._cond_to(lr_cond, self.device) cond = self._cond_to(cond, self.device) - coords_dev = coords.to(self.device) + coords_dev = coords.to(self.device) # Sample structured latent noise = SparseTensor( feats=torch.randn(coords.shape[0], flow_model.in_channels, device=self.device), @@ -583,17 +607,17 @@ def sample_shape_slat_cascade( ).samples if self.low_vram: flow_model_lr.cpu() - self._cleanup_cuda() + self._cleanup_cuda() std = torch.tensor(self.shape_slat_normalization['std'])[None].to(slat.device) mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(slat.device) slat = slat * std + mean - + del coords_dev if self.low_vram: lr_cond = self._cond_cpu(lr_cond) self._cleanup_cuda() - # Upsample + # Upsample self.load_shape_slat_decoder() if self.low_vram: self.models['shape_slat_decoder'].to(self.device) @@ -603,10 +627,10 @@ def sample_shape_slat_cascade( self.models['shape_slat_decoder'].cpu() self.models['shape_slat_decoder'].low_vram = False hr_resolution = resolution - + if not self.keep_models_loaded: self.unload_shape_slat_decoder() - + while True: quant_coords = torch.cat([ hr_coords[:, :1], @@ -625,8 +649,8 @@ def sample_shape_slat_cascade( if hr_resolution < 512: hr_resolution = 512 break - - coords_dev = coords.to(self.device) + + coords_dev = coords.to(self.device) # Sample structured latent noise = SparseTensor( feats=torch.randn(coords.shape[0], flow_model.in_channels, device=self.device), @@ -645,12 +669,12 @@ def sample_shape_slat_cascade( ).samples if self.low_vram: flow_model.cpu() - self._cleanup_cuda() + self._cleanup_cuda() std = torch.tensor(self.shape_slat_normalization['std'])[None].to(slat.device) mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(slat.device) slat = slat * std + mean - + del coords_dev if self.low_vram: cond = self._cond_cpu(cond) @@ -674,9 +698,9 @@ def decode_shape_slat( List[Mesh]: The decoded meshes. List[SparseTensor]: The decoded substructures. """ - + self.load_shape_slat_decoder() - + self.models['shape_slat_decoder'].set_resolution(resolution) if self.low_vram: self.models['shape_slat_decoder'].to(self.device) @@ -685,13 +709,13 @@ def decode_shape_slat( if self.low_vram: self.models['shape_slat_decoder'].cpu() self.models['shape_slat_decoder'].low_vram = False - torch.cuda.empty_cache() - - if not self.keep_models_loaded: + torch.cuda.empty_cache() + + if not self.keep_models_loaded: self.unload_shape_slat_decoder() - + return ret - + def sample_tex_slat( self, cond: dict, @@ -701,14 +725,14 @@ def sample_tex_slat( ) -> SparseTensor: """ Sample structured latent with the given conditioning. - + Args: cond (dict): The conditioning information. shape_slat (SparseTensor): The structured latent for shape sampler_params (dict): Additional parameters for the sampler. """ if self.low_vram: - cond = self._cond_to(cond, self.device) + cond = self._cond_to(cond, self.device) # Sample structured latent std = torch.tensor(self.shape_slat_normalization['std'])[None].to(shape_slat.device) mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(shape_slat.device) @@ -730,15 +754,15 @@ def sample_tex_slat( ).samples if self.low_vram: flow_model.cpu() - self._cleanup_cuda() + self._cleanup_cuda() std = torch.tensor(self.tex_slat_normalization['std'])[None].to(slat.device) mean = torch.tensor(self.tex_slat_normalization['mean'])[None].to(slat.device) slat = slat * std + mean - + if self.low_vram: cond = self._cond_cpu(cond) - self._cleanup_cuda() + self._cleanup_cuda() return slat def decode_tex_slat( @@ -755,28 +779,28 @@ def decode_tex_slat( Returns: SparseTensor: The decoded texture voxels """ - + self.load_tex_slat_decoder() - + if self.low_vram: self.models['tex_slat_decoder'].to(self.device) - self.models['tex_slat_decoder'].low_vram = True - + self.models['tex_slat_decoder'].low_vram = True + if subs is None: ret = self.models['tex_slat_decoder'](slat) * 0.5 + 0.5 else: ret = self.models['tex_slat_decoder'](slat, guide_subs=subs) * 0.5 + 0.5 - + if self.low_vram: self.models['tex_slat_decoder'].cpu() self.models['tex_slat_decoder'].low_vram = False - torch.cuda.empty_cache() - + torch.cuda.empty_cache() + if not self.keep_models_loaded: self.unload_tex_slat_decoder() - + return ret - + @torch.no_grad() def decode_latent( self, @@ -795,11 +819,11 @@ def decode_latent( """ meshes, subs = self.decode_shape_slat(shape_slat, resolution, use_tiled=use_tiled) if self.low_vram: - self._cleanup_cuda() - + self._cleanup_cuda() + if tex_slat is None: if self.low_vram: - self._cleanup_cuda() + self._cleanup_cuda() out_mesh = [] for m in meshes: m.fill_holes() @@ -815,11 +839,11 @@ def decode_latent( ) ) return out_mesh - - else: + + else: tex_voxels = self.decode_tex_slat(tex_slat, subs) if self.low_vram: - self._cleanup_cuda() + self._cleanup_cuda() out_mesh = [] for m, v in zip(meshes, tex_voxels): m.fill_holes() @@ -835,7 +859,7 @@ def decode_latent( ) ) return out_mesh - + @torch.no_grad() def run( self, @@ -888,7 +912,7 @@ def run( # assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found." # else: # raise ValueError(f"Invalid pipeline type: {pipeline_type}") - + # Accept either a single PIL image or a list of PIL images (multi-view) if isinstance(image, (list, tuple)): images = list(image) @@ -897,51 +921,51 @@ def run( if preprocess_image: images = [self.preprocess_image(im) for im in images] - + torch.manual_seed(seed) - + # Get Image Cond - self.load_image_cond_model() - # Multi-view conditioning happens inside get_cond() - cond_512 = self.get_cond(images, 512, max_views = max_views) + self.load_image_cond_model() + # Multi-view conditioning happens inside get_cond() + cond_512 = self.get_cond(images, 512, max_views = max_views) cond_1024 = self.get_cond(images, 1024, max_views = max_views) if pipeline_type != '512' else None - + if pbar is not None: pbar.update(1) - + if not self.keep_models_loaded: self.unload_image_cond_model() - + #ss_res = {'512': 32, '1024': 64, '1024_cascade': 32, '1536_cascade': 32}[pipeline_type] - + # Sampling Sparse Structure - self.load_sparse_structure_model() + self.load_sparse_structure_model() coords = self.sample_sparse_structure( cond_512, sparse_structure_resolution, num_samples, sparse_structure_sampler_params ) - + if pbar is not None: pbar.update(1) - + if not self.keep_models_loaded: self.unload_sparse_structure_model() - + # Sampling Shape - if pipeline_type == '512': + if pipeline_type == '512': self.unload_shape_slat_flow_model_1024() - self.load_shape_slat_flow_model_512() + self.load_shape_slat_flow_model_512() shape_slat = self.sample_shape_slat( cond_512, self.models['shape_slat_flow_model_512'], coords, shape_slat_sampler_params ) - + if pbar is not None: pbar.update(1) - + if not self.keep_models_loaded: self.unload_shape_slat_flow_model_512() - + if generate_texture_slat: self.unload_tex_slat_flow_model_1024() self.load_tex_slat_flow_model_512() @@ -949,13 +973,13 @@ def run( cond_512, self.models['tex_slat_flow_model_512'], shape_slat, tex_slat_sampler_params ) - + if pbar is not None: pbar.update(1) - + if not self.keep_models_loaded: self.unload_tex_slat_flow_model_512() - + res = 512 elif pipeline_type == '1024': self.unload_shape_slat_flow_model_512() @@ -964,13 +988,13 @@ def run( cond_1024, self.models['shape_slat_flow_model_1024'], coords, shape_slat_sampler_params ) - + if pbar is not None: pbar.update(1) - + if not self.keep_models_loaded: self.unload_shape_slat_flow_model_1024() - + if generate_texture_slat: self.unload_tex_slat_flow_model_512() self.load_tex_slat_flow_model_1024() @@ -978,17 +1002,17 @@ def run( cond_1024, self.models['tex_slat_flow_model_1024'], shape_slat, tex_slat_sampler_params ) - + if pbar is not None: pbar.update(1) - + if not self.keep_models_loaded: self.unload_tex_slat_flow_model_1024() - + res = 1024 elif pipeline_type == '1024_cascade': self.load_shape_slat_flow_model_512() - self.load_shape_slat_flow_model_1024() + self.load_shape_slat_flow_model_1024() shape_slat, res = self.sample_shape_slat_cascade( cond_512, cond_1024, self.models['shape_slat_flow_model_512'], self.models['shape_slat_flow_model_1024'], @@ -996,14 +1020,14 @@ def run( coords, shape_slat_sampler_params, max_num_tokens ) - + if pbar is not None: pbar.update(1) - + if not self.keep_models_loaded: self.unload_shape_slat_flow_model_512() self.unload_shape_slat_flow_model_1024() - + if generate_texture_slat: self.unload_tex_slat_flow_model_512() self.load_tex_slat_flow_model_1024() @@ -1011,10 +1035,10 @@ def run( cond_1024, self.models['tex_slat_flow_model_1024'], shape_slat, tex_slat_sampler_params ) - + if pbar is not None: pbar.update(1) - + if not self.keep_models_loaded: self.unload_tex_slat_flow_model_1024() elif pipeline_type == '2048_cascade': @@ -1027,14 +1051,14 @@ def run( coords, shape_slat_sampler_params, max_num_tokens ) - + if pbar is not None: pbar.update(1) - + if not self.keep_models_loaded: self.unload_shape_slat_flow_model_512() self.unload_shape_slat_flow_model_1024() - + if generate_texture_slat: self.unload_tex_slat_flow_model_512() self.load_tex_slat_flow_model_1024() @@ -1042,10 +1066,10 @@ def run( cond_1024, self.models['tex_slat_flow_model_1024'], shape_slat, tex_slat_sampler_params ) - + if pbar is not None: pbar.update(1) - + if not self.keep_models_loaded: self.unload_tex_slat_flow_model_1024() elif pipeline_type == '4096_cascade': @@ -1058,14 +1082,14 @@ def run( coords, shape_slat_sampler_params, max_num_tokens ) - + if pbar is not None: pbar.update(1) - + if not self.keep_models_loaded: self.unload_shape_slat_flow_model_512() self.unload_shape_slat_flow_model_1024() - + if generate_texture_slat: self.unload_tex_slat_flow_model_512() self.load_tex_slat_flow_model_1024() @@ -1073,15 +1097,15 @@ def run( cond_1024, self.models['tex_slat_flow_model_1024'], shape_slat, tex_slat_sampler_params ) - + if pbar is not None: - pbar.update(1) - + pbar.update(1) + if not self.keep_models_loaded: self.unload_tex_slat_flow_model_1024() elif pipeline_type == '1536_cascade': self.load_shape_slat_flow_model_512() - self.load_shape_slat_flow_model_1024() + self.load_shape_slat_flow_model_1024() shape_slat, res = self.sample_shape_slat_cascade( cond_512, cond_1024, self.models['shape_slat_flow_model_512'], self.models['shape_slat_flow_model_1024'], @@ -1089,35 +1113,35 @@ def run( coords, shape_slat_sampler_params, max_num_tokens ) - + if pbar is not None: pbar.update(1) - + if not self.keep_models_loaded: self.unload_shape_slat_flow_model_512() self.unload_shape_slat_flow_model_1024() - - if generate_texture_slat: + + if generate_texture_slat: self.unload_tex_slat_flow_model_512() self.load_tex_slat_flow_model_1024() tex_slat = self.sample_tex_slat( cond_1024, self.models['tex_slat_flow_model_1024'], shape_slat, tex_slat_sampler_params ) - + if pbar is not None: pbar.update(1) - + if not self.keep_models_loaded: - self.unload_tex_slat_flow_model_1024() - + self.unload_tex_slat_flow_model_1024() + torch.cuda.empty_cache() if generate_texture_slat: out_mesh = self.decode_latent(shape_slat, tex_slat, res, use_tiled=use_tiled) else: out_mesh = self.decode_latent(shape_slat, None, res, use_tiled=use_tiled) torch.cuda.empty_cache() - pbar.update(1) + pbar.update(1) if return_latent: if generate_texture_slat: return out_mesh, (shape_slat, tex_slat, res) @@ -1132,7 +1156,7 @@ def preprocess_mesh(self, mesh: trimesh.Trimesh) -> trimesh.Trimesh: """ mesh = mesh.copy() vertices = mesh.vertices.copy() - + vertices_min = vertices.min(axis=0) vertices_max = vertices.max(axis=0) center = (vertices_min + vertices_max) / 2 @@ -1142,7 +1166,7 @@ def preprocess_mesh(self, mesh: trimesh.Trimesh) -> trimesh.Trimesh: vertices[:, 1] = -vertices[:, 2] vertices[:, 2] = tmp assert np.all(vertices >= -0.5) and np.all(vertices <= 0.5), 'vertices out of range' - + mesh.vertices = vertices return mesh @@ -1157,13 +1181,13 @@ def encode_shape_slat( Args: mesh (trimesh.Trimesh): The mesh to encode. resolution (int): The resolution of mesh - + Returns: SparseTensor: The encoded structured latent. """ vertices = torch.from_numpy(mesh.vertices).float() faces = torch.from_numpy(mesh.faces).long() - + voxel_indices, dual_vertices, intersected = o_voxel.convert.mesh_to_flexible_dual_grid( vertices.cpu(), faces.cpu(), grid_size=resolution, @@ -1173,24 +1197,24 @@ def encode_shape_slat( regularization_weight=1e-2, timing=True, ) - + vertices = SparseTensor( feats=dual_vertices * resolution - voxel_indices, coords=torch.cat([torch.zeros_like(voxel_indices[:, 0:1]), voxel_indices], dim=-1) ).to(self.device) intersected = vertices.replace(intersected).to(self.device) - + self.load_shape_slat_encoder() - + if self.low_vram: self.models['shape_slat_encoder'].to(self.device) shape_slat = self.models['shape_slat_encoder'](vertices, intersected) if self.low_vram: self.models['shape_slat_encoder'].cpu() - + if not self.keep_models_loaded: self.unload_shape_slat_encoder() - + return shape_slat def postprocess_mesh( @@ -1202,11 +1226,11 @@ def postprocess_mesh( texture_alpha_mode = 'OPAQUE', double_side_material = True ): - + vertices = mesh.vertices faces = mesh.faces normals = np.asarray(mesh.vertex_normals).copy() - + vertices_torch = torch.from_numpy(vertices).float().cuda() faces_torch = torch.from_numpy(faces).int().cuda() if hasattr(mesh, 'visual') and hasattr(mesh.visual, 'uv') and mesh.visual.uv is not None: @@ -1234,7 +1258,7 @@ def postprocess_mesh( faces = faces_torch.cpu().numpy() uvs = uvs_torch.cpu().numpy() normals = normals[vmap.cpu().numpy()] - + # rasterize print('Finalizing mesh ...') ctx = dr.RasterizeCudaContext() @@ -1245,7 +1269,7 @@ def postprocess_mesh( ) mask = rast[0, ..., 3] > 0 pos = dr.interpolate(vertices_torch.unsqueeze(0), rast, faces_torch)[0][0] - + attrs = torch.zeros(texture_size, texture_size, pbr_voxel.shape[1], device=self.device) attrs[mask] = flex_gemm.ops.grid_sample.grid_sample_3d( pbr_voxel.feats, @@ -1254,24 +1278,24 @@ def postprocess_mesh( grid=((pos[mask] + 0.5) * resolution).reshape(1, -1, 3), mode='trilinear', ) - + # construct mesh mask = mask.cpu().numpy() base_color = np.clip(attrs[..., self.pbr_attr_layout['base_color']].cpu().numpy() * 255, 0, 255).astype(np.uint8) metallic = np.clip(attrs[..., self.pbr_attr_layout['metallic']].cpu().numpy() * 255, 0, 255).astype(np.uint8) roughness = np.clip(attrs[..., self.pbr_attr_layout['roughness']].cpu().numpy() * 255, 0, 255).astype(np.uint8) alpha = np.clip(attrs[..., self.pbr_attr_layout['alpha']].cpu().numpy() * 255, 0, 255).astype(np.uint8) - + # extend mask = (~mask).astype(np.uint8) base_color = cv2.inpaint(base_color, mask, 3, cv2.INPAINT_TELEA) metallic = cv2.inpaint(metallic, mask, 1, cv2.INPAINT_TELEA)[..., None] roughness = cv2.inpaint(roughness, mask, 1, cv2.INPAINT_TELEA)[..., None] alpha = cv2.inpaint(alpha, mask, 1, cv2.INPAINT_TELEA)[..., None] - + baseColorTexture = Image.fromarray(np.concatenate([base_color, alpha], axis=-1)) metallicRoughnessTexture = Image.fromarray(np.concatenate([np.zeros_like(metallic), roughness, metallic], axis=-1)) - + material = trimesh.visual.material.PBRMaterial( baseColorTexture=baseColorTexture, baseColorFactor=np.array([255, 255, 255, 255], dtype=np.uint8), @@ -1286,7 +1310,7 @@ def postprocess_mesh( vertices[:, 1], vertices[:, 2] = vertices[:, 2], -vertices[:, 1] normals[:, 1], normals[:, 2] = normals[:, 2], -normals[:, 1] uvs[:, 1] = 1 - uvs[:, 1] # Flip UV V-coordinate - + textured_mesh = trimesh.Trimesh( vertices=vertices, faces=faces, @@ -1294,7 +1318,7 @@ def postprocess_mesh( process=False, visual=trimesh.visual.TextureVisuals(uv=uvs, material=material) ) - + return textured_mesh, baseColorTexture, metallicRoughnessTexture @torch.no_grad() @@ -1311,51 +1335,51 @@ def texture_mesh( ): mesh = self.preprocess_mesh(mesh) torch.manual_seed(seed) - - self.load_image_cond_model() + + self.load_image_cond_model() cond = self.get_cond(image, resolution) - + if not self.keep_models_loaded: self.unload_image_cond_model() - + shape_slat = self.encode_shape_slat(mesh, resolution) - + if resolution==512: self.unload_tex_slat_flow_model_1024() self.load_tex_slat_flow_model_512() tex_model = self.models['tex_slat_flow_model_512'] - + tex_slat = self.sample_tex_slat( cond, tex_model, shape_slat, tex_slat_sampler_params ) - + if not self.keep_models_loaded: self.unload_tex_slat_flow_model_512() else: self.unload_tex_slat_flow_model_512() self.load_tex_slat_flow_model_1024() tex_model = self.models['tex_slat_flow_model_1024'] - + tex_slat = self.sample_tex_slat( cond, tex_model, shape_slat, tex_slat_sampler_params ) - + if not self.keep_models_loaded: self.unload_shape_slat_flow_model_1024() torch.cuda.empty_cache() pbr_voxel = self.decode_tex_slat(tex_slat) torch.cuda.empty_cache() - + out_mesh, baseColorTexture, metallicRoughnessTexture = self.postprocess_mesh(mesh, pbr_voxel, resolution, texture_size, texture_alpha_mode, double_side_material) return out_mesh, baseColorTexture, metallicRoughnessTexture - + def get_coords_from_trimesh(self, mesh, resolution): vertices = torch.from_numpy(mesh.vertices).float() faces = torch.from_numpy(mesh.faces).long() - + voxel_indices, dual_vertices, intersected = o_voxel.convert.mesh_to_flexible_dual_grid( vertices.cpu(), faces.cpu(), grid_size=resolution, @@ -1365,21 +1389,21 @@ def get_coords_from_trimesh(self, mesh, resolution): regularization_weight=1e-2, timing=True, ) - - coords = torch.cat([torch.zeros_like(voxel_indices[:, 0:1]), voxel_indices], dim=-1) + + coords = torch.cat([torch.zeros_like(voxel_indices[:, 0:1]), voxel_indices], dim=-1) coords = coords.cpu() - + #print(coords) - + del voxel_indices del dual_vertices del intersected - + if self.low_vram: - self._cleanup_cuda() - + self._cleanup_cuda() + return coords; - + def sample_mesh_slat( self, mesh_slat, @@ -1390,7 +1414,7 @@ def sample_mesh_slat( max_num_tokens: int = 49152, downsampling = 16, ) -> SparseTensor: - # Upsample + # Upsample self.load_shape_slat_decoder() if self.low_vram: self.models['shape_slat_decoder'].to(self.device) @@ -1400,10 +1424,10 @@ def sample_mesh_slat( self.models['shape_slat_decoder'].cpu() self.models['shape_slat_decoder'].low_vram = False hr_resolution = resolution - + if not self.keep_models_loaded: self.unload_shape_slat_decoder() - + #downsampling = 16 lr_resolution = hr_resolution # if hr_resolution == 512: @@ -1412,7 +1436,7 @@ def sample_mesh_slat( # downsampling = 32 # elif hr_resolution == 1536: # downsampling = 32 - + while True: quant_coords = torch.cat([ hr_coords[:, :1], @@ -1431,8 +1455,8 @@ def sample_mesh_slat( if hr_resolution < 512: hr_resolution = 512 break - - coords_dev = coords.to(self.device) + + coords_dev = coords.to(self.device) # Sample structured latent noise = SparseTensor( feats=torch.randn(coords.shape[0], flow_model.in_channels, device=self.device), @@ -1451,19 +1475,19 @@ def sample_mesh_slat( ).samples if self.low_vram: flow_model.cpu() - self._cleanup_cuda() + self._cleanup_cuda() std = torch.tensor(self.shape_slat_normalization['std'])[None].to(slat.device) mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(slat.device) slat = slat * std + mean - + del coords_dev if self.low_vram: cond = self._cond_cpu(cond) self._cleanup_cuda() - return slat, hr_resolution - + return slat, hr_resolution + @torch.no_grad() def refine_mesh( self, @@ -1481,22 +1505,22 @@ def refine_mesh( ): mesh = self.preprocess_mesh(mesh) torch.manual_seed(seed) - + self.load_image_cond_model() - + if resolution == 512: cond = self.get_cond(image, 512) else: cond = self.get_cond(image, 1024) - + if not self.keep_models_loaded: - self.unload_image_cond_model() - + self.unload_image_cond_model() + mesh_slat = self.encode_shape_slat(mesh, resolution) - + if resolution==512: self.unload_shape_slat_flow_model_1024() - self.load_shape_slat_flow_model_512() + self.load_shape_slat_flow_model_512() shape_slat, res = self.sample_mesh_slat( mesh_slat, cond, @@ -1506,10 +1530,10 @@ def refine_mesh( max_num_tokens, downsampling ) - + if not self.keep_models_loaded: self.unload_shape_slat_flow_model_512() - + if generate_texture_slat: self.unload_tex_slat_flow_model_1024() self.load_tex_slat_flow_model_512() @@ -1517,12 +1541,12 @@ def refine_mesh( cond, self.models['tex_slat_flow_model_512'], shape_slat, tex_slat_sampler_params ) - + if not self.keep_models_loaded: - self.unload_tex_slat_flow_model_512() + self.unload_tex_slat_flow_model_512() elif resolution == 1024: self.unload_shape_slat_flow_model_512() - self.load_shape_slat_flow_model_1024() + self.load_shape_slat_flow_model_1024() shape_slat, res = self.sample_mesh_slat( mesh_slat, cond, @@ -1532,10 +1556,10 @@ def refine_mesh( max_num_tokens, downsampling ) - + if not self.keep_models_loaded: self.unload_shape_slat_flow_model_1024() - + if generate_texture_slat: self.unload_tex_slat_flow_model_512() self.load_tex_slat_flow_model_1024() @@ -1543,12 +1567,12 @@ def refine_mesh( cond, self.models['tex_slat_flow_model_1024'], shape_slat, tex_slat_sampler_params ) - + if not self.keep_models_loaded: self.unload_tex_slat_flow_model_1024() elif resolution == 1536: self.unload_shape_slat_flow_model_512() - self.load_shape_slat_flow_model_1024() + self.load_shape_slat_flow_model_1024() shape_slat, res = self.sample_mesh_slat( mesh_slat, cond, @@ -1558,10 +1582,10 @@ def refine_mesh( max_num_tokens, downsampling ) - + if not self.keep_models_loaded: self.unload_shape_slat_flow_model_1024() - + if generate_texture_slat: self.unload_tex_slat_flow_model_512() self.load_tex_slat_flow_model_1024() @@ -1569,21 +1593,21 @@ def refine_mesh( cond, self.models['tex_slat_flow_model_1024'], shape_slat, tex_slat_sampler_params ) - + if not self.keep_models_loaded: - self.unload_tex_slat_flow_model_1024() - + self.unload_tex_slat_flow_model_1024() + torch.cuda.empty_cache() if generate_texture_slat: out_mesh = self.decode_latent(shape_slat, tex_slat, res, use_tiled=use_tiled) else: out_mesh = self.decode_latent(shape_slat, None, res, use_tiled=use_tiled) torch.cuda.empty_cache() - + if return_latent: if generate_texture_slat: return out_mesh, (shape_slat, tex_slat, res) else: return out_mesh, (shape_slat, None, res) else: - return out_mesh \ No newline at end of file + return out_mesh \ No newline at end of file