Skip to content

Commit

Permalink
updated scratch extension and added new api which is simple and retur…
Browse files Browse the repository at this point in the history
…ns image url
  • Loading branch information
jahangir091 committed Mar 4, 2024
1 parent 603a4a6 commit 0c77e9f
Show file tree
Hide file tree
Showing 2 changed files with 327 additions and 0 deletions.
89 changes: 89 additions & 0 deletions extensions/arifScratchRemoverWebUIExtention/scratch_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,76 @@ def process_images(test_path, output_dir, input_size="scale_256", gpu=0):
gc.collect()
torch.cuda.empty_cache()


def process_images_new(scratched_image, input_size="scale_256", gpu=0):
print("initializing the dataloader")

# Initialize the model
model = networks.UNet(
in_channels=1,
out_channels=1,
depth=4,
conv_num=2,
wf=6,
padding=True,
batch_norm=True,
up_mode="upsample",
with_tanh=False,
sync_bn=True,
antialiasing=True,
)

## load model
checkpoint_path = os.path.join(os.path.dirname(__file__), "FT_Epoch_latest.pt")
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(checkpoint["model_state"])
print("model weights loaded")

if gpu >= 0:
model.to(gpu)
else:
model.cpu()
model.eval()

transformed_image_PIL = data_transforms(scratched_image, input_size)
scratch_image = transformed_image_PIL.convert("L")
scratch_image = tv.transforms.ToTensor()(scratch_image)
scratch_image = tv.transforms.Normalize([0.5], [0.5])(scratch_image)
scratch_image = torch.unsqueeze(scratch_image, 0)
_, _, ow, oh = scratch_image.shape
scratch_image_scale = scale_tensor(scratch_image)

if gpu >= 0:
scratch_image_scale = scratch_image_scale.to(gpu)
else:
scratch_image_scale = scratch_image_scale.cpu()
with torch.no_grad():
P = torch.sigmoid(model(scratch_image_scale))

P = P.data.cpu()
P = F.interpolate(P, [ow, oh], mode="nearest")
gc.collect()
torch.cuda.empty_cache()
import torchvision.transforms as transforms
from PIL import Image
transform = transforms.ToPILImage()
pil_mask = transform(P.squeeze())
return transformed_image_PIL, pil_mask
# tv.utils.save_image(
# (P >= 0.4).float(),
# os.path.join(
# output_dir,
# image_name[:-4] + ".png",
# ),
# nrow=1,
# padding=0,
# normalize=True,
# )
# transformed_image_PIL.save(os.path.join(input_dir, image_name[:-4] + ".png"))
# gc.collect()
# torch.cuda.empty_cache()


# Wrap the scratch detection in a class
class ScratchDetection:
def __init__(self, test_path, output_dir, input_size="scale_256", gpu=0):
Expand All @@ -169,6 +239,24 @@ def get_mask_image(self, image_name):
mask_image_path = os.path.join(self.output_dir, "mask", image_name)
return Image.open(mask_image_path)


class ScratchDetectionNew:
def __init__(self, scratched_image, input_size="scale_256", gpu=0):
self.scratched_image = scratched_image
self.input_size = input_size
self.gpu = gpu

def run(self):
pil_image, pil_mask = process_images_new(self.scratched_image, self.input_size, self.gpu)
return pil_image, pil_mask

# Add a function to get the mask image from the output directory
def get_mask_image(self, image_name):
mask_image_path = os.path.join(self.output_dir, "mask", image_name)
return Image.open(mask_image_path)



# Keep the __main__ part, but modify it to use the new ScratchDetection class
if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand All @@ -180,3 +268,4 @@ def get_mask_image(self, image_name):

scratch_detector = ScratchDetection(args.test_path, args.output_dir, args.input_size, args.GPU)
scratch_detector.run()

238 changes: 238 additions & 0 deletions extensions/arifScratchRemoverWebUIExtention/scripts/api2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
from fastapi import FastAPI, Body, HTTPException
from fastapi.responses import RedirectResponse, FileResponse
from fastapi import File, UploadFile, Form
import gradio as gr
import time
from datetime import datetime, timezone
from pipeline_stable_diffusion_controlnet_inpaint import *
from scratch_detection import ScratchDetection, ScratchDetectionNew
from diffusers import ControlNetModel, DEISMultistepScheduler
from arif_install import downloadScratchRemoverModel
from PIL import Image
import cv2
import glob
import shutil
import os
from os.path import exists
import subprocess
import base64
from io import BytesIO
import numpy
from sympy import true, false

# new
from modules.api.models import *
from modules.api import api
from modules.api import models
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, \
shared_items, postprocessing
from typing_extensions import Literal

import uuid


def get_img_path(directory_name):
current_dir = '/tmp'
img_directory = current_dir + '/.temp' + directory_name
os.makedirs(img_directory, exist_ok=True)
img_file_name = uuid.uuid4().hex[:20] + '.jpg'
return img_directory + img_file_name


device = "cuda"

# load control net and stable diffusion v1-5
controlnet = ControlNetModel.from_pretrained("thepowefuldeez/sd21-controlnet-canny", torch_dtype=torch.float16)

pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting", controlnet=controlnet, torch_dtype=torch.float16
)

pipe.scheduler = DEISMultistepScheduler.from_config(pipe.scheduler.config)

# speed up diffusion process with faster scheduler and memory optimization
# remove following line if xformers is not installed
# pipe.enable_xformers_memory_efficient_attention()
pipe.to('cuda')


def scratch_remove_api(_: gr.Blocks, app: FastAPI):
@app.post('/sdapi/ai/v1/scratch_remove_new')
async def generate_mask_image(
input_image: str = Body("", title='scratch remove input image'),
upscale: bool = Body(False, title='input image name')
):
start_time = time.time()

downloadScratchRemoverModelModel()
pil_image = api.decode_base64_to_image(input_image).convert("RGB")
out_image = remove_scratch_using_mask(pil_image, upscale)

out_images_directory_name = '/scratch_images/'

out_image_path = get_img_path(out_images_directory_name)
out_image.save(out_image_path)

return {
"success": True,
"message": "Returned output successfully",
"server_process_time": time.time() - start_time,
"output_image_urls": '/media' + out_images_directory_name + out_image_path.split('/')[-1]
}

def remove_scratch_using_mask(source_image: Image, upscale: bool):
start_time = time.time()
# curDir = os.getcwd()
# fileName = "arif.png"

# input_path = curDir + "/extensions/arifScratchRemoverWebUIExtention/input_images"
# output_dir = curDir + "/extensions/arifScratchRemoverWebUIExtention/output_masks"

# remove previous image from output directory
# remove_all_file_in_dir(folder=("%s/*" % input_path))
# remove_all_file_in_dir(folder=(curDir + "/extensions/arifScratchRemoverWebUIExtention/output_masks/mask/*"))
# remove_all_file_in_dir(folder=(curDir + "/extensions/arifScratchRemoverWebUIExtention/output_masks/input/*"))

# Save the input image to a directory
# source_file_location = input_path + "/" + fileName
# image = source_image.save(f"{source_file_location}")

scratch_detector = ScratchDetectionNew(source_image, input_size="scale_256", gpu=0)
main_image, mask_image = scratch_detector.run()
# mask_image = scratch_detector.get_mask_image(fileName)

# Resize the mask to match the input image size
mask_image = mask_image.resize(mask_image.size, Image.BICUBIC)

# Apply dilation to make the lines bigger
kernel = np.ones((5, 5), np.uint8)
mask_image_np = np.array(mask_image)
mask_image_np_dilated = cv2.dilate(mask_image_np, kernel, iterations=2)
mask_image_dilated = Image.fromarray(mask_image_np_dilated)

##scratck removing
# main_image_dir = curDir + "/extensions/arifScratchRemoverWebUIExtention/output_masks/input/" + fileName
# main_image = Image.open(main_image_dir).convert("RGB")
main_image = main_image.convert("RGB")
main_image = resize_image(main_image, 768)

main_mask = mask_image_dilated
main_mask = resize_image(main_mask, 768)

image = np.array(main_image)
low_threshold = 100
high_threshold = 200
canny = cv2.Canny(image, low_threshold, high_threshold)
canny = canny[:, :, None]
canny = np.concatenate([canny, canny, canny], axis=2)
canny_image = Image.fromarray(canny)
generator = torch.manual_seed(0)

without_scratch_Image_output = pipe(
prompt="",
num_inference_steps=20,
generator=generator,
image=main_image,
control_image=canny_image,
controlnet_conditioning_scale=0,
mask_image=main_mask
).images[0]
# return base64 image
if upscale == False:
return without_scratch_Image_output

args = scripts.scripts_postproc.create_args_for_run({
"Upscale": {
"upscale_mode": 0,
"upscale_by": 1,
"upscale_to_width": 512,
"upscale_to_height": 512,
"upscale_crop": True,
"upscaler_1_name": "R-ESRGAN 4x+",
"upscaler_2_name": "None",
"upscaler_2_visibility": 0,
},
"GFPGAN": {
"gfpgan_visibility": 1,
},
"CodeFormer": {
"codeformer_visibility": 0.75,
"codeformer_weight": 0,
},
})

result = postprocessing.run_postprocessing(0, without_scratch_Image_output, "", "", "", True, *args,
save_output=False)
# img_str = api.encode_pil_to_base64(result[0][0])
print(result[0][0])
return result[0][0]

def save_file(file: UploadFile, path: str):
with open(path, "wb+") as file_object:
shutil.copyfileobj(file.file, file_object)

def resize_image(image, target_size):
width, height = image.size
aspect_ratio = float(width) / float(height)
if width > height:
new_width = target_size
new_height = int(target_size / aspect_ratio)
else:
new_width = int(target_size * aspect_ratio)
new_height = target_size
return image.resize((new_width, new_height), Image.BICUBIC)

def remove_all_file_in_dir(folder):
# '/YOUR/PATH/*'
files = glob.glob(folder)
for f in files:
os.remove(f)

def downloadScratchRemoverModelModel():
curDir = os.getcwd()
model_name = "FT_Epoch_latest.pt"
model_dir = curDir + "/extensions/arifScratchRemoverWebUIExtention/"
model_path = model_dir + model_name

if exists(model_path):
print("model already downloaded")
return
else:
command_str = "wget https://www.dropbox.com/s/5jencqq4h59fbtb/FT_Epoch_latest.pt" + " -P " + model_dir
runcmd(command_str, verbose=True)
print("model downloaded done")

# new
upscaleDir = curDir + "/extensions/arifScratchRemoverWebUIExtention/Bringing-Old-Photos-Back-to-Life/"
check_file = "/extensions/arifScratchRemoverWebUIExtention/Bringing-Old-Photos-Back-to-Life/Global/global_checkpoints.zip"
if exists(check_file):
print("all MS upscale already downloaded")
return
else:
shDir = upscaleDir + "download-weights.sh"
command_str = "sudo chmod +x " + shDir + " " + upscaleDir
runcmd(command_str, verbose=True)
print("all MS scr model downloaded done")

def runcmd(cmd, verbose=False, *args, **kwargs):

process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
shell=True
)
std_out, std_err = process.communicate()
if verbose:
print(std_out.strip(), std_err)
pass


try:
import modules.script_callbacks as script_callbacks

script_callbacks.on_app_started(scratch_remove_api)

except:
pass

0 comments on commit 0c77e9f

Please sign in to comment.