diff --git a/diffueraser/diffueraser.py b/diffueraser/diffueraser.py index db8a77e..32779aa 100644 --- a/diffueraser/diffueraser.py +++ b/diffueraser/diffueraser.py @@ -117,6 +117,35 @@ def read_mask(validation_mask, fps, n_total_frames, img_size, mask_dilation_iter return masks, masked_images +def mask_process(np_masks, n_total_frames, img_size, mask_dilation_iter, frames): + masks = [] + masked_images = [] + # if not np.isclose(mask_fps, fps, rtol=1e-3): # allow 0.1% relative error + # raise ValueError("The frame rate of all input videos needs to be consistent.") + for idx, mask in enumerate(np_masks): + if idx >= n_total_frames: + break + if mask.size != img_size: + mask = mask.resize(img_size, Image.NEAREST) + + m = np.array(mask.convert('L')) > 0 + m = m.astype(np.uint8) + m = cv2.erode(m, + cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)), + iterations=1) + m = cv2.dilate(m, + cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)), + iterations=mask_dilation_iter) + mask = Image.fromarray(m * 255) + masks.append(mask) + + masked_image = np.array(frames[idx])*(1-(np.array(mask)[:,:,np.newaxis].astype(np.float32)/255)) + masked_image = Image.fromarray(masked_image.astype(np.uint8)) + masked_images.append(masked_image) + + return masks, masked_images + + def read_priori(priori, fps, n_total_frames, img_size): cap = cv2.VideoCapture(priori) if not cap.isOpened(): @@ -143,10 +172,23 @@ def read_priori(priori, fps, n_total_frames, img_size): idx += 1 cap.release() - os.remove(priori) # remove priori + # os.remove(priori) # remove priori + + return prioris +def read_priori_list(priori_list, fps, n_total_frames, img_size): + prioris = [] + for idx, priori in enumerate(priori_list): + if(idx >= n_total_frames): + break + img = Image.fromarray(priori[...,::-1]) + if img.size != img_size: + img = img.resize(img_size) + prioris.append(img) + # os.remove(priori) # remove priori return prioris + def read_video(validation_image, video_length, nframes, max_img_size): vframes, aframes, info = torchvision.io.read_video(filename=validation_image, pts_unit='sec', end_pts=video_length) # RGB fps = info['video_fps'] @@ -178,6 +220,33 @@ def read_video(validation_image, video_length, nframes, max_img_size): return frames, fps, img_size, n_clip, n_total_frames +def clip_process(video_info, video_length, nframes, max_img_size): + frames, fps, n_clip, n_total_frames = video_info['frames'], video_info['fps'], video_info['n_frames'], video_info['n_frames'] + n_total_frames = int(video_length * fps) + n_clip = int(np.ceil(n_total_frames/nframes)) + + max_size = max(frames[0].size) + if(max_size<256): + raise ValueError("The resolution of the uploaded video must be larger than 256x256.") + if(max_size>4096): + raise ValueError("The resolution of the uploaded video must be smaller than 4096x4096.") + if max_size>max_img_size: + ratio = max_size/max_img_size + ratio_size = (int(frames[0].size[0]/ratio),int(frames[0].size[1]/ratio)) + img_size = (ratio_size[0]-ratio_size[0]%8, ratio_size[1]-ratio_size[1]%8) + resize_flag=True + elif (frames[0].size[0]%8==0) and (frames[0].size[1]%8==0): + img_size = frames[0].size + resize_flag=False + else: + ratio_size = frames[0].size + img_size = (ratio_size[0]-ratio_size[0]%8, ratio_size[1]-ratio_size[1]%8) + resize_flag=True + if resize_flag: + frames = resize_frames(frames, img_size) + img_size = frames[0].size + + return frames, fps, img_size, n_clip, n_total_frames class DiffuEraser: def __init__( @@ -215,7 +284,8 @@ def __init__( tokenizer=self.tokenizer, unet=self.unet_main, brushnet=self.brushnet - ).to(self.device, torch.float16) + ).to(torch.float16) + self.pipeline = self.pipeline.to(self.device) self.pipeline.scheduler = UniPCMultistepScheduler.from_config(self.pipeline.scheduler.config) self.pipeline.set_progress_bar_config(disable=True) @@ -246,24 +316,24 @@ def __init__( self.num_inference_steps = checkpoints[ckpt][1] self.guidance_scale = 0 - def forward(self, validation_image, validation_mask, priori, output_path, + def forward(self, video_info, mask_info, priori, output_path, max_img_size = 1280, video_length=2, mask_dilation_iter=4, nframes=22, seed=None, revision = None, guidance_scale=None, blended=True): validation_prompt = "" # guidance_scale_final = self.guidance_scale if guidance_scale==None else guidance_scale - if (max_img_size<256 or max_img_size>1920): raise ValueError("The max_img_size must be larger than 256, smaller than 1920.") ################ read input video ################ - frames, fps, img_size, n_clip, n_total_frames = read_video(validation_image, video_length, nframes, max_img_size) + frames, fps, img_size, n_clip, n_total_frames = clip_process(video_info, video_length, nframes, max_img_size) video_len = len(frames) ################ read mask ################ - validation_masks_input, validation_images_input = read_mask(validation_mask, fps, video_len, img_size, mask_dilation_iter, frames) + # validation_masks_input, validation_images_input = read_mask(validation_mask, fps, video_len, img_size, mask_dilation_iter, frames) + validation_masks_input, validation_images_input = mask_process(mask_info['masks'], video_len, img_size, mask_dilation_iter, frames) ################ read priori ################ - prioris = read_priori(priori, fps, n_total_frames, img_size) + prioris = read_priori_list(priori, fps, n_total_frames, img_size) ## recheck n_total_frames = min(min(len(frames), len(validation_masks_input)), len(prioris)) diff --git a/propainter/inference.py b/propainter/inference.py index b180d49..6dcbbfa 100644 --- a/propainter/inference.py +++ b/propainter/inference.py @@ -9,6 +9,7 @@ import torchvision import gc + try: from model.modules.flow_comp_raft import RAFT_bi from model.recurrent_flow_completion import RecurrentFlowCompleteNet @@ -96,10 +97,15 @@ def read_mask(mpath, frames_len, size, flow_mask_dilates=8, mask_dilates=5): masks_img.append(Image.fromarray(frame)) idx += 1 cap.release() - else: - mnames = sorted(os.listdir(mpath)) + else: + files = [i for i in os.listdir(mpath) if i.endswith(('.npz', 'jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG'))] + mnames = sorted(files, key=lambda x: int(x.split('_')[1].split('.')[0])) for mp in mnames: - masks_img.append(Image.open(os.path.join(mpath, mp))) + if mp.endswith(('jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG')): + masks_img.append(Image.open(os.path.join(mpath, mp))) + else: + mask_array = np.where(np.load(os.path.join(mpath, mp))['mask'] == 255, 0, 1) + masks_img.append(Image.fromarray((mask_array).astype(np.uint8))) # print(mp) for mask_img in masks_img: @@ -129,6 +135,41 @@ def read_mask(mpath, frames_len, size, flow_mask_dilates=8, mask_dilates=5): return flow_masks, masks_dilated +def mask_process(np_masks, frames_len, size, flow_mask_dilates=8, mask_dilates=5): + masks_dilated = [] + flow_masks = [] + + for idx,mask_img in enumerate(np_masks): + if(idx >= frames_len): + break + if size is not None: + mask_img = mask_img.resize(size, Image.NEAREST) + mask_img = np.array(mask_img.convert('L')) + + # Dilate 8 pixel so that all known pixel is trustworthy + if flow_mask_dilates > 0: + flow_mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=flow_mask_dilates).astype(np.uint8) + else: + flow_mask_img = binary_mask(mask_img).astype(np.uint8) + # Close the small holes inside the foreground objects + # flow_mask_img = cv2.morphologyEx(flow_mask_img, cv2.MORPH_CLOSE, np.ones((21, 21),np.uint8)).astype(bool) + # flow_mask_img = scipy.ndimage.binary_fill_holes(flow_mask_img).astype(np.uint8) + flow_masks.append(Image.fromarray(flow_mask_img * 255)) + + if mask_dilates > 0: + mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=mask_dilates).astype(np.uint8) + else: + mask_img = binary_mask(mask_img).astype(np.uint8) + masks_dilated.append(Image.fromarray(mask_img * 255)) + + if len(np_masks) == 1: + flow_masks = flow_masks * frames_len + masks_dilated = masks_dilated * frames_len + + return flow_masks, masks_dilated + + + def get_ref_index(mid_neighbor_id, neighbor_ids, length, ref_stride=10, ref_num=-1): ref_index = [] if ref_num == -1: @@ -172,7 +213,7 @@ def __init__( model_dir=propainter_model_dir, progress=True, file_name=None) self.model = InpaintGenerator(model_path=ckpt_path).to(device) self.model.eval() - def forward(self, video, mask, output_path, resize_ratio=0.6, video_length=2, height=-1, width=-1, + def forward(self, video_info, mask_info, output_path, resize_ratio=0.6, video_length=2, height=-1, width=-1, mask_dilation=4, ref_stride=10, neighbor_length=10, subvideo_length=80, raft_iter=20, save_fps=24, save_frames=False, fp16=True): @@ -182,7 +223,7 @@ def forward(self, video, mask, output_path, resize_ratio=0.6, video_length=2, he use_half = False ################ read input video ################ - frames, fps, size, video_name, nframes = read_frame_from_videos(video, video_length) + frames, fps, size, video_name, nframes = video_info['frames'], video_info['fps'], video_info['size'], video_info['video_name'], video_info['n_frames'] frames = frames[:nframes] if not width == -1 and not height == -1: size = (width, height) @@ -199,7 +240,7 @@ def forward(self, video, mask, output_path, resize_ratio=0.6, video_length=2, he ################ read mask ################ frames_len = len(frames) - flow_masks, masks_dilated = read_mask(mask, frames_len, size, + flow_masks, masks_dilated = mask_process(mask_info['masks'], frames_len, size, flow_mask_dilates=mask_dilation, mask_dilates=mask_dilation) flow_masks = flow_masks[:nframes] @@ -490,18 +531,19 @@ def forward(self, video, mask, output_path, resize_ratio=0.6, video_length=2, he torch.cuda.empty_cache() - ##save composed video## + # ##save composed video## comp_frames = [cv2.resize(f, out_size) for f in comp_frames] - writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), - fps, (comp_frames[0].shape[1],comp_frames[0].shape[0])) - for f in range(video_length): - frame = comp_frames[f].astype(np.uint8) - writer.write(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) - writer.release() + # writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), + # fps, (comp_frames[0].shape[1],comp_frames[0].shape[0])) + # for f in range(video_length): + # frame = comp_frames[f].astype(np.uint8) + # writer.write(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + # writer.release() - torch.cuda.empty_cache() + # torch.cuda.empty_cache() - return output_path + # return output_path + return comp_frames @@ -511,10 +553,13 @@ def forward(self, video, mask, output_path, resize_ratio=0.6, video_length=2, he propainter_model_dir = "weights/propainter" propainter = Propainter(propainter_model_dir, device=device) - video = "examples/example1/video.mp4" - mask = "examples/example1/mask.mp4" + video = "/mnt6/jinruh/data/moviiDB/long-context-following/friends/clips/Friends.S01E01.1080p.BluRay.x265-RARBG/Friends.S01E01.1080p.BluRay.x265-RARBG_1261.13_1263.30.mp4" + # mask = "examples/example1/mask.mp4" + mask = "/mnt6/zengzhuo/projects/VideoDataProcess/split/video_clips/friends/mask_no_limit/Friends.S01E01.1080p.BluRay.x265-RARBG/Friends.S01E01.1080p.BluRay.x265-RARBG_1261.13_1263.30" output = "results/priori.mp4" res = propainter.forward(video, mask, output) + import pdb; pdb.set_trace() + print(res) \ No newline at end of file diff --git a/run_diffueraser_qf.py b/run_diffueraser_qf.py new file mode 100644 index 0000000..a4a7462 --- /dev/null +++ b/run_diffueraser_qf.py @@ -0,0 +1,139 @@ +from diffueraser.diffueraser import DiffuEraser +from propainter.inference import Propainter, get_device +from pathlib import Path + +import torch +import os +import time +import argparse +import shutil +from tools.load_source import read_video_shared, read_mask_shared + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--input_video', type=str, default="examples/example3/video.mp4", help='Path to the input video') + parser.add_argument('--input_mask', type=str, default="examples/example3/mask.mp4" , help='Path to the input mask') + parser.add_argument('--video_length', type=int, default=30, help='The maximum length of output video') + parser.add_argument('--mask_dilation_iter', type=int, default=8, help='Adjust it to change the degree of mask expansion') + parser.add_argument('--max_img_size', type=int, default=654, help='The maximum length of output width and height') + parser.add_argument('--save_path', type=str, default="results" , help='Path to the output') + parser.add_argument('--ref_stride', type=int, default=10, help='Propainter params') + parser.add_argument('--neighbor_length', type=int, default=10, help='Propainter params') + parser.add_argument('--subvideo_length', type=int, default=50, help='Propainter params') + parser.add_argument('--base_model_path', type=str, default="weights/stable-diffusion-v1-5" , help='Path to sd1.5 base model') + parser.add_argument('--vae_path', type=str, default="weights/sd-vae-ft-mse" , help='Path to vae') + parser.add_argument('--diffueraser_path', type=str, default="weights/diffuEraser" , help='Path to DiffuEraser') + parser.add_argument('--propainter_model_dir', type=str, default="weights/propainter" , help='Path to priori model') + return parser.parse_args() + +def process_single_video(input_video, input_mask, save_path, args, video_inpainting_sd, propainter, copy_inputs=False): + if not os.path.exists(save_path): + os.makedirs(save_path) + priori_path = os.path.join(save_path, "priori.mp4") + output_path = os.path.join(save_path, "diffueraser_result.mp4") + + # Copy input files if requested + if copy_inputs: + input_copy_path = os.path.join(save_path, "input.mp4") + mask_copy_path = os.path.join(save_path, "mask.mp4") + shutil.copy2(input_video, input_copy_path) + shutil.copy2(input_mask, mask_copy_path) + + start_time = time.time() + + video_info = read_video_shared(input_video, args.video_length) + mask_info = read_mask_shared(input_mask) + ## priori + priori_list = propainter.forward(video_info, mask_info, priori_path, video_length=args.video_length, + ref_stride=args.ref_stride, neighbor_length=args.neighbor_length, subvideo_length=args.subvideo_length, + mask_dilation=args.mask_dilation_iter) + + ## diffueraser + guidance_scale = None # The default value is 0. + video_inpainting_sd.forward(video_info, mask_info, priori_list, output_path, + max_img_size=args.max_img_size, video_length=args.video_length, mask_dilation_iter=args.mask_dilation_iter, + guidance_scale=guidance_scale) + + end_time = time.time() + inference_time = end_time - start_time + print(f"DiffuEraser inference time for {input_video}: {inference_time:.4f} s") + +def main(input_videos=None, input_masks=None, save_paths=None, copy_inputs: bool = False): + args = get_args() + + # If no lists provided, use command line arguments + if input_videos is None: + input_videos = [args.input_video] + input_masks = [args.input_mask] + save_paths = [args.save_path] + + # Validate input lists + if not (len(input_videos) == len(input_masks) == len(save_paths)): + raise ValueError("Input lists must have the same length") + + ## model initialization + device = get_device() + # PCM params + ckpt = "2-Step" + video_inpainting_sd = DiffuEraser(device, args.base_model_path, args.vae_path, args.diffueraser_path, ckpt=ckpt) + propainter = Propainter(args.propainter_model_dir, device=device) + + # Process each video + for input_video, input_mask, save_path in zip(input_videos, input_masks, save_paths): + process_single_video(input_video, input_mask, save_path, args, video_inpainting_sd, propainter, copy_inputs=copy_inputs) + + torch.cuda.empty_cache() + +def test_batch_processing(): + # # 设置基础路径 + # mask_dir = Path("/mnt3/qiufeng/documents/code/MoviiDB/data/Friends/mask_video") + # video_dir = Path("/mnt3/qiufeng/documents/code/MoviiDB/data/Friends/clips") + # results_dir = Path("/mnt3/qiufeng/documents/code/MoviiDB/DiffuEraser/results") + + # # 获取所有mask视频的路径 + # mask_paths = list(mask_dir.rglob("*.mp4")) + + # # 准备对应的输入视频和保存路径 + # input_videos = [] + # input_masks = [] + # save_paths = [] + + # for mask_path in mask_paths: + # # 使用相同的文件名获取对应的输入视频 + # video_path = video_dir / mask_path.relative_to(mask_dir) + # # 创建对应的保存目录(使用文件名作为子目录) + # save_path = results_dir / mask_path.relative_to(mask_dir).with_suffix("") + + # # 如果保存路径已经存在,则跳过 + # if (save_path / "diffueraser_result.mp4").exists(): + # print(f"Save path already exists: {save_path}") + # continue + + # # 确保输入视频存在 + # if video_path.exists(): + # input_videos.append(str(video_path)) + # input_masks.append(str(mask_path)) + # save_paths.append(str(save_path)) + # else: + # print(f"Video not found for mask: {video_path}") + + # # 打印处理信息 + # print(f"Found {len(input_videos)} videos to process:") + # for i, (video, mask, save) in enumerate(zip(input_videos, input_masks, save_paths), 1): + # print(f"\nPair {i}:") + # print(f"Video: {video}") + # print(f"Mask: {mask}") + # print(f"Save: {save}") + + # 确认是否继续处理 + response = input("\nDo you want to proceed with processing these videos? (y/n): ") + input_videos = ["/mnt6/jinruh/data/moviiDB/long-context-following/friends/clips/Friends.S01E01.1080p.BluRay.x265-RARBG/Friends.S01E01.1080p.BluRay.x265-RARBG_1261.13_1263.30.mp4"] + input_masks = ["/mnt6/zengzhuo/projects/VideoDataProcess/split/video_clips/friends/mask_no_limit/Friends.S01E01.1080p.BluRay.x265-RARBG/Friends.S01E01.1080p.BluRay.x265-RARBG_1261.13_1263.30"] + save_paths = ["results"] + if response.lower() == 'y': + main(input_videos, input_masks, save_paths, copy_inputs=False) + else: + print("Processing cancelled.") + +if __name__ == '__main__': + test_batch_processing() diff --git a/run_diffueraser_qf_multi_processor.py b/run_diffueraser_qf_multi_processor.py new file mode 100644 index 0000000..fe8138e --- /dev/null +++ b/run_diffueraser_qf_multi_processor.py @@ -0,0 +1,198 @@ +from diffueraser.diffueraser import DiffuEraser +from propainter.inference import Propainter, get_device +from pathlib import Path +import sys + +import torch +import os +import time +import argparse +import shutil +from tools.load_source import read_video_shared, read_mask_shared + +import threading +from queue import Queue +import torch.multiprocessing as mp +from typing import List, Optional + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--input_video', type=str, default="examples/example3/video.mp4", help='Path to the input video') + parser.add_argument('--input_mask', type=str, default="examples/example3/mask.mp4" , help='Path to the input mask') + parser.add_argument('--video_length', type=int, default=30, help='The maximum length of output video') + parser.add_argument('--mask_dilation_iter', type=int, default=8, help='Adjust it to change the degree of mask expansion') + parser.add_argument('--max_img_size', type=int, default=654, help='The maximum length of output width and height') + parser.add_argument('--save_path', type=str, default="results" , help='Path to the output') + parser.add_argument('--ref_stride', type=int, default=10, help='Propainter params') + parser.add_argument('--neighbor_length', type=int, default=10, help='Propainter params') + parser.add_argument('--subvideo_length', type=int, default=50, help='Propainter params') + parser.add_argument('--base_model_path', type=str, default="weights/stable-diffusion-v1-5" , help='Path to sd1.5 base model') + parser.add_argument('--vae_path', type=str, default="weights/sd-vae-ft-mse" , help='Path to vae') + parser.add_argument('--diffueraser_path', type=str, default="weights/diffuEraser" , help='Path to DiffuEraser') + parser.add_argument('--propainter_model_dir', type=str, default="weights/propainter" , help='Path to priori model') + parser.add_argument('--idx', type=int, default=0, help='第几份') + parser.add_argument('--total_piece', type=int, default=None, help='总片数') + parser.add_argument('--gpu_id', type=int, default=0, help='使用的GPU ID') + return parser.parse_args() + +class VideoProcessor: + def __init__(self, num_workers: int, args): + self.task_queue = Queue() + self.args = args + self.num_workers = num_workers + self.workers: List[threading.Thread] = [] + + def producer(self, input_videos: List[str], input_masks: List[str], save_paths: List[str]): + for video, mask, save_path in zip(input_videos, input_masks, save_paths): + self.task_queue.put((video, mask, save_path)) + + # 添加结束标记 + for _ in range(self.num_workers): + self.task_queue.put(None) + + def consumer(self, device_id: int): + torch.cuda.set_device(device_id) + # 在线程开始时初始化模型 + device = get_device(device_id % 8) + ckpt = "2-Step" + video_inpainting_sd = DiffuEraser(device, self.args.base_model_path, + self.args.vae_path, + self.args.diffueraser_path, + ckpt=ckpt) + propainter = Propainter(self.args.propainter_model_dir, device=device) + + while True: + task = self.task_queue.get() + if task is None: # 结束标记 + break + + input_video, input_mask, save_path = task + try: + self.process_single_video(input_video, input_mask, save_path, + video_inpainting_sd, propainter) + print(f"已在 GPU {device_id} 上处理完成 {input_video}") + except Exception as e: + print(f"错误发生在文件 {__file__} 第 {sys.exc_info()[2].tb_lineno} 行") + print(f"在 GPU {device_id} 上处理 {input_video} 时发生错误: {str(e)}") + finally: + torch.cuda.empty_cache() + self.task_queue.task_done() + + def process_single_video(self, input_video, input_mask, save_path, + video_inpainting_sd, propainter): + if not os.path.exists(save_path): + os.makedirs(save_path) + priori_path = '' + output_path = os.path.join(save_path, "diffueraser_result.mp4") + + start_time = time.time() + + video_info = read_video_shared(input_video, self.args.video_length) + mask_info = read_mask_shared(input_mask) + + ## priori + priori_list = propainter.forward(video_info, mask_info, priori_path, + video_length=self.args.video_length, + ref_stride=self.args.ref_stride, + neighbor_length=self.args.neighbor_length, + subvideo_length=self.args.subvideo_length, + mask_dilation=self.args.mask_dilation_iter) + + ## diffueraser + guidance_scale = None # The default value is 0. + video_inpainting_sd.forward(video_info, mask_info, priori_list, output_path, + max_img_size=self.args.max_img_size, + video_length=self.args.video_length, + mask_dilation_iter=self.args.mask_dilation_iter, + guidance_scale=guidance_scale) + + end_time = time.time() + inference_time = end_time - start_time + print(f"DiffuEraser 处理 {input_video} 的推理时间: {inference_time:.4f} 秒") + + def start_processing(self, input_videos: List[str], input_masks: List[str], save_paths: List[str],gpu_id: int): + # 启动消费者线程 + for i in range(self.num_workers): + worker = threading.Thread(target=self.consumer, args=(gpu_id%8,)) + worker.start() + self.workers.append(worker) + + # 启动生产者线程 + producer = threading.Thread(target=self.producer, args=(input_videos, input_masks, save_paths)) + producer.start() + + # 等待所有任务完成 + producer.join() + self.task_queue.join() + + # 等待所有工作线程结束 + for worker in self.workers: + worker.join() + +def main(input_videos=None, input_masks=None, save_paths=None, num_workers: Optional[int] = None): + args = get_args() + idx = args.idx + total_piece = args.total_piece + gpu_id = args.gpu_id + + # 如果没有提供输入列表,使用命令行参数 + if input_videos is None: + input_videos = [args.input_video] + input_masks = [args.input_mask] + save_paths = [args.save_path] + + input_videos_piece = input_videos[idx::total_piece] + input_masks_piece = input_masks[idx::total_piece] + save_paths_piece = save_paths[idx::total_piece] + + # 验证输入列表 + if not (len(input_videos) == len(input_masks) == len(save_paths)): + raise ValueError("输入列表长度必须相同") + + # 如果未指定工作线程数,使用可用的GPU数量 + if num_workers is None: + num_workers = torch.cuda.device_count() + + processor = VideoProcessor(num_workers, args) + processor.start_processing(input_videos_piece, input_masks_piece, save_paths_piece, gpu_id) + +def test_batch_processing(): + # 设置基础路径 + mask_dir = Path("/mnt6/zengzhuo/projects/VideoDataProcess/split/video_clips/friends/mask_no_limit/") + video_dir = Path("/mnt6/jinruh/data/moviiDB/long-context-following/friends/clips/") + results_dir = Path("./results") + + # 获取所有mask视频的路径 + video_paths = list(video_dir.rglob("*.mp4")) + # mask_paths = [mask_dir / video_path.relative_to(video_dir).with_suffix('') for video_path in video_paths] + + # 准备对应的输入视频和保存路径 + input_videos = [] + input_masks = [] + save_paths = [] + + for video_path in video_paths: + mask_path = mask_dir / video_path.relative_to(video_dir).with_suffix("") + if not mask_path.exists(): + print(f"Mask not found for video: {mask_path}") + continue + # 创建对应的保存目录(使用文件名作为子目录) + save_path = results_dir / video_path.relative_to(video_dir).with_suffix("") + + # 如果保存路径已经存在,则跳过 + if (save_path / "diffueraser_result.mp4").exists(): + print(f"Save path already exists: {save_path}") + continue + + # 确保输入视频存在 + if video_path.exists(): + input_videos.append(str(video_path)) + input_masks.append(str(mask_path)) + save_paths.append(str(save_path)) + else: + print(f"Video not found for mask: {video_path}") + + main(input_videos, input_masks, save_paths, num_workers=1) # 设置使用16个GPU + +if __name__ == '__main__': + test_batch_processing() \ No newline at end of file diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tools/load_source.py b/tools/load_source.py new file mode 100644 index 0000000..d8e8ba0 --- /dev/null +++ b/tools/load_source.py @@ -0,0 +1,71 @@ +import cv2 +import numpy as np +import os +import torchvision +from PIL import Image + +def read_video_shared(video_path, video_length): + """ + 统一的视频读取函数 + """ + if video_path.endswith(('mp4', 'mov', 'avi', 'MP4', 'MOV', 'AVI')): + video_name = os.path.basename(video_path)[:-4] + vframes, aframes, info = torchvision.io.read_video(filename=video_path, pts_unit='sec', end_pts=video_length) + fps = info['video_fps'] + n_total_frames = int(video_length * fps) if video_length else len(vframes) + + frames = list(vframes.numpy())[:n_total_frames] + frames = [Image.fromarray(f) for f in frames] + + else: # 文件夹输入 + video_name = os.path.basename(video_path) + frames = [] + fr_lst = sorted(os.listdir(video_path)) + for fr in fr_lst: + frame = cv2.imread(os.path.join(video_path, fr)) + frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + frames.append(frame) + fps = None + n_total_frames = len(frames) + + img_size = frames[0].size + + return { + 'frames': frames, + 'fps': fps, + 'size': img_size, + 'video_name': video_name, + 'n_frames': n_total_frames, + } + + +def read_mask_shared(mpath): + masks_img = [] + + if mpath.endswith(('jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG')): # input single img path + masks_img = [Image.open(mpath)] + elif mpath.endswith(('mp4', 'mov', 'avi', 'MP4', 'MOV', 'AVI')): # input video path + cap = cv2.VideoCapture(mpath) + if not cap.isOpened(): + print("Error: Could not open video.") + exit() + idx = 0 + while True: + ret, frame = cap.read() + if not ret: + break + masks_img.append(Image.fromarray(frame)) + idx += 1 + cap.release() + else: + files = [i for i in os.listdir(mpath) if i.endswith(('.npz', 'jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG'))] + mnames = sorted(files, key=lambda x: int(x.split('_')[1].split('.')[0])) + for mp in mnames: + if mp.endswith(('jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG')): + masks_img.append(Image.open(os.path.join(mpath, mp))) + else: + mask_array = np.where(np.load(os.path.join(mpath, mp))['mask'] == 255, 0, 1) + masks_img.append(Image.fromarray((mask_array).astype(np.uint8))) + return { + 'masks': masks_img + } \ No newline at end of file