diff --git a/README.md b/README.md index 68ac2c0..ed95ced 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,7 @@ +我给这个输出加了个MASK用来支持在更多地方使用, +此外我不知道他的requirements.txt里面的image是啥反正我没装也能用 + + # ComfyUI BEN - Background Erase Network **** @@ -8,7 +12,7 @@ Remove backgrounds from images with [BEN2](https://huggingface.co/PramaLLC/BEN2) ## Installation ``` -git clone https://github.com/PramaLLC/BEN2_ComfyUI.git +git clone https://github.com/chenpipi0807/BEN2_ComfyUI.git ``` ``` cd BEN2_ComfyUI diff --git a/background_erase_network.py b/background_erase_network.py index 4be60a9..c12115c 100644 --- a/background_erase_network.py +++ b/background_erase_network.py @@ -11,10 +11,9 @@ import BEN2 class BackgroundEraseNetwork: - # Define these as class variables (outside of any method) - RETURN_TYPES = ("IMAGE",) - RETURN_NAMES = ("image",) - FUNCTION = "process" + RETURN_TYPES = ("IMAGE", "MASK") + RETURN_NAMES = ("image", "mask") + FUNCTION = "process_image" CATEGORY = "BEN2" def __init__(self): @@ -32,42 +31,41 @@ def INPUT_TYPES(cls): }, } - RETURN_TYPES = ("IMAGE",) - RETURN_NAMES = ("image",) - FUNCTION = "process_image" - CATEGORY = "BEN2" - def process_image(self, input_image): - # Handle the input tensor format from ComfyUI + # 处理输入图像 if isinstance(input_image, torch.Tensor): if input_image.dim() == 4: input_image = input_image[0] - + if input_image.dim() == 3: input_image = input_image.permute(2, 0, 1) input_image = self.to_pil(input_image) - # Ensure the image is in RGBA mode + # 转换为RGBA格式 if input_image.mode != 'RGBA': input_image = input_image.convert("RGBA") - # Run inference to get the foreground image + # 执行推理 foreground = self.model.inference(input_image) - # Convert the foreground to tensor - foreground_tensor = self.to_tensor(foreground) + # 提取alpha通道作为mask + alpha = foreground.split()[-1] + mask_np = np.array(alpha) + mask_tensor = torch.from_numpy(mask_np).float() / 255.0 # 归一化到[0,1] + mask_tensor = mask_tensor.unsqueeze(0) # [B, H, W] - # Convert to ComfyUI format [B, H, W, C] - foreground_tensor = foreground_tensor.permute(1, 2, 0).unsqueeze(0) + # 转换前景图像为tensor + foreground_tensor = self.to_tensor(foreground) + foreground_tensor = foreground_tensor.permute(1, 2, 0).unsqueeze(0) # [B, H, W, C] - return (foreground_tensor,) + return (foreground_tensor, mask_tensor) -# Export mappings for ComfyUI +# ComfyUI节点映射 NODE_CLASS_MAPPINGS = { "BackgroundEraseNetwork": BackgroundEraseNetwork } NODE_DISPLAY_NAME_MAPPINGS = { - "BackgroundEraseNetwork": "Background Erase Network Image" + "BackgroundEraseNetwork": "Background Erase Network (Image+Mask)" } diff --git a/requirements.txt b/requirements.txt index 5c556ec..c1f09c4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ huggingface_hub -Image numpy torch einops