From 94d76a9091b1d338f47b7625afd550a62080469e Mon Sep 17 00:00:00 2001 From: Kyobo <1366213780@qq.com> Date: Wed, 16 Jul 2025 20:23:12 +0800 Subject: [PATCH] [feat&bugfix] 1. add gradio version to use model conveniently 2. fix using cpu to inference causing error --- BEN2.py | 144 +++++++++++++++++++++++++++++++++++----------- app.py | 175 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 287 insertions(+), 32 deletions(-) create mode 100644 app.py diff --git a/BEN2.py b/BEN2.py index 41d4ecb..acb3133 100644 --- a/BEN2.py +++ b/BEN2.py @@ -922,53 +922,140 @@ def __init__(self): m.inplace = True + ####################################以下为源代码################################################## + # @torch.inference_mode() + # @torch.autocast(device_type="cuda",dtype=torch.float16) + # def forward(self, x): + # real_batch = x.size(0) + # + # shallow_batch = self.shallow(x) + # glb_batch = rescale_to(x, scale_factor=0.5, interpolation='bilinear') + # + # + # + # final_input = None + # for i in range(real_batch): + # start = i * 4 + # end = (i + 1) * 4 + # loc_batch = image2patches(x[i,:,:,:].unsqueeze(dim=0)) + # input_ = torch.cat((loc_batch, glb_batch[i,:,:,:].unsqueeze(dim=0)), dim=0) + # + # + # if final_input == None: + # final_input= input_ + # else: final_input = torch.cat((final_input, input_), dim=0) + # + # features = self.backbone(final_input) + # outputs = [] + # + # for i in range(real_batch): + # + # start = i * 5 + # end = (i + 1) * 5 + # + # f4 = features[4][start:end, :, :, :] # shape: [5, C, H, W] + # f3 = features[3][start:end, :, :, :] + # f2 = features[2][start:end, :, :, :] + # f1 = features[1][start:end, :, :, :] + # f0 = features[0][start:end, :, :, :] + # e5 = self.output5(f4) + # e4 = self.output4(f3) + # e3 = self.output3(f2) + # e2 = self.output2(f1) + # e1 = self.output1(f0) + # loc_e5, glb_e5 = e5.split([4, 1], dim=0) + # e5 = self.multifieldcrossatt(loc_e5, glb_e5) # (4,128,16,16) + # + # + # e4, tokenattmap4 = self.dec_blk4(e4 + resize_as(e5, e4)) + # e4 = self.conv4(e4) + # e3, tokenattmap3 = self.dec_blk3(e3 + resize_as(e4, e3)) + # e3 = self.conv3(e3) + # e2, tokenattmap2 = self.dec_blk2(e2 + resize_as(e3, e2)) + # e2 = self.conv2(e2) + # e1, tokenattmap1 = self.dec_blk1(e1 + resize_as(e2, e1)) + # e1 = self.conv1(e1) + # + # loc_e1, glb_e1 = e1.split([4, 1], dim=0) + # + # output1_cat = patches2image(loc_e1) # (1,128,256,256) + # + # # add glb feat in + # output1_cat = output1_cat + resize_as(glb_e1, output1_cat) + # # merge + # final_output = self.insmask_head(output1_cat) # (1,128,256,256) + # # shallow feature merge + # shallow = shallow_batch[i,:,:,:].unsqueeze(dim=0) + # final_output = final_output + resize_as(shallow, final_output) + # final_output = self.upsample1(rescale_to(final_output)) + # final_output = rescale_to(final_output + resize_as(shallow, final_output)) + # final_output = self.upsample2(final_output) + # final_output = self.output(final_output) + # mask = final_output.sigmoid() + # outputs.append(mask) + # + # return torch.cat(outputs, dim=0) + + ####################################以上为源代码################################################## @torch.inference_mode() - @torch.autocast(device_type="cuda",dtype=torch.float16) def forward(self, x): - real_batch = x.size(0) + # 自动处理数据类型:GPU用float16,CPU用float32 + if x.device.type == 'cuda': + # GPU模式使用autocast+float16 + with torch.autocast(device_type="cuda", dtype=torch.float16): + return self._forward_impl(x) + else: + # CPU模式强制使用float32 + x = x.float() + return self._forward_impl(x) + def _forward_impl(self, x): + real_batch = x.size(0) shallow_batch = self.shallow(x) - glb_batch = rescale_to(x, scale_factor=0.5, interpolation='bilinear') - + # 确保rescale_to在CPU上使用float32 + if x.device.type != 'cuda': + glb_batch = rescale_to(x.float(), scale_factor=0.5, interpolation='bilinear') + else: + glb_batch = rescale_to(x, scale_factor=0.5, interpolation='bilinear') final_input = None for i in range(real_batch): start = i * 4 - end = (i + 1) * 4 - loc_batch = image2patches(x[i,:,:,:].unsqueeze(dim=0)) - input_ = torch.cat((loc_batch, glb_batch[i,:,:,:].unsqueeze(dim=0)), dim=0) - - - if final_input == None: - final_input= input_ - else: final_input = torch.cat((final_input, input_), dim=0) + end = (i + 1) * 4 + loc_batch = image2patches(x[i, :, :, :].unsqueeze(dim=0)) + input_ = torch.cat((loc_batch, glb_batch[i, :, :, :].unsqueeze(dim=0)), dim=0) + + if final_input is None: + final_input = input_ + else: + final_input = torch.cat((final_input, input_), dim=0) features = self.backbone(final_input) outputs = [] - - for i in range(real_batch): + for i in range(real_batch): start = i * 5 - end = (i + 1) * 5 - - f4 = features[4][start:end, :, :, :] # shape: [5, C, H, W] + end = (i + 1) * 5 + + f4 = features[4][start:end, :, :, :] f3 = features[3][start:end, :, :, :] f2 = features[2][start:end, :, :, :] f1 = features[1][start:end, :, :, :] f0 = features[0][start:end, :, :, :] + e5 = self.output5(f4) e4 = self.output4(f3) e3 = self.output3(f2) e2 = self.output2(f1) e1 = self.output1(f0) - loc_e5, glb_e5 = e5.split([4, 1], dim=0) - e5 = self.multifieldcrossatt(loc_e5, glb_e5) # (4,128,16,16) + loc_e5, glb_e5 = e5.split([4, 1], dim=0) + e5 = self.multifieldcrossatt(loc_e5, glb_e5) - e4, tokenattmap4 = self.dec_blk4(e4 + resize_as(e5, e4)) - e4 = self.conv4(e4) + e4, tokenattmap4 = self.dec_blk4(e4 + resize_as(e5, e4)) + e4 = self.conv4(e4) e3, tokenattmap3 = self.dec_blk3(e3 + resize_as(e4, e3)) e3 = self.conv3(e3) e2, tokenattmap2 = self.dec_blk2(e2 + resize_as(e3, e2)) @@ -977,15 +1064,11 @@ def forward(self, x): e1 = self.conv1(e1) loc_e1, glb_e1 = e1.split([4, 1], dim=0) - - output1_cat = patches2image(loc_e1) # (1,128,256,256) - - # add glb feat in + output1_cat = patches2image(loc_e1) output1_cat = output1_cat + resize_as(glb_e1, output1_cat) - # merge - final_output = self.insmask_head(output1_cat) # (1,128,256,256) - # shallow feature merge - shallow = shallow_batch[i,:,:,:].unsqueeze(dim=0) + + final_output = self.insmask_head(output1_cat) + shallow = shallow_batch[i, :, :, :].unsqueeze(dim=0) final_output = final_output + resize_as(shallow, final_output) final_output = self.upsample1(rescale_to(final_output)) final_output = rescale_to(final_output + resize_as(shallow, final_output)) @@ -996,9 +1079,6 @@ def forward(self, x): return torch.cat(outputs, dim=0) - - - def loadcheckpoints(self,model_path): model_dict = torch.load(model_path, map_location="cpu", weights_only=True) self.load_state_dict(model_dict['model_state_dict'], strict=True) diff --git a/app.py b/app.py new file mode 100644 index 0000000..c69a379 --- /dev/null +++ b/app.py @@ -0,0 +1,175 @@ +import os +import torch +import argparse +import gradio as gr +import numpy as np +from PIL import Image +import BEN2 +import logging +from datetime import datetime + + +# 配置日志系统 +def setup_logging(): + log_format = '\033[33m%(asctime)s \033[0m[%(threadName)s] %(levelname)s \033[32m[%(filename)s-%(funcName)s-%(lineno)d]\033[0m - %(message)s' + logging.basicConfig( + level=logging.INFO, + format=log_format, + handlers=[ + logging.StreamHandler(), + logging.FileHandler("../log/matting_tool.log") + ] + ) + + +setup_logging() +logger = logging.getLogger(__name__) + +# 解析命令行参数 +parser = argparse.ArgumentParser(description="BEN2: Background Erase Network") +parser.add_argument('--port', type=int, required=True, help="Gradio port") +parser.add_argument('--device', type=str, default='cuda') +args = parser.parse_args() + +# 设置设备 +if args.device == 'cuda': + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +else: + device = torch.device('cpu') +logger.info(f"using device: {device}") + +# 加载模型 +try: + model = BEN2.BEN_Base().to(device).eval() + model_path = "/home/hx/workspace/projects/video_python/ai_change_cloth_auto/model/matting_model/BEN2/BEN2_Base.pth" + model.loadcheckpoints(model_path) + + if device.type == 'cpu': + model = model.float() # 将模型转换为单精度 + torch.set_float32_matmul_precision('high') # 可选,设置矩阵乘法精度 + logger.info(f"successfully load: {model_path}") +except Exception as e: + logger.error(f"fail to load: {str(e)}") + raise + + +def process_single_image(image): + """处理单张图片""" + try: + start_time = datetime.now() + logger.info("开始处理单张图片...") + + # 转换输入格式 + if isinstance(image, np.ndarray): + logger.debug("输入为numpy数组,转换为PIL图像") + image = Image.fromarray(image) + elif isinstance(image, str): + logger.debug(f"输入为文件路径: {image}") + image = Image.open(image) + else: + error_msg = f"不支持的图片类型: {type(image)}" + logger.error(error_msg) + raise TypeError(error_msg) + + logger.debug("将图片转换为RGB格式") + image = image.convert("RGB") + + # 执行抠图 + logger.info("正在进行抠图处理...") + foreground = model.inference(image, refine_foreground=False) + + # 保存结果 + output_path = "/tmp/foreground.png" + foreground.save(output_path, format="PNG") + logger.info(f"抠图完成! 结果已保存到: {output_path}") + + elapsed = (datetime.now() - start_time).total_seconds() + logger.info(f"单张图片处理耗时: {elapsed:.2f}秒") + + return foreground, output_path + + except Exception as e: + logger.error(f"处理单张图片时出错: {str(e)}", exc_info=True) + raise + + +def process_folder(folder_path, output_folder): + """批量处理文件夹中的图片""" + try: + start_time = datetime.now() + logger.info(f"开始批量处理文件夹: {folder_path}") + + if not os.path.isdir(folder_path): + error_msg = f"输入文件夹不存在: {folder_path}" + logger.error(error_msg) + raise FileNotFoundError(error_msg) + + os.makedirs(output_folder, exist_ok=True) + logger.info(f"创建输出文件夹: {output_folder}") + + processed_count = 0 + for image_item in os.listdir(folder_path): + image_path = os.path.join(folder_path, image_item) + if image_path.lower().endswith(('png', 'jpg', 'jpeg')): + try: + logger.info(f"正在处理: {image_path}") + image = Image.open(image_path) + foreground = model.inference(image, refine_foreground=False) + + output_path = os.path.join(output_folder, f"foreground-{image_item}") + foreground.save(output_path) + processed_count += 1 + logger.debug(f"已保存结果到: {output_path}") + except Exception as e: + logger.error(f"处理图片 {image_path} 时出错: {str(e)}", exc_info=True) + continue + + elapsed = (datetime.now() - start_time).total_seconds() + logger.info(f"批量处理完成! 共处理 {processed_count} 张图片, 总耗时: {elapsed:.2f}秒") + return f"所有图片处理完成! 共处理 {processed_count} 张图片, 结果保存在: {output_folder}" + + except Exception as e: + logger.error(f"批量处理文件夹时出错: {str(e)}", exc_info=True) + raise + + +# Gradio界面 +with gr.Blocks(title="BEN2: Background Erase Network") as app: + gr.Markdown("# 🖼️ BEN2: Background Erase Network") + gr.Markdown("Support single image file or Batch image folder") + + with gr.Row(): + with gr.Column(): + image_input = gr.Image(label="上传图片/Upload image") + process_button = gr.Button("开始抠图/Start", variant="primary") + with gr.Column(): + image_output = gr.Image(label="抠图结果/Result") + download_file = gr.File(label="下载 PNG/ Download") + + process_button.click( + process_single_image, + inputs=image_input, + outputs=[image_output, download_file], + api_name="single_image_matting" + ) + + with gr.Row(): + folder_input = gr.Textbox(label="输入图片文件夹路径/Input image folder path", placeholder="请输入包含图片的文件夹路径/Please enter the folder path containing the picture") + output_folder_input = gr.Textbox(label="输出文件夹路径/Output image folder path", placeholder="请输入保存结果的文件夹路径/Please enter the folder path where you will save the result") + process_folder_button = gr.Button("批量抠图/Batch process", variant="primary") + + folder_output_info = gr.Textbox(label="处理结果/Result") + + process_folder_button.click( + process_folder, + inputs=[folder_input, output_folder_input], + outputs=folder_output_info, + api_name="batch_matting" + ) + +# 启动应用 +try: + logger.info(f"Start Gradio service, port: {args.port}") + app.launch(server_name="0.0.0.0", server_port=args.port) +except Exception as e: + logger.error(f"Fail to start Gradio service: {str(e)}", exc_info=True)