Skip to content

Commit

Permalink
Update threaded-main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
TheElevatedOne committed Jun 11, 2024
1 parent 2a93893 commit f236318
Showing 1 changed file with 32 additions and 4 deletions.
36 changes: 32 additions & 4 deletions threaded-main.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
from tagger.interrogator import Interrogator, WaifuDiffusionInterrogator
from PIL import Image
from pathlib import Path
from send2trash import send2trash
import argparse
import multiprocessing
import os
import os.path as op
import time
import statistics

from tagger.interrogators import interrogators


class Iterrogate():
def __init__(self):
self.start_time = time.time()
self.thread_time = []
self.position = 0
self.args = self.parse()
self.interrogator = interrogators[self.args.model]

if self.args.ext not in [".txt", ".caption"]:
raise ValueError(f'"{self.args.ext}" is not a valid caption file extension')
if self.args.cpu:
self.interrogator.use_cpu()
if self.args.dir is not None:
Expand All @@ -39,6 +47,7 @@ def parse(self):
parser.add_argument(
'--ext',
default='.txt',
choices=[".txt", ".caption"],
help='Extension to add to caption file in case of dir option (default is .txt)')
parser.add_argument(
'--overwrite',
Expand Down Expand Up @@ -74,6 +83,11 @@ def parse(self):
help='Ppecify the number of threads you want to run it with (multithreading)')
return parser.parse_args()

def time_taken(self):
if len(self.thread_time) > 0:
final_time = round(statistics.fmean(self.thread_time) - self.start_time, 2)
print(f"Time taken: {final_time} s")

def image_interrogate(self, image_path: Path):
"""
Predictions from a image path
Expand All @@ -88,11 +102,19 @@ def image_interrogate(self, image_path: Path):

def chunks(self, lst, n) -> list:
h = len(lst)//n
return [lst[i:i+h] for i in range(0, len(lst), h)]
temp_chunk = [lst[i:i+h] for i in range(0, len(lst), h)]
if len(temp_chunk[-1]) == 1 and len(temp_chunk) > 1:
temp_chunk[-2].append(temp_chunk[-1][0])
temp_chunk.pop()
return temp_chunk
return temp_chunk

def dir_thread_main(self):
d = op.abspath(self.args.dir)
if self.args.overwrite is not None:
[send2trash(op.join(d, x)) for x in os.listdir(d) if (".txt" in x) or (".caption" in x)]
q = self.chunks(os.listdir(d), int(self.args.threads))
self.position = len(q)
jobs = []

for i, s in enumerate(q):
Expand All @@ -118,18 +140,23 @@ def directory_iter(self, job, chunk):

caption_path = op.join(d, f.replace(f'.{f.split(".")[-1]}', "")+self.args.ext)

if op.isfile(caption_path) and not self.args.overwrite:
if op.isfile(caption_path):
# skip if file exists
print('skip:', image_path)
continue

print(f'processing: {image_path} | {"0"*(len(str(len(chunk)))-len(str(chunk.index(f)))+1)+str(chunk.index(f))}/{len(chunk)} | Job: {job}')
print(f'processing: {image_path} | {"0"*(len(str(len(chunk)))-len(str(chunk.index(f)+1)))+str(chunk.index(f)+1)}/{len(chunk)} | Job: {job}')
tags = self.image_interrogate(Path(image_path))

tags_str = self.additional_tags(self.args.prepend, True)+(', '.join(tags.keys()))+self.additional_tags(self.args.append, False)

with open(caption_path, 'w') as fp:
fp.write(tags_str)
if f == chunk[-1]:
time.sleep(job/4)
self.thread_time.append(time.time())
if self.position == (job+1):
self.time_taken()


def file_iter(self):
tags = self.image_interrogate(Path(self.args.file))
Expand All @@ -140,3 +167,4 @@ def file_iter(self):

if __name__ == "__main__":
Iterrogate()

0 comments on commit f236318

Please sign in to comment.