From 28287ac2a3d8673c44d31e8eb4d08f5504f407b9 Mon Sep 17 00:00:00 2001 From: HuiwenShi Date: Thu, 23 Jan 2025 14:48:32 +0800 Subject: [PATCH] fix recenter --- hy3dgen/texgen/pipelines.py | 36 ++++++++++++++++++++++++++++++++++++ minimal_demo.py | 1 - 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/hy3dgen/texgen/pipelines.py b/hy3dgen/texgen/pipelines.py index 9516b81..30888b3 100644 --- a/hy3dgen/texgen/pipelines.py +++ b/hy3dgen/texgen/pipelines.py @@ -146,6 +146,40 @@ def texture_inpaint(self, texture, mask): return texture + def recenter_image(self, image, border_ratio=0.2): + if image.mode == 'RGB': + return image + elif image.mode == 'L': + image = image.convert('RGB') + return image + + alpha_channel = np.array(image)[:, :, 3] + non_zero_indices = np.argwhere(alpha_channel > 0) + if non_zero_indices.size == 0: + raise ValueError("Image is fully transparent") + + min_row, min_col = non_zero_indices.min(axis=0) + max_row, max_col = non_zero_indices.max(axis=0) + + cropped_image = image.crop((min_col, min_row, max_col + 1, max_row + 1)) + + width, height = cropped_image.size + border_width = int(width * border_ratio) + border_height = int(height * border_ratio) + + new_width = width + 2 * border_width + new_height = height + 2 * border_height + + square_size = max(new_width, new_height) + + new_image = Image.new('RGBA', (square_size, square_size), (255, 255, 255, 0)) + + paste_x = (square_size - new_width) // 2 + border_width + paste_y = (square_size - new_height) // 2 + border_height + + new_image.paste(cropped_image, (paste_x, paste_y)) + return new_image + @torch.no_grad() def __call__(self, mesh, image): @@ -154,6 +188,8 @@ def __call__(self, mesh, image): else: image_prompt = image + image_prompt = self.recenter_image(image_prompt) + image_prompt = self.models['delight_model'](image_prompt) mesh = mesh_uv_wrap(mesh) diff --git a/minimal_demo.py b/minimal_demo.py index 45b12a7..f3c6d8c 100644 --- a/minimal_demo.py +++ b/minimal_demo.py @@ -35,7 +35,6 @@ def image_to_3d(image_path='assets/demo.png'): model_path = 'tencent/Hunyuan3D-2' image = Image.open(image_path) - image = image.resize((1024, 1024)) if image.mode == 'RGB': image = rembg(image)