Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deoldify extension #11

Merged
merged 5 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions extensions/arifScratchRemoverWebUIExtention/arifScretchRemover.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from pipeline_stable_diffusion_controlnet_inpaint import *
from scratch_detection import ScratchDetection

from PIL import Image
import cv2
import glob
import os

device = "cuda"


def resize_image(image, target_size):
width, height = image.size
aspect_ratio = float(width) / float(height)
if width > height:
new_width = target_size
new_height = int(target_size / aspect_ratio)
else:
new_width = int(target_size * aspect_ratio)
new_height = target_size
return image.resize((new_width, new_height), Image.BICUBIC)

def remove_all_file_in_dir(folder):
#'/YOUR/PATH/*'
files = glob.glob(folder)
for f in files:
os.remove(f)


def generate_scratch_mask():
# Save the input image to a directory
pngExt = '.png'
jpgExt = '.jpg'
fileName = 'auny'
image_dir = 'Arif'
image_name = fileName+pngExt


image_full_dir = (f'{image_dir}/{image_name}')
img_p = (image_full_dir)
input_image = PIL.Image.open(img_p).convert('RGB')

input_path = "input_images"
output_dir = "output_masks"

remove_all_file_in_dir(folder=("%s/*"%input_path))
input_image_path = (f'{input_path}/{image_name}')
#input_image_resized = resize_image(input_image, 768)
input_image.save(input_image_path)


scratch_detector = ScratchDetection(input_path, output_dir, input_size="scale_256", gpu=0)
scratch_detector.run()
mask_image = scratch_detector.get_mask_image((fileName+pngExt))

# Resize the mask to match the input image size
mask_image = mask_image.resize(input_image.size, Image.BICUBIC)

# Apply dilation to make the lines bigger
kernel = np.ones((5, 5), np.uint8)
mask_image_np = np.array(mask_image)
mask_image_np_dilated = cv2.dilate(mask_image_np, kernel, iterations=2)
mask_image_dilated = Image.fromarray(mask_image_np_dilated)


return mask_image_dilated

# window_name = 'Output Image'
# cv2.imshow(window_name, mask_image_dilated)
# cv2.waitKey(0)
# cv2.destroyAllWindows()

# return mask_image_dilated


# generate_scratch_mask()

filename = os.path.splitext("auny.png")[0]
print(filename)
24 changes: 24 additions & 0 deletions extensions/arifScratchRemoverWebUIExtention/arif_install.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import subprocess
import os
def runcmd(cmd, verbose = False, *args, **kwargs):

process = subprocess.Popen(
cmd,
stdout = subprocess.PIPE,
stderr = subprocess.PIPE,
text = True,
shell = True
)
std_out, std_err = process.communicate()
if verbose:
print(std_out.strip(), std_err)
pass

def downloadScratchRemoverModel():
curDir = os.getcwd()
command_str = "wget https://www.dropbox.com/s/5jencqq4h59fbtb/FT_Epoch_latest.pt" + " -P " + curDir +"/extensions/arifScratchRemoverWebUIExtention/"
runcmd(command_str, verbose=True)


#runcmd("apt-get update && apt-get install libgl1", verbose = True)

144 changes: 144 additions & 0 deletions extensions/arifScratchRemoverWebUIExtention/dataset/CamVid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import torch
import glob
import os
from torchvision import transforms
import cv2
from PIL import Image
import pandas as pd
import numpy as np
from imgaug import augmenters as iaa
import imgaug as ia
from utils import get_label_info, one_hot_it, RandomCrop, reverse_one_hot, one_hot_it_v11, one_hot_it_v11_dice
import random

def augmentation():
# augment images with spatial transformation: Flip, Affine, Rotation, etc...
# see https://github.com/aleju/imgaug for more details
pass


def augmentation_pixel():
# augment images with pixel intensity transformation: GaussianBlur, Multiply, etc...
pass

class CamVid(torch.utils.data.Dataset):
def __init__(self, image_path, label_path, csv_path, scale, loss='dice', mode='train'):
super().__init__()
self.mode = mode
self.image_list = []
if not isinstance(image_path, list):
image_path = [image_path]
for image_path_ in image_path:
self.image_list.extend(glob.glob(os.path.join(image_path_, '*.png')))
self.image_list.sort()
self.label_list = []
if not isinstance(label_path, list):
label_path = [label_path]
for label_path_ in label_path:
self.label_list.extend(glob.glob(os.path.join(label_path_, '*.png')))
self.label_list.sort()
# self.image_name = [x.split('/')[-1].split('.')[0] for x in self.image_list]
# self.label_list = [os.path.join(label_path, x + '_L.png') for x in self.image_list]
self.fliplr = iaa.Fliplr(0.5)
self.label_info = get_label_info(csv_path)
# resize
# self.resize_label = transforms.Resize(scale, Image.NEAREST)
# self.resize_img = transforms.Resize(scale, Image.BILINEAR)
# normalization
self.to_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
# self.crop = transforms.RandomCrop(scale, pad_if_needed=True)
self.image_size = scale
self.scale = [0.5, 1, 1.25, 1.5, 1.75, 2]
self.loss = loss

def __getitem__(self, index):
# load image and crop
seed = random.random()
img = Image.open(self.image_list[index])
# random crop image
# =====================================
# w,h = img.size
# th, tw = self.scale
# i = random.randint(0, h - th)
# j = random.randint(0, w - tw)
# img = F.crop(img, i, j, th, tw)
# =====================================

scale = random.choice(self.scale)
scale = (int(self.image_size[0] * scale), int(self.image_size[1] * scale))

# randomly resize image and random crop
# =====================================
if self.mode == 'train':
img = transforms.Resize(scale, Image.BILINEAR)(img)
img = RandomCrop(self.image_size, seed, pad_if_needed=True)(img)
# =====================================

img = np.array(img)
# load label
label = Image.open(self.label_list[index])


# crop the corresponding label
# =====================================
# label = F.crop(label, i, j, th, tw)
# =====================================

# randomly resize label and random crop
# =====================================
if self.mode == 'train':
label = transforms.Resize(scale, Image.NEAREST)(label)
label = RandomCrop(self.image_size, seed, pad_if_needed=True)(label)
# =====================================

label = np.array(label)


# augment image and label
if self.mode == 'train':
seq_det = self.fliplr.to_deterministic()
img = seq_det.augment_image(img)
label = seq_det.augment_image(label)


# image -> [C, H, W]
img = Image.fromarray(img)
img = self.to_tensor(img).float()

if self.loss == 'dice':
# label -> [num_classes, H, W]
label = one_hot_it_v11_dice(label, self.label_info).astype(np.uint8)

label = np.transpose(label, [2, 0, 1]).astype(np.float32)
# label = label.astype(np.float32)
label = torch.from_numpy(label)

return img, label

elif self.loss == 'crossentropy':
label = one_hot_it_v11(label, self.label_info).astype(np.uint8)
# label = label.astype(np.float32)
label = torch.from_numpy(label).long()

return img, label

def __len__(self):
return len(self.image_list)


if __name__ == '__main__':
# data = CamVid('/path/to/CamVid/train', '/path/to/CamVid/train_labels', '/path/to/CamVid/class_dict.csv', (640, 640))
data = CamVid(['/data/sqy/CamVid/train', '/data/sqy/CamVid/val'],
['/data/sqy/CamVid/train_labels', '/data/sqy/CamVid/val_labels'], '/data/sqy/CamVid/class_dict.csv',
(720, 960), loss='crossentropy', mode='val')
from model.build_BiSeNet import BiSeNet
from utils import reverse_one_hot, get_label_info, colour_code_segmentation, compute_global_accuracy

label_info = get_label_info('/data/sqy/CamVid/class_dict.csv')
for i, (img, label) in enumerate(data):
print(label.size())
print(torch.max(label))

Empty file.
81 changes: 81 additions & 0 deletions extensions/arifScratchRemoverWebUIExtention/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import cv2
import argparse
from model.build_BiSeNet import BiSeNet
import os
import torch
import cv2
from imgaug import augmenters as iaa
from PIL import Image
from torchvision import transforms
import numpy as np
from utils import reverse_one_hot, get_label_info, colour_code_segmentation

def predict_on_image(model, args):
# pre-processing on image
image = cv2.imread(args.data, -1)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
resize = iaa.Scale({'height': args.crop_height, 'width': args.crop_width})
resize_det = resize.to_deterministic()
image = resize_det.augment_image(image)
image = Image.fromarray(image).convert('RGB')
image = transforms.ToTensor()(image)
image = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))(image).unsqueeze(0)
# read csv label path
label_info = get_label_info(args.csv_path)
# predict
model.eval()
predict = model(image).squeeze()
predict = reverse_one_hot(predict)
predict = colour_code_segmentation(np.array(predict), label_info)
predict = cv2.resize(np.uint8(predict), (960, 720))
cv2.imwrite(args.save_path, cv2.cvtColor(np.uint8(predict), cv2.COLOR_RGB2BGR))

def main(params):
# basic parameters
parser = argparse.ArgumentParser()
parser.add_argument('--image', action='store_true', default=False, help='predict on image')
parser.add_argument('--video', action='store_true', default=False, help='predict on video')
parser.add_argument('--checkpoint_path', type=str, default=None, help='The path to the pretrained weights of model')
parser.add_argument('--context_path', type=str, default="resnet101", help='The context path model you are using.')
parser.add_argument('--num_classes', type=int, default=12, help='num of object classes (with void)')
parser.add_argument('--data', type=str, default=None, help='Path to image or video for prediction')
parser.add_argument('--crop_height', type=int, default=720, help='Height of cropped/resized input image to network')
parser.add_argument('--crop_width', type=int, default=960, help='Width of cropped/resized input image to network')
parser.add_argument('--cuda', type=str, default='0', help='GPU ids used for training')
parser.add_argument('--use_gpu', type=bool, default=True, help='Whether to user gpu for training')
parser.add_argument('--csv_path', type=str, default=None, required=True, help='Path to label info csv file')
parser.add_argument('--save_path', type=str, default=None, required=True, help='Path to save predict image')


args = parser.parse_args(params)

# build model
os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda
model = BiSeNet(args.num_classes, args.context_path)
if torch.cuda.is_available() and args.use_gpu:
model = torch.nn.DataParallel(model).cuda()

# load pretrained model if exists
print('load model from %s ...' % args.checkpoint_path)
model.module.load_state_dict(torch.load(args.checkpoint_path))
print('Done!')

# predict on image
if args.image:
predict_on_image(model, args)

# predict on video
if args.video:
pass

if __name__ == '__main__':
params = [
'--image',
'--data', 'exp.png',
'--checkpoint_path', '/path/to/ckpt',
'--cuda', '0',
'--csv_path', '/data/sqy/CamVid/class_dict.csv',
'--save_path', 'demo.png',
'--context_path', 'resnet18'
]
main(params)
Empty file.
Loading
Loading