Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 77 additions & 7 deletions diffueraser/diffueraser.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,35 @@ def read_mask(validation_mask, fps, n_total_frames, img_size, mask_dilation_iter

return masks, masked_images

def mask_process(np_masks, n_total_frames, img_size, mask_dilation_iter, frames):
masks = []
masked_images = []
# if not np.isclose(mask_fps, fps, rtol=1e-3): # allow 0.1% relative error
# raise ValueError("The frame rate of all input videos needs to be consistent.")
for idx, mask in enumerate(np_masks):
if idx >= n_total_frames:
break
if mask.size != img_size:
mask = mask.resize(img_size, Image.NEAREST)

m = np.array(mask.convert('L')) > 0
m = m.astype(np.uint8)
m = cv2.erode(m,
cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)),
iterations=1)
m = cv2.dilate(m,
cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)),
iterations=mask_dilation_iter)
mask = Image.fromarray(m * 255)
masks.append(mask)

masked_image = np.array(frames[idx])*(1-(np.array(mask)[:,:,np.newaxis].astype(np.float32)/255))
masked_image = Image.fromarray(masked_image.astype(np.uint8))
masked_images.append(masked_image)

return masks, masked_images


def read_priori(priori, fps, n_total_frames, img_size):
cap = cv2.VideoCapture(priori)
if not cap.isOpened():
Expand All @@ -143,10 +172,23 @@ def read_priori(priori, fps, n_total_frames, img_size):
idx += 1
cap.release()

os.remove(priori) # remove priori
# os.remove(priori) # remove priori

return prioris

def read_priori_list(priori_list, fps, n_total_frames, img_size):
prioris = []
for idx, priori in enumerate(priori_list):
if(idx >= n_total_frames):
break
img = Image.fromarray(priori[...,::-1])
if img.size != img_size:
img = img.resize(img_size)
prioris.append(img)
# os.remove(priori) # remove priori
return prioris


def read_video(validation_image, video_length, nframes, max_img_size):
vframes, aframes, info = torchvision.io.read_video(filename=validation_image, pts_unit='sec', end_pts=video_length) # RGB
fps = info['video_fps']
Expand Down Expand Up @@ -178,6 +220,33 @@ def read_video(validation_image, video_length, nframes, max_img_size):

return frames, fps, img_size, n_clip, n_total_frames

def clip_process(video_info, video_length, nframes, max_img_size):
frames, fps, n_clip, n_total_frames = video_info['frames'], video_info['fps'], video_info['n_frames'], video_info['n_frames']
n_total_frames = int(video_length * fps)
n_clip = int(np.ceil(n_total_frames/nframes))

max_size = max(frames[0].size)
if(max_size<256):
raise ValueError("The resolution of the uploaded video must be larger than 256x256.")
if(max_size>4096):
raise ValueError("The resolution of the uploaded video must be smaller than 4096x4096.")
if max_size>max_img_size:
ratio = max_size/max_img_size
ratio_size = (int(frames[0].size[0]/ratio),int(frames[0].size[1]/ratio))
img_size = (ratio_size[0]-ratio_size[0]%8, ratio_size[1]-ratio_size[1]%8)
resize_flag=True
elif (frames[0].size[0]%8==0) and (frames[0].size[1]%8==0):
img_size = frames[0].size
resize_flag=False
else:
ratio_size = frames[0].size
img_size = (ratio_size[0]-ratio_size[0]%8, ratio_size[1]-ratio_size[1]%8)
resize_flag=True
if resize_flag:
frames = resize_frames(frames, img_size)
img_size = frames[0].size

return frames, fps, img_size, n_clip, n_total_frames

class DiffuEraser:
def __init__(
Expand Down Expand Up @@ -215,7 +284,8 @@ def __init__(
tokenizer=self.tokenizer,
unet=self.unet_main,
brushnet=self.brushnet
).to(self.device, torch.float16)
).to(torch.float16)
self.pipeline = self.pipeline.to(self.device)
self.pipeline.scheduler = UniPCMultistepScheduler.from_config(self.pipeline.scheduler.config)
self.pipeline.set_progress_bar_config(disable=True)

Expand Down Expand Up @@ -246,24 +316,24 @@ def __init__(
self.num_inference_steps = checkpoints[ckpt][1]
self.guidance_scale = 0

def forward(self, validation_image, validation_mask, priori, output_path,
def forward(self, video_info, mask_info, priori, output_path,
max_img_size = 1280, video_length=2, mask_dilation_iter=4,
nframes=22, seed=None, revision = None, guidance_scale=None, blended=True):
validation_prompt = "" #
guidance_scale_final = self.guidance_scale if guidance_scale==None else guidance_scale

if (max_img_size<256 or max_img_size>1920):
raise ValueError("The max_img_size must be larger than 256, smaller than 1920.")

################ read input video ################
frames, fps, img_size, n_clip, n_total_frames = read_video(validation_image, video_length, nframes, max_img_size)
frames, fps, img_size, n_clip, n_total_frames = clip_process(video_info, video_length, nframes, max_img_size)
video_len = len(frames)

################ read mask ################
validation_masks_input, validation_images_input = read_mask(validation_mask, fps, video_len, img_size, mask_dilation_iter, frames)
# validation_masks_input, validation_images_input = read_mask(validation_mask, fps, video_len, img_size, mask_dilation_iter, frames)
validation_masks_input, validation_images_input = mask_process(mask_info['masks'], video_len, img_size, mask_dilation_iter, frames)

################ read priori ################
prioris = read_priori(priori, fps, n_total_frames, img_size)
prioris = read_priori_list(priori, fps, n_total_frames, img_size)

## recheck
n_total_frames = min(min(len(frames), len(validation_masks_input)), len(prioris))
Expand Down
79 changes: 62 additions & 17 deletions propainter/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torchvision
import gc


try:
from model.modules.flow_comp_raft import RAFT_bi
from model.recurrent_flow_completion import RecurrentFlowCompleteNet
Expand Down Expand Up @@ -96,10 +97,15 @@ def read_mask(mpath, frames_len, size, flow_mask_dilates=8, mask_dilates=5):
masks_img.append(Image.fromarray(frame))
idx += 1
cap.release()
else:
mnames = sorted(os.listdir(mpath))
else:
files = [i for i in os.listdir(mpath) if i.endswith(('.npz', 'jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG'))]
mnames = sorted(files, key=lambda x: int(x.split('_')[1].split('.')[0]))
for mp in mnames:
masks_img.append(Image.open(os.path.join(mpath, mp)))
if mp.endswith(('jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG')):
masks_img.append(Image.open(os.path.join(mpath, mp)))
else:
mask_array = np.where(np.load(os.path.join(mpath, mp))['mask'] == 255, 0, 1)
masks_img.append(Image.fromarray((mask_array).astype(np.uint8)))
# print(mp)

for mask_img in masks_img:
Expand Down Expand Up @@ -129,6 +135,41 @@ def read_mask(mpath, frames_len, size, flow_mask_dilates=8, mask_dilates=5):

return flow_masks, masks_dilated

def mask_process(np_masks, frames_len, size, flow_mask_dilates=8, mask_dilates=5):
masks_dilated = []
flow_masks = []

for idx,mask_img in enumerate(np_masks):
if(idx >= frames_len):
break
if size is not None:
mask_img = mask_img.resize(size, Image.NEAREST)
mask_img = np.array(mask_img.convert('L'))

# Dilate 8 pixel so that all known pixel is trustworthy
if flow_mask_dilates > 0:
flow_mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=flow_mask_dilates).astype(np.uint8)
else:
flow_mask_img = binary_mask(mask_img).astype(np.uint8)
# Close the small holes inside the foreground objects
# flow_mask_img = cv2.morphologyEx(flow_mask_img, cv2.MORPH_CLOSE, np.ones((21, 21),np.uint8)).astype(bool)
# flow_mask_img = scipy.ndimage.binary_fill_holes(flow_mask_img).astype(np.uint8)
flow_masks.append(Image.fromarray(flow_mask_img * 255))

if mask_dilates > 0:
mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=mask_dilates).astype(np.uint8)
else:
mask_img = binary_mask(mask_img).astype(np.uint8)
masks_dilated.append(Image.fromarray(mask_img * 255))

if len(np_masks) == 1:
flow_masks = flow_masks * frames_len
masks_dilated = masks_dilated * frames_len

return flow_masks, masks_dilated



def get_ref_index(mid_neighbor_id, neighbor_ids, length, ref_stride=10, ref_num=-1):
ref_index = []
if ref_num == -1:
Expand Down Expand Up @@ -172,7 +213,7 @@ def __init__(
model_dir=propainter_model_dir, progress=True, file_name=None)
self.model = InpaintGenerator(model_path=ckpt_path).to(device)
self.model.eval()
def forward(self, video, mask, output_path, resize_ratio=0.6, video_length=2, height=-1, width=-1,
def forward(self, video_info, mask_info, output_path, resize_ratio=0.6, video_length=2, height=-1, width=-1,
mask_dilation=4, ref_stride=10, neighbor_length=10, subvideo_length=80,
raft_iter=20, save_fps=24, save_frames=False, fp16=True):

Expand All @@ -182,7 +223,7 @@ def forward(self, video, mask, output_path, resize_ratio=0.6, video_length=2, he
use_half = False

################ read input video ################
frames, fps, size, video_name, nframes = read_frame_from_videos(video, video_length)
frames, fps, size, video_name, nframes = video_info['frames'], video_info['fps'], video_info['size'], video_info['video_name'], video_info['n_frames']
frames = frames[:nframes]
if not width == -1 and not height == -1:
size = (width, height)
Expand All @@ -199,7 +240,7 @@ def forward(self, video, mask, output_path, resize_ratio=0.6, video_length=2, he

################ read mask ################
frames_len = len(frames)
flow_masks, masks_dilated = read_mask(mask, frames_len, size,
flow_masks, masks_dilated = mask_process(mask_info['masks'], frames_len, size,
flow_mask_dilates=mask_dilation,
mask_dilates=mask_dilation)
flow_masks = flow_masks[:nframes]
Expand Down Expand Up @@ -490,18 +531,19 @@ def forward(self, video, mask, output_path, resize_ratio=0.6, video_length=2, he

torch.cuda.empty_cache()

##save composed video##
# ##save composed video##
comp_frames = [cv2.resize(f, out_size) for f in comp_frames]
writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"),
fps, (comp_frames[0].shape[1],comp_frames[0].shape[0]))
for f in range(video_length):
frame = comp_frames[f].astype(np.uint8)
writer.write(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
writer.release()
# writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"),
# fps, (comp_frames[0].shape[1],comp_frames[0].shape[0]))
# for f in range(video_length):
# frame = comp_frames[f].astype(np.uint8)
# writer.write(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
# writer.release()

torch.cuda.empty_cache()
# torch.cuda.empty_cache()

return output_path
# return output_path
return comp_frames



Expand All @@ -511,10 +553,13 @@ def forward(self, video, mask, output_path, resize_ratio=0.6, video_length=2, he
propainter_model_dir = "weights/propainter"
propainter = Propainter(propainter_model_dir, device=device)

video = "examples/example1/video.mp4"
mask = "examples/example1/mask.mp4"
video = "/mnt6/jinruh/data/moviiDB/long-context-following/friends/clips/Friends.S01E01.1080p.BluRay.x265-RARBG/Friends.S01E01.1080p.BluRay.x265-RARBG_1261.13_1263.30.mp4"
# mask = "examples/example1/mask.mp4"
mask = "/mnt6/zengzhuo/projects/VideoDataProcess/split/video_clips/friends/mask_no_limit/Friends.S01E01.1080p.BluRay.x265-RARBG/Friends.S01E01.1080p.BluRay.x265-RARBG_1261.13_1263.30"
output = "results/priori.mp4"
res = propainter.forward(video, mask, output)
import pdb; pdb.set_trace()
print(res)



Loading