-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathconcept_guided_purification.py
75 lines (57 loc) · 2.68 KB
/
concept_guided_purification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import requests
import torch
from PIL import Image
from io import BytesIO
import argparse
import itertools
import math
import os
import random
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.utils.data import Dataset
torch.manual_seed(39)
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
from diffusers import StableDiffusionImg2ImgPipeline
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers import DPMSolverMultistepScheduler
use_liif = ""
for data_name in ["Cosal2015", "iCoseg", "CoCA", "CoSOD3k"]:
for strength in [0.25]:
for distortion in ["jadena"]:
version = "{}_224_mixdataset_ratio0.5_768".format(data_name)
dataset_root = "dataset/{}/img_{}{}".format(version, distortion, use_liif)
learned_concept_root = "textual_models/{}_object/img_{}".format(version,distortion)
guidance_scale = 7.5
target_concept_root = "{}_CosalPure".format(dataset_root,strength,guidance_scale)
import requests
import glob
from io import BytesIO
folders = os.listdir(dataset_root)
folder_num = len(folders)
for i in range(folder_num):
folder_name = folders[i]
target_concept_path = os.path.join(target_concept_root,folder_name)
if not os.path.exists(target_concept_path):
os.makedirs(target_concept_path)
dataset_folder = os.path.join(dataset_root,folder_name)
placeholder_token = "my_{}".format(folder_name)
model_root = "{}/{}/model".format(learned_concept_root,folder_name)
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_root,
scheduler=DPMSolverMultistepScheduler.from_pretrained(model_root, subfolder="scheduler"),
torch_dtype=torch.float16)
pipe = pipe.to("cuda")
img_names = os.listdir(dataset_folder)
len_imgs = len(img_names)
for img_name in img_names:
image_inp = Image.open(os.path.join(dataset_folder, img_name)).convert("RGB")
prompt_concept = "{}".format(placeholder_token)
images = pipe(prompt=prompt_concept,
image=image_inp,
strength=strength,
guidance_scale=guidance_scale,
).images[0]
images.save(os.path.join(target_concept_path, img_name))