Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
我给这个输出加了个MASK用来支持在更多地方使用,
此外我不知道他的requirements.txt里面的image是啥反正我没装也能用


# ComfyUI BEN - Background Erase Network

****
Expand All @@ -8,7 +12,7 @@ Remove backgrounds from images with [BEN2](https://huggingface.co/PramaLLC/BEN2)
## Installation

```
git clone https://github.com/PramaLLC/BEN2_ComfyUI.git
git clone https://github.com/chenpipi0807/BEN2_ComfyUI.git
```
```
cd BEN2_ComfyUI
Expand Down
38 changes: 18 additions & 20 deletions background_erase_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
import BEN2

class BackgroundEraseNetwork:
# Define these as class variables (outside of any method)
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("image",)
FUNCTION = "process"
RETURN_TYPES = ("IMAGE", "MASK")
RETURN_NAMES = ("image", "mask")
FUNCTION = "process_image"
CATEGORY = "BEN2"

def __init__(self):
Expand All @@ -32,42 +31,41 @@ def INPUT_TYPES(cls):
},
}

RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("image",)
FUNCTION = "process_image"
CATEGORY = "BEN2"

def process_image(self, input_image):
# Handle the input tensor format from ComfyUI
# 处理输入图像
if isinstance(input_image, torch.Tensor):
if input_image.dim() == 4:
input_image = input_image[0]

if input_image.dim() == 3:
input_image = input_image.permute(2, 0, 1)

input_image = self.to_pil(input_image)

# Ensure the image is in RGBA mode
# 转换为RGBA格式
if input_image.mode != 'RGBA':
input_image = input_image.convert("RGBA")

# Run inference to get the foreground image
# 执行推理
foreground = self.model.inference(input_image)

# Convert the foreground to tensor
foreground_tensor = self.to_tensor(foreground)
# 提取alpha通道作为mask
alpha = foreground.split()[-1]
mask_np = np.array(alpha)
mask_tensor = torch.from_numpy(mask_np).float() / 255.0 # 归一化到[0,1]
mask_tensor = mask_tensor.unsqueeze(0) # [B, H, W]

# Convert to ComfyUI format [B, H, W, C]
foreground_tensor = foreground_tensor.permute(1, 2, 0).unsqueeze(0)
# 转换前景图像为tensor
foreground_tensor = self.to_tensor(foreground)
foreground_tensor = foreground_tensor.permute(1, 2, 0).unsqueeze(0) # [B, H, W, C]

return (foreground_tensor,)
return (foreground_tensor, mask_tensor)

# Export mappings for ComfyUI
# ComfyUI节点映射
NODE_CLASS_MAPPINGS = {
"BackgroundEraseNetwork": BackgroundEraseNetwork
}

NODE_DISPLAY_NAME_MAPPINGS = {
"BackgroundEraseNetwork": "Background Erase Network Image"
"BackgroundEraseNetwork": "Background Erase Network (Image+Mask)"
}
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
huggingface_hub
Image
numpy
torch
einops
Expand Down