diff --git a/.gitignore b/.gitignore index 8115b62..6b071ed 100644 --- a/.gitignore +++ b/.gitignore @@ -31,4 +31,5 @@ yarn-error.log* manga-translator-service-account-data.json exa/* local/* -trans-test/* \ No newline at end of file +trans-test/* +*.env \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 74438b0..ce2bc38 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,57 +1,108 @@ # Build ui using node -FROM node:18.12.1 +# FROM node:18.12.1 -COPY package.json . -COPY package-lock.json . -COPY public public -COPY src src -COPY tsconfig.json . -COPY .eslintrc.json . +# WORKDIR /app -RUN npm install +# COPY ui ui -RUN npm run build +# WORKDIR /app/ui + +# RUN npm install + +# RUN npm run build # Use the NVIDIA CUDA base image FROM nvidia/cuda:11.7.1-runtime-ubuntu20.04 -COPY --from=0 build build +WORKDIR /app + +# COPY --from=0 /app/ui/build /ui/build # Set the working directory to /app #WORKDIR /app # Update package lists and install required packages -RUN DEBIAN_FRONTEND=noninteractive apt-get update && \ - apt-get install -y python3.9 python3-pip -# # Make sure Python 3.9 is the default python3 version -# RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 1 +# RUN apt-get update +# RUN apt-get remove python +# RUN apt-get remove python-pip +# RUN apt-get -y install software-properties-common +# RUN add-apt-repository ppa:deadsnakes/ppa +# RUN apt-get -y install python3.9 +# RUN apt-get -y install python3-pip -# Create a symbolic link for pip (optional) -RUN ln -s /usr/bin/pip3 /usr/bin/pip +COPY translator translator +COPY server.py . +COPY fonts fonts +COPY models models +COPY requirements.txt . + +# # Install base utilities +# RUN apt-get update \ +# && apt-get install -y build-essential \ +# && apt-get install -y wget \ +# && apt-get clean \ +# && rm -rf /var/lib/apt/lists/* + +# # Install miniconda +# ENV CONDA_DIR /opt/conda +# RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh && \ +# /bin/bash ~/miniconda.sh -b -p /opt/conda -# Verify Python and pip versions -RUN python3 --version && pip --version +# # Put conda in path so we can use conda activate +# ENV PATH=$CONDA_DIR/bin:$PATH # # Create symbolic links to set Python 3.9 as the default # RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 1 && \ # update-alternatives --config python3 # RUN apt-get update && apt-get install -y python3.9 python3.9-dev +# RUN conda create -n translator python=3.9 -y -COPY translator translator -COPY server.py . -COPY fonts fonts -COPY models models -COPY requirements.txt . +# SHELL ["conda", "run", "-n", "translator", "/bin/bash", "-c"] + +# RUN pip install -r requirements.txt +# RUN pip uninstall -y torch torchvision torchaudio +# RUN pip install opencv-python +# RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117 + +# RUN apt-get update && apt-get install -y --no-install-recommends \ +# software-properties-common \ +# libsm6 libxext6 ffmpeg libfontconfig1 libxrender1 libgl1-mesa-glx \ +# curl python3-pip + +# RUN pip3 install --upgrade pip + +# RUN pip3 install -r requirements.txt +# RUN pip3 uninstall -y torch torchvision torchaudio +# RUN pip3 install opencv-python +# RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117 + +RUN apt-get update && apt-get install -y --no-install-recommends \ + software-properties-common \ + libsm6 libxext6 ffmpeg libfontconfig1 libxrender1 libgl1-mesa-glx \ + curl python3-pip + +RUN curl https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -o ~/miniconda.sh \ + && sh ~/miniconda.sh -b -p /opt/conda \ + && rm ~/miniconda.sh + +ENV PATH /opt/conda/bin:$PATH + +RUN conda update -n base -c defaults conda + +COPY conda.yml conda.yml + +RUN conda env create -f conda.yml --name translator +RUN conda activate translator -RUN pip install -r requirements.txt +# RUN python3.9 -m pip install -r requirements.txt -RUN pip uninstall -y torch torchvision torchaudio +# RUN python3.9 -m -RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117 +# RUN python3.9 -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117 -CMD ["python","server.py"] +CMD ["python3","server.py"] diff --git a/extractor.py b/extractor.py index 10d47f9..e1d17da 100644 --- a/extractor.py +++ b/extractor.py @@ -4,11 +4,13 @@ from typing import Any import os import re +import sys +from importlib.resources import is_resource from typing import Union import pkg_resources from functools import cmp_to_key INSTALLED_PACKAGES = {pkg.key for pkg in pkg_resources.working_set} -base_dir = "./" +BASE_DIR = os.getcwd() # This file is an attempt to extract a given function or name from a module with all its depencencies. it has trouble with relative imports and needs more work @@ -213,35 +215,12 @@ def get_other_refs(tree: ast.AST,exclude: list[str]) -> dict[str,ast.AST]: deps[node_name] = node return deps - -def try_build_module_path(current: str,remaining: list[str],max_lookaheads = 1) -> Union[str,None]: - - if len(remaining) == 0: - if os.path.exists(current) and os.path.isfile(current): - return current - return None - - new_path = os.path.join(current,remaining[0]) - - if os.path.exists(new_path): - return try_build_module_path(new_path,remaining[1:]) - else: - for i in range(max_lookaheads): - delta = i + 1 - if len(remaining) <= delta: - break - - test_path = os.path.join(current,remaining[delta]) - if os.path.exists(delta): - return try_build_module_path(test_path,remaining[delta:]) - - return None paths_cache: dict[str,str] = {} def module_to_file_path(cur_file_path: str,module: str) -> str: - global base_dir + global BASE_DIR global paths_cache global INSTALLED_PACKAGES @@ -249,7 +228,7 @@ def module_to_file_path(cur_file_path: str,module: str) -> str: if cache_key in paths_cache: return paths_cache[cache_key] - start_path = base_dir + start_path = BASE_DIR if module.split('.')[0] in INSTALLED_PACKAGES: return None @@ -443,17 +422,11 @@ def comp_a_b(a:str,b: str): print(key) file_parts.append(content) import_parts.update(col_info.imports) - - return list(import_parts),file_parts - + + + return [f"# This file was created using extractor.py\n"] + list(import_parts) + [""],file_parts - # # imported_refs = filter(lambda a: a in parsed_imports.keys(),map(lambda a: parsed_imports_aliases.get(a,a),related_refs)) -# print(extract_from_file(filename="module_test/a.py",names=["A"])) -# print(extract_from_file(filename="translator/core/pipelines.py",names=["FullConversion"])) -# "D:\\Github\\manga-translator\\deepfillv2-pytorch\\test.py" -# with open("translator/core/pipelines.py",'r') as f: -# base_dir = './lama' with open("out.py",'w',encoding='utf8') as out_file: - file_imports,file_content = extract_from_file(filename="D:\\Github\\manga-translator\\extractor.py",names=["extract_from_file"]) + file_imports,file_content = extract_from_file(filename="D:\\Github\\manga-translator\\d.py",names=["non_max_suppression","scale_boxes","process_mask"]) out_file.write("\n".join(file_imports + file_content)) \ No newline at end of file diff --git a/out.py b/out.py deleted file mode 100644 index e53d281..0000000 --- a/out.py +++ /dev/null @@ -1,391 +0,0 @@ -from functools import cmp_to_key -import os -from typing import Any -import pkg_resources -from ast import AnnAssign, Attribute, Call, ClassDef, Expr, FunctionDef, Name, arguments -import ast -from typing import Union -import re -INSTALLED_PACKAGES = {pkg.key for pkg in pkg_resources.working_set} -base_dir = "./" -class ReferencesVisitor(ast.NodeVisitor): - def __init__(self,deps_for: str = "",relevant_refs: list[str] = [],checked: set[str] = set()): - self.deps_for = deps_for - self.relevant_refs = relevant_refs - self.checked = checked - - - def visit_ClassDef(self, node: ClassDef) -> Any: - self.relevant_refs.append(node.name) - for x in node.bases: - self.relevant_refs.append(get_nested_name(x)) - - return self.generic_visit(node) - - def visit_FunctionDef(self, node: FunctionDef) -> Any: - self.relevant_refs.append(node.name) - return self.generic_visit(node) - - def visit_AnnAssign(self, node: AnnAssign) -> Any: - if isinstance(node.annotation,ast.Name): - self.relevant_refs.append(node.annotation.id) - elif isinstance(node.annotation,ast.Attribute): - item_name = get_nested_name(node.annotation) - if item_name is not None and item_name != 'self': - self.relevant_refs.append(item_name) - - return self.generic_visit(node) - - def visit_Name(self, node: Name) -> Any: - self.relevant_refs.append(get_nested_name(node)) - - return self.generic_visit(node) - - def visit_Call(self, node: Call) -> Any: - item_name = get_call_name(node) - if item_name is not None and item_name != 'self': - self.relevant_refs.append(item_name) - return self.generic_visit(node) - - def visit_Attribute(self, node: Attribute) -> Any: - item_name = get_nested_name(node) - if item_name is not None and item_name != 'self': - self.relevant_refs.append(item_name) - return self.generic_visit(node) - - - def visit_Expr(self, node: Expr) -> Any: - item_name = get_nested_name(node.value) - if item_name is not None and item_name != 'self': - self.relevant_refs.append(item_name) - return self.generic_visit(node) -class ImportsVisitor(ast.NodeVisitor): - relativity_regex = r"(?:from|import)(?:[\s]+)?(\.\.|\.|)?([\w\d_\.]+).*" - def __init__(self,file: str): - self.file = file - self.imports_parsed = {} - - def visit_Import(self, node): - - import_str = ast.get_source_segment(self.file,node) - - match = re.search(ImportsVisitor.relativity_regex,import_str) - - if match is None: - self.generic_visit(node) - return - - relativity,module_name = match[1],match[2] - module = f"{relativity}{module_name}" - - for alias in node.names: - # print(alias.__dict__) - self.imports_parsed[alias.name] = (module,import_str,alias.name) - if alias.asname is not None: - self.imports_parsed[alias.asname] = (module,import_str,alias.name) - - self.generic_visit(node) - - def visit_ImportFrom(self, node): - import_str = ast.get_source_segment(self.file,node) - - match = re.search(ImportsVisitor.relativity_regex,import_str) - - if match is None: - self.generic_visit(node) - return - - relativity,module_name = match[1],match[2] - module = f"{relativity}{module_name}" - - for alias in node.names: - self.imports_parsed[alias.name] = (module,import_str,alias.name) - if alias.asname is not None: - self.imports_parsed[alias.asname] = (module,import_str,alias.name) - - self.generic_visit(node) -def is_class(node: ast.AST): - return isinstance(node,ast.ClassDef) -def is_function(node: ast.AST): - return isinstance(node,ast.FunctionDef) or isinstance(node,ast.AsyncFunctionDef) -def is_call(node: ast.AST): - return isinstance(node,ast.Call) -def get_nested_name(start: Union[ast.Attribute,ast.Name,ast.Call]): - if isinstance(start,ast.Name): - return start.id - elif isinstance(start,ast.Attribute): - return get_nested_name(start.value) - elif isinstance(start,ast.Call): - get_nested_name(start.func) - elif isinstance(start,ast.BinOp): - return None - else: - # print("UNKNOWN NAME",start,start.__dict__) - pass - - - return None -def get_call_name(node: ast.Call) -> str: - return get_nested_name(node.func) -def get_class_name(node: ast.ClassDef): - return node.name -def get_function_name(node: ast.FunctionDef): - return node.name -def get_name_for(node: ast.AST): - if is_class(node): - return get_class_name(node) - - if is_function(node): - return get_function_name(node) - - if is_call(node): - return get_call_name(node) - - if isinstance(node,ast.Assign): - return get_nested_name(node.targets[0]) - - if isinstance(node,ast.Attribute): - return get_nested_name(node) - - # print(node,node.__dict__) - return None -def parse_file_imports(file: str,node: ast.AST) -> dict[str,str]: - vis = ImportsVisitor(file=file) - vis.visit(node) - return vis.imports_parsed -def get_target_refs(tree: ast.AST,names: list[str]) -> dict[str,ast.AST]: - deps = {} - for node in tree.body: - # if isinstance(node,ast.Assign): - # node_name = get_name_for(node.value) - - # if node_name is not None and node_name in names: - - # deps[node_name] = node - node_name = get_name_for(node) - # if (isinstance(node,ast.ClassDef) or isinstance(node,ast.FunctionDef) or isinstance(node,ast.AsyncFunctionDef)) and node.name in names: - # deps[node.name] = node - - if node_name is not None and node_name in names: - deps[node_name] = node - - return deps -def get_other_refs(tree: ast.AST,exclude: list[str]) -> dict[str,ast.AST]: - deps = {} - for node in tree.body: - if isinstance(node,ast.Import) or isinstance(node,ast.ImportFrom): - continue - - if isinstance(node,ast.Constant): - print(node.__dict__) - # node_name = get_name_for(node.value) - - # if node_name is not None and node_name not in exclude: - - # deps[node_name] = node - - node_name = get_name_for(node) - # if (isinstance(node,ast.ClassDef) or isinstance(node,ast.FunctionDef) or isinstance(node,ast.AsyncFunctionDef)) and node.name not in exclude: - # deps[node.name] = node - - if node_name is not None and node_name not in exclude: - deps[node_name] = node - - return deps -def module_to_file_path(cur_file_path: str,module: str) -> str: - global base_dir - global paths_cache - global INSTALLED_PACKAGES - - cache_key = module - if cache_key in paths_cache: - return paths_cache[cache_key] - - start_path = base_dir - - if module.split('.')[0] in INSTALLED_PACKAGES: - return None - - - - if module.startswith('..'): - start_path = os.path.join(os.path.dirname(cur_file_path),'..') - module = module[2:] - elif module.startswith('.'): - start_path = os.path.join(os.path.dirname(cur_file_path),'.') - module = module[1:] - - - parts = module.split(".") - - search_path = os.path.abspath(os.path.join(start_path,f"{os.path.sep}".join(parts))) - - #print(search_path + ".py") - if os.path.exists(search_path + ".py"): - - paths_cache[cache_key] = search_path + ".py" - return paths_cache[cache_key] - - #print(os.path.join(search_path,"__init__.py")) - if os.path.exists(os.path.join(search_path,"__init__.py")): - paths_cache[cache_key] = os.path.join(search_path,"__init__.py") - return paths_cache[cache_key] - - #print("I give up",module) - paths_cache[cache_key] = None - return None -def get_refs_for_root_item(item: ast.AST,other_root_refs: dict[str,ast.AST],checked: set[str] = set(),other_root_checked: set[str] = set()) -> list[str]: - # if get_name_for(item) in checked: - # return [] - - my_refs = [] - visitor = ReferencesVisitor(checked=checked,relevant_refs=my_refs) - visitor.visit(item) - checked.add(get_name_for(item)) - - for ref in my_refs: - if ref in other_root_refs.keys() and ref not in checked: - my_refs.extend(get_refs_for_root_item(other_root_refs[ref],other_root_refs=other_root_refs,checked=checked,other_root_checked=other_root_checked)) - return my_refs -class ExtractedRawInfo: - def __init__(self,filename: str,imports: list[str],file_depencencies: dict[str, set[str]],content: list[tuple[str,str,int]]) -> None: - self.filename = filename - self.imports = imports - self.file_dependencies = file_depencencies - self.content = content - - def merge(self,other: 'ExtractedRawInfo'): - self.imports = set(self.imports) - self.imports.update(other.imports) - self.imports = list(self.imports) - - self.file_dependencies.update(other.file_dependencies) - - self.content = set(self.content) - self.content.update(other.content) - self.content = list(self.content) -def extract_raw_info_from_file(filename:str,names: list[str],files_checked_for_refs: set[str] = set()) -> list[ExtractedRawInfo]: - for name in names.copy(): - if f"{filename}=>{name}" in files_checked_for_refs: - names.remove(name) - - for name in names: - files_checked_for_refs.add(f"{filename}=>{name}") - - - # print("CHECKING FILE",filename,names) - with open(filename,'r',encoding="utf8") as f: - file = f.read() - - refs_checked_in_this_file: set[str] = set() - - refs_extracted: set[tuple[str,str,int]] = set() - - tree = ast.parse(file) - - parsed_imports = parse_file_imports(file,tree) - - root_refs = get_target_refs(tree=tree,names=names) - - other_refs = get_other_refs(tree=tree,exclude=names) - - imported_refs = set(parsed_imports.keys()) - - - relevant_refs = [] - others_checked = set() - for root_ref in root_refs.values(): - - relevant_refs.extend(get_refs_for_root_item(root_ref,other_root_refs=other_refs,checked=refs_checked_in_this_file,other_root_checked=others_checked)) - refs_extracted.add((ast.get_source_segment(file,root_ref),root_ref.lineno)) - - relevant_refs = set(filter(lambda a: a is not None,relevant_refs)) - - - imported_refs_from_local_files = set(filter(lambda a: a in relevant_refs and module_to_file_path(filename,parsed_imports[a][0]) is not None,imported_refs)) - - imported_refs_installed = set(filter(lambda a: a in relevant_refs and module_to_file_path(filename,parsed_imports[a][0]) is None,imported_refs)) - - for ref in relevant_refs: - if ref in other_refs.keys(): - ref_node = other_refs.get(ref,None) - if ref_node is not None: - refs_extracted.add((ast.get_source_segment(file,ref_node),ref_node.lineno)) - - - relevant_refs.difference_update(root_refs.keys()) - - relevant_refs_from_file_imports = relevant_refs.intersection(imported_refs_from_local_files) - - files_to_get_refs_from: dict[str,set[str]] = {} - - - for ref in relevant_refs_from_file_imports: - if parsed_imports[ref][2] != ref: - refs_extracted.add((f"{ref} = {parsed_imports[ref][2]}",-1)) - module_path = module_to_file_path(filename,parsed_imports[ref][0]) - if module_path not in files_to_get_refs_from.keys(): - files_to_get_refs_from[module_path] = set() - - files_to_get_refs_from[module_path].add(parsed_imports[ref][2]) - - - all_extracted = [ExtractedRawInfo(filename=filename,imports=list(map(lambda a: parsed_imports.get(a)[1],imported_refs_installed)),file_depencencies=files_to_get_refs_from,content=list(refs_extracted))] - - for module_path in files_to_get_refs_from.keys(): - print("Extracting",files_to_get_refs_from[module_path],"From",module_path) - all_extracted.extend(extract_raw_info_from_file(module_path,names=list(files_to_get_refs_from[module_path]),files_checked_for_refs=files_checked_for_refs)) - - return all_extracted -def extract_from_file(filename:str,names: list[str]): - filename = os.path.abspath(filename) - extracted = extract_raw_info_from_file(filename=filename,names=names) - collated_extracted: dict[str,ExtractedRawInfo] = {} - - for info in extracted: - if info.filename not in collated_extracted.keys(): - collated_extracted[info.filename] = info - else: - collated_extracted[info.filename].merge(info) - - def comp_a_b(a:str,b: str): - a_data = collated_extracted[a] - b_data = collated_extracted[b] - - if a == filename: - return -1 - - if b == filename: - return 1 - - - if a in b_data.file_dependencies.keys(): - return 1 - - if b in a_data.file_dependencies.keys(): - return -1 - - - return len(a_data.file_dependencies.keys()) - len(b_data.file_dependencies.keys()) - collated_keys = list(collated_extracted.keys()) - - for x in collated_keys: - collated_keys.sort(key=cmp_to_key(comp_a_b)) - - for x in collated_extracted: - col_info = collated_extracted[x] - col_info.content.sort(key=lambda a: a[1]) - # print(collated_extracted[x][2]) - - import_parts: set() = set() - file_parts = [] - - for key in reversed(collated_keys): - col_info = collated_extracted[key] - #print(key,col_content) - for content,line_no in col_info.content: - print(key) - file_parts.append(content) - import_parts.update(col_info.imports) - - return list(import_parts),file_parts \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 65ef632..5aacd85 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,14 @@ readme = "README.md" [tool.poetry.dependencies] python = "^3.9" +torch = { version = "2.0.1", source="torch"} +torchvision = { version = "0.15.2", source="torch"} + +[[tool.poetry.source]] +name = "torch" +url = "https://download.pytorch.org/whl/cu117" +secondary = true + # ultralytics = "^8.0.118" # pillow = "^9.5.0" # pyhyphen = "^4.0.3" diff --git a/requirements.txt b/requirements.txt index 0287351..e30214d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,18 +1,23 @@ -ultralytics == 8.0.118 -pillow == 9.5.0 -pyhyphen == 4.0.3 -google-cloud-translate == 3.11.1 -mss == 9.0.1 -manga-ocr == 0.1.10 -largestinteriorrectangle == 0.2.0 -sentencepiece == 0.1.99 -tornado == 6.3.2 -easyocr == 1.7.0 -pytesseract == 0.3.10 -pycountry == 22.3.5 -opencv-python == 4.8.0.74 -pysimplegui == 4.60.5 -timm == 0.9.2 -Faker == 19.1.0 -roboflow == 1.1.3 -simple-lama-inpainting == 0.1.0 \ No newline at end of file +easyocr==1.7.0 +Faker==19.1.0 +jaconv==0.3.4 +largestinteriorrectangle==0.2.0 +mss==9.0.1 +numpy==1.24.3 +opencv_python==4.8.0.74 +opencv_python_headless==4.8.0.74 +Pillow==9.5.0 +Pillow==10.0.1 +pycountry==22.3.5 +PyHyphen==4.0.3 +PySimpleGUI==4.60.5 +pytesseract==0.3.10 +Requests==2.31.0 +setuptools==67.8.0 +simple_lama_inpainting==0.1.0 +timm==0.9.2 +torch==2.0.1+cu117 +torchvision==0.15.2+cu117 +tornado==6.3.2 +tqdm==4.65.0 +ultralytics==8.0.118 diff --git a/server.py b/server.py index 88b97cb..9f460a0 100644 --- a/server.py +++ b/server.py @@ -1,3 +1,7 @@ +from typing import Union +from dotenv import load_dotenv + +load_dotenv() import os import io import urllib.parse @@ -6,14 +10,13 @@ import numpy as np import asyncio from tornado.web import RequestHandler, Application -from threading import Thread -from translator.utils import cv2_to_pil, pil_to_cv2 +from translator.utils import cv2_to_pil, pil_to_cv2, run_in_thread_decorator from translator.pipelines import FullConversion from translator.translators.get import get_translators from translator.translators.deepl import DeepLTranslator from translator.ocr.get import get_ocr from translator.ocr.clean import CleanOcr -from translator.ocr.manga import MangaOcr +from translator.ocr.huggingface_ja import JapaneseOcr from translator.drawers.get import get_drawers from translator.cleaners.get import get_cleaners from PIL import Image @@ -24,40 +27,8 @@ import os -def run_in_thread(func): - async def wrapper(*args, **kwargs): - loop = asyncio.get_event_loop() - pending_task = asyncio.Future() - - def run_task(): - nonlocal loop - loop.call_soon_threadsafe(pending_task.set_result, func(*args, **kwargs)) - - Thread(target=run_task, group=None, daemon=True).start() - result = await pending_task - return result - - return wrapper - -def run_in_thread(func): - async def wrapper(*args, **kwargs): - loop = asyncio.get_event_loop() - pending_task = asyncio.Future() - - def run_task(): - nonlocal loop - loop.call_soon_threadsafe(pending_task.set_result, func(*args, **kwargs)) - - Thread(target=run_task, group=None, daemon=True).start() - result = await pending_task - return result - - return wrapper - - - def cv2_image_from_url(url: str): - if url.startswith('http'): + if url.startswith("http"): return pil_to_cv2(Image.open(io.BytesIO(requests.get(url).content))) else: sanitized = urllib.parse.unquote(url.split("?")[0]) @@ -77,7 +48,9 @@ def extract_params(data: str) -> tuple[int, dict[str, str]]: params = {} if len(params_to_parse.strip()) > 0: - for param_name, param_value in re.findall(REQUEST_SECTION_PARAMS_REGEX, params_to_parse.strip()): + for param_name, param_value in re.findall( + REQUEST_SECTION_PARAMS_REGEX, params_to_parse.strip() + ): if len(param_value.strip()) > 0: params[param_name] = param_value @@ -85,7 +58,7 @@ def extract_params(data: str) -> tuple[int, dict[str, str]]: def send_file_in_chunks(request: RequestHandler, file_path): - with open(file_path, 'rb') as f: + with open(file_path, "rb") as f: while True: data = f.read(16384) # or some other nice-sized chunk if not data: @@ -97,26 +70,28 @@ class CleanFromWebHandler(RequestHandler): def set_default_headers(self): self.set_header("Access-Control-Allow-Origin", "*") self.set_header("Access-Control-Allow-Headers", "*") - self.set_header('Access-Control-Allow-Methods', 'POST, GET, OPTIONS') - self.set_header("Content-Type", 'image/png') + self.set_header("Access-Control-Allow-Methods", "POST, GET, OPTIONS") + self.set_header("Content-Type", "image/png") def options(self): self.set_status(200) async def post(self): try: - - image = self.request.files.get('file') + image = self.request.files.get("file") if image is None: raise BaseException("No Image Sent") - data = json.loads(self.get_argument('data')) + data = json.loads(self.get_argument("data")) - - cleaner_id, cleaner_params = data.get('cleaner', 0), data.get('cleanerArgs', {}) - image_cv2 = pil_to_cv2(Image.open(io.BytesIO(image[0]['body']))) - converter = FullConversion(ocr=CleanOcr(),cleaner=get_cleaners()[cleaner_id](**cleaner_params)) + cleaner_id, cleaner_params = data.get("cleaner", 0), data.get( + "cleanerArgs", {} + ) + image_cv2 = pil_to_cv2(Image.open(io.BytesIO(image[0]["body"]))) + converter = FullConversion( + ocr=CleanOcr(), cleaner=get_cleaners()[cleaner_id](**cleaner_params) + ) results = await converter([image_cv2]) converted_pil = cv2_to_pil(results[0]) img_byte_arr = io.BytesIO() @@ -124,48 +99,52 @@ async def post(self): # Create response given the bytes self.write(img_byte_arr.getvalue()) except: - self.set_header("Content-Type", 'text/html') + self.set_header("Content-Type", "text/html") self.set_status(500) traceback.print_exc() self.write(traceback.format_exc()) - - - class TranslateFromWebHandler(RequestHandler): - def set_default_headers(self): self.set_header("Access-Control-Allow-Origin", "*") self.set_header("Access-Control-Allow-Headers", "*") - self.set_header('Access-Control-Allow-Methods', 'POST, GET, OPTIONS') - self.set_header("Content-Type", 'image/png') + self.set_header("Access-Control-Allow-Methods", "POST, GET, OPTIONS") + self.set_header("Content-Type", "image/png") def options(self): self.set_status(200) async def post(self): try: - - image = self.request.files.get('file') + image = self.request.files.get("file") if image is None: raise BaseException("No Image Sent") - data = json.loads(self.get_argument('data')) + data = json.loads(self.get_argument("data")) - translator_id, translator_params = data.get('translator', 0), data.get('translatorArgs', {}) + translator_id, translator_params = data.get("translator", 0), data.get( + "translatorArgs", {} + ) - ocr_id, ocr_params = data.get('ocr', 0), data.get('ocrArgs', {}) + ocr_id, ocr_params = data.get("ocr", 0), data.get("ocrArgs", {}) - drawer_id, drawer_params = data.get('drawer', 0), data.get('drawerArgs', {}) + drawer_id, drawer_params = data.get("drawer", 0), data.get("drawerArgs", {}) - cleaner_id, cleaner_params = data.get('cleaner', 0), data.get('cleanerArgs', {}) + cleaner_id, cleaner_params = data.get("cleaner", 0), data.get( + "cleanerArgs", {} + ) - image_cv2 = pil_to_cv2(Image.open(io.BytesIO(image[0]['body']))) + image_cv2 = pil_to_cv2(Image.open(io.BytesIO(image[0]["body"]))) - converter = FullConversion(translator=get_translators()[translator_id](**translator_params), - ocr=get_ocr()[ocr_id](**ocr_params), drawer=get_drawers()[drawer_id](**drawer_params),cleaner=get_cleaners()[cleaner_id](**cleaner_params),color_detect_model=None) + converter = FullConversion( + translator=get_translators()[translator_id](**translator_params), + ocr=get_ocr()[ocr_id](**ocr_params), + drawer=get_drawers()[drawer_id](**drawer_params), + cleaner=get_cleaners()[cleaner_id](**cleaner_params), + color_detect_model=None, + ) results = await converter([image_cv2]) @@ -175,21 +154,20 @@ async def post(self): # Create response given the bytes self.write(img_byte_arr.getvalue()) except: - self.set_header("Content-Type", 'text/html') + self.set_header("Content-Type", "text/html") self.set_status(500) self.write(traceback.format_exc()) traceback.print_exc() class ImageHandler(RequestHandler): - def set_default_headers(self): self.set_header("Access-Control-Allow-Origin", "*") self.set_header("Access-Control-Allow-Headers", "*") - self.set_header('Access-Control-Allow-Methods', 'POST, GET, OPTIONS') - self.set_header("Content-Type", 'image/*') + self.set_header("Access-Control-Allow-Methods", "POST, GET, OPTIONS") + self.set_header("Content-Type", "image/*") - @run_in_thread + @run_in_thread_decorator def get(self): try: full_url = self.request.full_url() @@ -205,19 +183,18 @@ def get(self): else: send_file_in_chunks(self, item_path) except: - self.set_header("Content-Type", 'text/html') + self.set_header("Content-Type", "text/html") self.set_status(500) self.write(traceback.format_exc()) traceback.print_exc() class BaseHandler(RequestHandler): - def set_default_headers(self): self.set_header("Access-Control-Allow-Origin", "*") self.set_header("Access-Control-Allow-Headers", "*") - self.set_header('Access-Control-Allow-Methods', 'POST, GET, OPTIONS') - self.set_header("Content-Type", 'application/json') + self.set_header("Access-Control-Allow-Methods", "POST, GET, OPTIONS") + self.set_header("Content-Type", "application/json") def get(self): try: @@ -226,73 +203,90 @@ def get(self): translators = get_translators() for x in range(len(translators)): - data["translators"].append({ - "id": x, - "name": translators[x].get_name(), - "description": translators[x].__doc__, - "args": [x.get() for x in translators[x].get_arguments()] - }) + data["translators"].append( + { + "id": x, + "name": translators[x].get_name(), + "description": translators[x].__doc__, + "args": [x.get() for x in translators[x].get_arguments()], + } + ) ocr = get_ocr() for x in range(len(ocr)): - data["ocr"].append({ - "id": x, - "name": ocr[x].get_name(), - "description": ocr[x].__doc__, - "args": [x.get() for x in ocr[x].get_arguments()] - }) + data["ocr"].append( + { + "id": x, + "name": ocr[x].get_name(), + "description": ocr[x].__doc__, + "args": [x.get() for x in ocr[x].get_arguments()], + } + ) drawers = get_drawers() for x in range(len(drawers)): - data["drawers"].append({ - "id": x, - "name": drawers[x].get_name(), - "description": drawers[x].__doc__, - "args": [x.get() for x in drawers[x].get_arguments()] - }) + data["drawers"].append( + { + "id": x, + "name": drawers[x].get_name(), + "description": drawers[x].__doc__, + "args": [x.get() for x in drawers[x].get_arguments()], + } + ) cleaners = get_cleaners() for x in range(len(cleaners)): - data["cleaners"].append({ - "id": x, - "name": cleaners[x].get_name(), - "description": cleaners[x].__doc__, - "args": [x.get() for x in cleaners[x].get_arguments()] - }) + data["cleaners"].append( + { + "id": x, + "name": cleaners[x].get_name(), + "description": cleaners[x].__doc__, + "args": [x.get() for x in cleaners[x].get_arguments()], + } + ) self.write(json.dumps(data)) except: - self.set_header("Content-Type", 'text/html') + self.set_header("Content-Type", "text/html") self.set_status(500) self.write(traceback.format_exc()) traceback.print_exc() class MiraTranslateWebHandler(RequestHandler): + converter: Union[FullConversion, None] = None + def set_default_headers(self): self.set_header("Access-Control-Allow-Origin", "*") self.set_header("Access-Control-Allow-Headers", "*") - self.set_header('Access-Control-Allow-Methods', 'POST, GET, OPTIONS') - self.set_header("Content-Type", 'image/png') + self.set_header("Access-Control-Allow-Methods", "POST, GET, OPTIONS") + self.set_header("Content-Type", "image/png") def options(self): self.set_status(200) + @run_in_thread_decorator async def post(self): try: - - image = self.request.files.get('file') + image = self.request.files.get("file") if image is None: raise BaseException("No Image Sent") - to_convert = pil_to_cv2(Image.open(io.BytesIO(image[0]['body']))) + if MiraTranslateWebHandler.converter is None: + MiraTranslateWebHandler.converter = FullConversion( + color_detect_model=None, + # translator=OpenAiTranslator(api_key=os.getenv("GPT_AUTH")), + translator=DeepLTranslator(auth_token=os.getenv("DEEPL_AUTH")), + ocr=JapaneseOcr(), + translate_free_text=True, + ) - converter = FullConversion(color_detect_model=None,translator=DeepLTranslator(auth_token=os.environ['DEEPL_AUTH']),ocr=MangaOcr(),translate_free_text=True) + to_convert = pil_to_cv2(Image.open(io.BytesIO(image[0]["body"]))) - translated = await converter([to_convert]) + translated = await MiraTranslateWebHandler.converter([to_convert]) # display_image(translated,"Translated") converted_pil = cv2_to_pil(translated[0]) @@ -306,17 +300,17 @@ async def post(self): self.write(img_byte_arr.getvalue()) except: - self.set_header("Content-Type", 'text/html') + self.set_header("Content-Type", "text/html") self.set_status(500) traceback.print_exc() self.write(traceback.format_exc()) -class UiFilesHandler(RequestHandler): +class UiFilesHandler(RequestHandler): def initialize(self, build_path) -> None: self.build_path = build_path - @run_in_thread + @run_in_thread_decorator def get(self, target_file): send_file_in_chunks(self, os.path.join(self.build_path, target_file)) @@ -328,23 +322,26 @@ def get(self): async def main(): app_port = 5000 - build_path = os.path.join(os.path.dirname(__file__),'ui', "build") + build_path = os.path.join(os.path.dirname(__file__), "ui", "build") settings = { "template_path": build_path, - "static_path": os.path.join(build_path, 'static'), + "static_path": os.path.join(build_path, "static"), } - app = Application([ - (r"/", UiHandler), - (r"/info", BaseHandler), - (r"/clean", CleanFromWebHandler), - (r"/translate", TranslateFromWebHandler), - # (r"/images/.*", ImageHandler), - (r"/mira/translate", MiraTranslateWebHandler), - # (r"/(.*)", UiFilesHandler, dict(build_path=build_path)), - ], **settings) + app = Application( + [ + (r"/", UiHandler), + (r"/info", BaseHandler), + (r"/clean", CleanFromWebHandler), + (r"/translate", TranslateFromWebHandler), + # (r"/images/.*", ImageHandler), + (r"/mira/translate", MiraTranslateWebHandler), + # (r"/(.*)", UiFilesHandler, dict(build_path=build_path)), + ], + **settings, + ) app.listen(app_port) - webbrowser.open(f'http://localhost:{app_port}') + webbrowser.open(f"http://localhost:{app_port}") await asyncio.Event().wait() -asyncio.run(main()) \ No newline at end of file +asyncio.run(main()) diff --git a/test.py b/test.py deleted file mode 100644 index 7b544a3..0000000 --- a/test.py +++ /dev/null @@ -1,7 +0,0 @@ -from translator.utils import resize_and_pad -import numpy as np - - -img = np.zeros((300,180,3),dtype=np.uint8) - -print(resize_and_pad(img,target_size=(1000,500)).shape) \ No newline at end of file diff --git a/translator/cleaners/deepfillv2.py b/translator/cleaners/deepfillv2.py index b4102c0..aead787 100644 --- a/translator/cleaners/deepfillv2.py +++ b/translator/cleaners/deepfillv2.py @@ -24,16 +24,10 @@ class DeepFillV2Cleaner(Cleaner): DEFAULT_MODEL_PATH = os.path.join("models", "inpainting.pth") - _job_queue = queue.Queue() - _model = None _model_path = "" - _pending_tasks = [] - - _thread: Union[None, threading.Thread] = None - @staticmethod def get_model(path: str): if path == DeepFillV2Cleaner._model_path: @@ -99,82 +93,26 @@ def in_paint( return img_out - @staticmethod - def _in_paint_thread(): - payload = DeepFillV2Cleaner._job_queue.get() - - while payload is not None: - data = payload - pil_image, pil_mask, model_path, callback = data - - callback(DeepFillV2Cleaner.in_paint(pil_image, pil_mask, model_path)) - - payload = DeepFillV2Cleaner._job_queue.get() - - @staticmethod - async def add_in_paint_task( - image: np.ndarray, mask: np.ndarray, model_path: str = DEFAULT_MODEL_PATH - ): - loop = asyncio.get_event_loop() - - pending_task = asyncio.Future() - - def callback(in_paint_result): - nonlocal loop - - loop.call_soon_threadsafe(pending_task.set_result, pil_to_cv2(in_paint_result)) - - DeepFillV2Cleaner._job_queue.put((cv2_to_pil(image), cv2_to_pil(mask), model_path, callback)) - DeepFillV2Cleaner._pending_tasks.append(pending_task) - - result = await pending_task - DeepFillV2Cleaner._pending_tasks.remove(pending_task) - return result - - @staticmethod - def _stop_in_paint_thread(signum, frame): - sys.exit(signum) - - @staticmethod - def _exit_thread(): - try: - for task in DeepFillV2Cleaner._pending_tasks: - task.cancel() - # callback_queue.put(None) - if DeepFillV2Cleaner._thread.is_alive(): - DeepFillV2Cleaner._job_queue.put(None) - DeepFillV2Cleaner._thread.join() - except: - pass def __init__(self) -> None: super().__init__() - if DeepFillV2Cleaner._thread is None: - DeepFillV2Cleaner._thread = threading.Thread( - target=lambda: DeepFillV2Cleaner._in_paint_thread(), - group=None, - daemon=True, - ) - DeepFillV2Cleaner._thread.start() - atexit.register(DeepFillV2Cleaner._exit_thread) - @staticmethod def get_name() -> str: return "Deep Fill V2" - + + def clean_section(self,frame: np.ndarray,mask: np.ndarray) -> np.ndarray: + return pil_to_cv2(DeepFillV2Cleaner.in_paint(cv2_to_pil(frame),cv2_to_pil(mask))) + async def clean( self, frame: ndarray, mask: ndarray, detection_results: list[tuple[tuple[int, int, int, int], str, float]] = [], ) -> tuple[ndarray, ndarray]: - loop = asyncio.get_event_loop() return await in_paint_optimized( frame, mask=mask, filtered=detection_results, # segmentation_results.boxes.xyxy.cpu().numpy() - inpaint_fun=lambda frame, mask: loop.create_task(DeepFillV2Cleaner.add_in_paint_task( - frame, mask - )), + inpaint_fun=lambda frame, mask: self.clean_section(frame,mask), ) diff --git a/translator/cleaners/lama.py b/translator/cleaners/lama.py index 38b144d..4deb874 100644 --- a/translator/cleaners/lama.py +++ b/translator/cleaners/lama.py @@ -21,7 +21,7 @@ def get_name() -> str: def get_arguments() -> list[PluginArgument]: return [PluginTextArgument(id="dilation", name="Mask Dilation",description="The dilation used for the text mask", default="9")] - async def clean_with_lama(self,frame,mask): + def clean_with_lama(self,frame,mask): return pil_to_cv2( self.lama(cv2_to_pil(frame), cv2_to_pil(mask).convert("L")) ) @@ -32,11 +32,10 @@ async def clean( mask: ndarray, detection_results: list[tuple[tuple[int, int, int, int], str, float]] = ..., ) -> tuple[ndarray, ndarray]: - loop = asyncio.get_event_loop() return await in_paint_optimized( frame=frame, mask=mask, filtered=detection_results, mask_dilation_kernel_size=self.dilation, - inpaint_fun=lambda f, m: loop.create_task(self.clean_with_lama(f,m)), + inpaint_fun=lambda f, m: self.clean_with_lama(f,m), ) diff --git a/translator/core/plugin.py b/translator/core/plugin.py index d19d4f2..e60b4d4 100644 --- a/translator/core/plugin.py +++ b/translator/core/plugin.py @@ -1,5 +1,5 @@ import numpy as np -from translator.utils import resize_and_pad, display_image +from translator.utils import run_in_thread_decorator class PluginArgumentType: TEXT = 0 @@ -90,11 +90,11 @@ class Ocr(BasePlugin): def __init__(self) -> None: super().__init__() - async def __call__(self, text: np.ndarray) -> OcrResult: - return await self.do_ocr(text) + async def __call__(self, texts: list[np.ndarray]) -> list[OcrResult]: + return await self.do_ocr(texts) - async def do_ocr(self, text: np.ndarray): - return OcrResult("Sample") + async def do_ocr(self, texts: list[np.ndarray]): + return [OcrResult("Sample") for _ in texts] @staticmethod def get_name() -> str: @@ -113,30 +113,36 @@ class Translator(BasePlugin): def __init__(self) -> None: super().__init__() - async def __call__(self, ocr_result: OcrResult) -> str: - return await self.translate(ocr_result) + async def __call__(self, ocr_results: list[OcrResult]) -> list[TranslatorResult]: + return await self.translate(ocr_results) - async def translate(self, ocr_result: OcrResult) -> TranslatorResult: - return TranslatorResult(ocr_result.text) + async def translate(self, ocr_results: list[OcrResult]) -> list[TranslatorResult]: + return [TranslatorResult(x.text) for x in ocr_results] @staticmethod def get_name() -> str: return "Base Translator" +class Drawable: + def __init__(self,color: np.ndarray, translation: TranslatorResult,frame: np.ndarray) -> None: + self.color = color + self.translation = translation + self.frame = frame class Drawer(BasePlugin): def __init__(self) -> None: super().__init__() + async def draw( - self, draw_color: np.ndarray, translation: TranslatorResult, frame: np.ndarray - ) -> np.ndarray: - return frame + self, to_draw: list[Drawable] + ) -> list[tuple[np.ndarray,np.ndarray]]: + return [x.frame for x in to_draw] async def __call__( - self, draw_color: np.ndarray, translation: TranslatorResult, frame: np.ndarray - ) -> np.ndarray: - return await self.draw(draw_color=draw_color, translation=translation, frame=frame) + self, to_draw: list[Drawable] + ) -> list[tuple[np.ndarray,np.ndarray]]: + return await self.draw(to_draw=to_draw) class Cleaner(BasePlugin): diff --git a/translator/drawers/horizontal.py b/translator/drawers/horizontal.py index 45fcd57..0c7f85c 100644 --- a/translator/drawers/horizontal.py +++ b/translator/drawers/horizontal.py @@ -3,7 +3,9 @@ from PIL import ImageFont, ImageDraw from numpy import ndarray from hyphen import Hyphenator +import asyncio from translator.core.plugin import ( + Drawable, Drawer, PluginArgument, PluginSelectArgument, @@ -31,13 +33,19 @@ def __init__( self.line_spacing = round(float(line_spacing)) async def draw( - self, draw_color: np.ndarray, translation: TranslatorResult, frame: np.ndarray - ) -> ndarray: - # print(color_diff(np.array(color),np.array((0,0,0)))) - if len(translation.text.strip()) <= 0: - return frame - - frame_h, frame_w, _ = frame.shape + self,to_draw: list[Drawable] + ) -> list[tuple[ndarray,ndarray]]: + return await asyncio.gather(*[self.draw_one(x) for x in to_draw]) + + + async def draw_one( + self, item: Drawable + ) -> tuple[ndarray,ndarray]: + item_mask = np.zeros_like(item.frame) + if len(item.translation.text.strip()) <= 0: + return (item.frame,item_mask) + + frame_h, frame_w, _ = item.frame.shape # fill background incase of segmentation errors # cv2.rectangle(frame, pt1, pt2, (255, 255, 255), -1) @@ -46,7 +54,7 @@ async def draw( hyphenator = Hyphenator("en_US") font_size, chars_per_line, line_height, iters = get_best_font_size( - translation.text, + item.translation.text, (frame_w, frame_h), font_file=self.font_file, space_between_lines=self.line_spacing, @@ -56,19 +64,24 @@ async def draw( ) if not font_size: - return frame - - frame_as_pil = cv2_to_pil(frame) + return (item.frame,item_mask) font = ImageFont.truetype(self.font_file, font_size) draw_x = 0 draw_y = 0 - wrapped = wrap_text(translation.text, chars_per_line, hyphenator=hyphenator) + wrapped = wrap_text(item.translation.text, chars_per_line, hyphenator=hyphenator) + + frame_as_pil = cv2_to_pil(item.frame) + + mask_as_pil = cv2_to_pil(item_mask) image_draw = ImageDraw.Draw(frame_as_pil) + mask_draw = ImageDraw.Draw(mask_as_pil) + + stroke_width = 2 for line_no in range(len(wrapped)): line = wrapped[line_no] x, y, w, h = font.getbbox(line) @@ -92,13 +105,39 @@ async def draw( + (self.line_spacing * line_no), ), str(line), - fill=(*draw_color, 255), + fill=(*item.color, 255), + font=font, + stroke_width=stroke_width, + stroke_fill=(255, 255, 255), + ) + + mask_draw.text( + ( + draw_x + abs(((frame_w - w) / 2)), + draw_y + + self.line_spacing + + ( + ( + frame_h + - ( + (len(wrapped) * line_height) + + (len(wrapped) * self.line_spacing) + ) + ) + / 2 + ) + + (line_no * line_height) + + (self.line_spacing * line_no), + ), + str(line), + fill=(255, 255, 255, 255), font=font, - stroke_width=2, + stroke_width=stroke_width, stroke_fill=(255, 255, 255), ) - return pil_to_cv2(frame_as_pil) + return (pil_to_cv2(frame_as_pil),pil_to_cv2(mask_as_pil)) + @staticmethod def get_arguments() -> list[PluginArgument]: diff --git a/translator/drawers/vertical.py b/translator/drawers/vertical.py index 275ba7c..aff0419 100644 --- a/translator/drawers/vertical.py +++ b/translator/drawers/vertical.py @@ -1,6 +1,6 @@ from typing import Any from numpy import ndarray -from translator.core.plugin import Drawer, TranslatorResult +from translator.core.plugin import Drawable, Drawer, TranslatorResult from translator.utils import ( get_best_font_size, cv2_to_pil, @@ -13,11 +13,6 @@ class VerticalDrawer(Drawer): """Draws text vertically""" - async def draw( - self, draw_color: ndarray, translation: TranslatorResult, frame: ndarray - ) -> ndarray: - return super().draw(draw_color, translation, frame) - @staticmethod def is_valid() -> bool: return False diff --git a/translator/ocr/clean.py b/translator/ocr/clean.py index c254b93..54de63f 100644 --- a/translator/ocr/clean.py +++ b/translator/ocr/clean.py @@ -14,8 +14,8 @@ class CleanOcr(Ocr): def __init__(self) -> None: super().__init__() - async def do_ocr(self, text: numpy.ndarray): - return OcrResult("", "") + async def do_ocr(self, texts: list[numpy.ndarray]): + return [OcrResult("", "") for _ in texts] @staticmethod def get_name() -> str: diff --git a/translator/ocr/easy_ocr.py b/translator/ocr/easy_ocr.py index 71e8650..6231514 100644 --- a/translator/ocr/easy_ocr.py +++ b/translator/ocr/easy_ocr.py @@ -105,11 +105,11 @@ def __init__(self, lang=languages[0]) -> None: self.easy = easyocr.Reader([lang]) self.language = lang - async def do_ocr(self, text: numpy.ndarray): - return OcrResult( - text=self.easy.readtext(text, detail=0, paragraph=True)[0], + async def do_ocr(self, texts: list[numpy.ndarray]): + return [OcrResult( + text=self.easy.readtext(x, detail=0, paragraph=True)[0], language=self.language, - ) # self.language) + ) for x in texts] @staticmethod def get_name() -> str: diff --git a/translator/ocr/get.py b/translator/ocr/get.py index c69dd24..b289c77 100644 --- a/translator/ocr/get.py +++ b/translator/ocr/get.py @@ -1,11 +1,11 @@ from translator.core.plugin import Ocr from translator.ocr.clean import CleanOcr -from translator.ocr.manga import MangaOcr +from translator.ocr.huggingface_ja import JapaneseOcr from translator.ocr.easy_ocr import EasyOcr from translator.ocr.tessaract_ocr import TesseractOcr def get_ocr() -> list[Ocr]: return list( - filter(lambda a: a.is_valid(), [CleanOcr, MangaOcr, EasyOcr, TesseractOcr]) + filter(lambda a: a.is_valid(), [CleanOcr, JapaneseOcr, EasyOcr, TesseractOcr]) ) diff --git a/translator/ocr/huggingface_ja.py b/translator/ocr/huggingface_ja.py new file mode 100644 index 0000000..5c89621 --- /dev/null +++ b/translator/ocr/huggingface_ja.py @@ -0,0 +1,40 @@ +import numpy +import torch +import re +import jaconv +from transformers import pipeline +from translator.utils import cv2_to_pil, get_torch_device +from translator.core.plugin import Ocr, OcrResult + + +class JapaneseOcr(Ocr): + """Only Supports Japanese""" + + def __init__(self,model='TareHimself/manga-ocr-base') -> None: + super().__init__() + self.pipeline = pipeline("image-to-text", model=model, device=get_torch_device()) + + async def do_ocr(self, texts: list[numpy.ndarray]): + + with torch.inference_mode(): + frames = [cv2_to_pil(x).convert('L').convert('RGB') for x in texts] + results = self.pipeline(frames,max_new_tokens=300) + + return [OcrResult(self._post_process(x[0]['generated_text']), "ja") for x in results] + + def _preprocess(self, frame: numpy.ndarray): + + pixel_values = self.feature_extractor(frame, return_tensors="pt").pixel_values + return pixel_values.squeeze() + + def _post_process(self,text: str): + text = ''.join(text.split()) + text = text.replace('…', '...') + text = re.sub('[・.]{2,}', lambda x: (x.end() - x.start()) * '.', text) + text = jaconv.h2z(text, ascii=True, digit=True) + + return text + + @staticmethod + def get_name() -> str: + return "Japanese Ocr" diff --git a/translator/ocr/manga.py b/translator/ocr/manga.py deleted file mode 100644 index 39786bc..0000000 --- a/translator/ocr/manga.py +++ /dev/null @@ -1,20 +0,0 @@ -import numpy -from translator.utils import cv2_to_pil -from translator.core.plugin import Ocr, OcrResult - - -class MangaOcr(Ocr): - """Only Supports Japanese""" - - def __init__(self) -> None: - from manga_ocr import MangaOcr as MangaOcrPackage - - super().__init__() - self.manga_ocr = MangaOcrPackage() - - async def do_ocr(self, text: numpy.ndarray): - return OcrResult(self.manga_ocr(cv2_to_pil(text)), "ja") - - @staticmethod - def get_name() -> str: - return "Manga Ocr" diff --git a/translator/ocr/tessaract_ocr.py b/translator/ocr/tessaract_ocr.py index 5f7c4a4..f670609 100644 --- a/translator/ocr/tessaract_ocr.py +++ b/translator/ocr/tessaract_ocr.py @@ -24,12 +24,12 @@ def __init__(self, language=default_language) -> None: self.tesseract = pytesseract self.language = language - async def do_ocr(self, text: numpy.ndarray): + async def do_ocr(self, texts: list[numpy.ndarray]): print("Using Tesseract OCR") - return OcrResult( - text=self.tesseract.image_to_string(text, lang=self.language), + return [OcrResult( + text=self.tesseract.image_to_string(x, lang=self.language), language=simplify_lang_code(self.language), - ) + ) for x in texts] @staticmethod def get_name() -> str: diff --git a/translator/pipelines.py b/translator/pipelines.py index f865134..0f68db3 100644 --- a/translator/pipelines.py +++ b/translator/pipelines.py @@ -8,9 +8,11 @@ mask_text_and_make_bubble_mask, get_bounds_for_text, TranslatorGlobals, + run_in_thread_decorator, transform_sample, has_white, get_model_path, + apply_mask ) import traceback import threading @@ -19,7 +21,7 @@ from typing import Union from concurrent.futures import ThreadPoolExecutor from translator.color_detect.models import get_color_detection_model -from translator.core.plugin import Translator, Ocr, Drawer, Cleaner +from translator.core.plugin import Drawable, Translator, Ocr, Drawer, Cleaner from translator.cleaners.deepfillv2 import DeepFillV2Cleaner from translator.drawers.horizontal import HorizontalDrawer @@ -87,21 +89,35 @@ def filter_results(self, results, min_confidence=0.1): confidence = np.array(results.boxes.conf.cpu(), dtype="float") - filtered: list[tuple[tuple[int, int, int, int], str, float]] = [] + raw_results: list[tuple[tuple[int, int, int, int], str, float]] = [] + # has_similar = False + # for item, _, __ in filtered: + # print(np.absolute(item - box)) + # diff = np.average(np.absolute(item - box)) + # if diff < 10: + # has_similar = True + # break + # if has_similar: + # continue + for box, obj_class, conf in zip(bounding_boxes, classes, confidence): if conf >= min_confidence: - has_similar = False - for item, _, __ in filtered: - diff = np.average(np.absolute(item - box)) - if diff < 10: - has_similar = True - break - if has_similar: - continue - filtered.append((box, results.names[obj_class], conf)) - - return filtered + raw_results.append((box, results.names[obj_class], conf)) + + raw_results.sort(key= lambda a: 1 - a[2]) + + results: list[tuple[tuple[int, int, int, int], str, float]] = [] + + # print(f"Starting with {len(raw_results)} results") + # while len(raw_results) > 0: + # results.append(raw_results[0]) + # raw_results = list(filter(lambda a: iou(raw_results[0][0],a[0]) < 0.5,raw_results)) + + results = raw_results + + # print(f"Ended with {len(results)} results") + return results async def process_ml_results(self, detect_result, seg_result, frame): text_mask = np.zeros_like(frame, dtype=frame.dtype) @@ -132,233 +148,251 @@ async def process_ml_results(self, detect_result, seg_result, frame): async def get_translation(self, image_with_text, extra_data): return (await self.translator(await self.ocr(image_with_text)), *extra_data) + @run_in_thread_decorator async def process_frame(self, detect_result, seg_result, frame): - frame, frame_clean, text_mask, detect_result = await self.process_ml_results( - detect_result, seg_result, frame - ) + try: + frame, frame_clean, text_mask, detect_result = await self.process_ml_results( + detect_result, seg_result, frame + ) - to_translate = [] - # First pass, mask all bubbles - for bbox, cls, conf in detect_result: - # if conf < 0.65: - # continue + to_translate = [] + # First pass, mask all bubbles + for bbox, cls, conf in detect_result: + try: + # if conf < 0.65: + # continue - # print(get_ocr(get_box_section(frame, box))) - color = (0, 0, 255) if cls == 1 else (0, 255, 0) + # print(get_ocr(get_box_section(frame, box))) + color = (0, 0, 255) if cls == 1 else (0, 255, 0) - (x1, y1, x2, y2) = bbox + (x1, y1, x2, y2) = bbox - class_name = cls + class_name = cls - bubble = frame[y1:y2, x1:x2] - bubble_clean = frame_clean[y1:y2, x1:x2] - bubble_text_mask = text_mask[y1:y2, x1:x2] + bubble = frame[y1:y2, x1:x2] + bubble_clean = frame_clean[y1:y2, x1:x2] + bubble_text_mask = text_mask[y1:y2, x1:x2] - if class_name == "text_bubble": - if has_white(bubble_text_mask): - text_only, bubble_mask = mask_text_and_make_bubble_mask( - bubble, bubble_text_mask, bubble_clean - ) + if class_name == "text_bubble": + if has_white(bubble_text_mask): + text_only, bubble_mask = mask_text_and_make_bubble_mask( + bubble, bubble_text_mask, bubble_clean + ) - frame[y1:y2, x1:x2] = bubble_clean - text_draw_bounds = get_bounds_for_text(bubble_mask) + frame[y1:y2, x1:x2] = bubble_clean + text_draw_bounds = get_bounds_for_text(bubble_mask) - pt1, pt2 = text_draw_bounds + pt1, pt2 = text_draw_bounds - pt1_x, pt1_y = pt1 - pt2_x, pt2_y = pt2 + pt1_x, pt1_y = pt1 + pt2_x, pt2_y = pt2 - pt1_x += x1 - pt2_x += x1 - pt1_y += y1 - pt2_y += y1 + pt1_x += x1 + pt2_x += x1 + pt1_y += y1 + pt2_y += y1 - to_translate.append([(pt1_x, pt1_y, pt2_x, pt2_y), text_only]) - # debug_image(text_only,"Text Only") - else: - if self.translate_free_text: - free_text = frame[y1:y2, x1:x2] - if has_white(free_text): - text_only, _ = mask_text_and_make_bubble_mask( - free_text, bubble_text_mask, bubble_clean + to_translate.append([(pt1_x, pt1_y, pt2_x, pt2_y), text_only]) + + # frame = cv2.rectangle(frame,(x1,y1),(x2,y2),color=(255,255,0),thickness=2) + # debug_image(text_only,"Text Only") + else: + if self.translate_free_text: + free_text = frame[y1:y2, x1:x2] + if has_white(free_text): + text_only, _ = mask_text_and_make_bubble_mask( + free_text, bubble_text_mask, bubble_clean + ) + + to_translate.append([(x1, y1, x2, y2), text_only]) + + frame[y1:y2, x1:x2] = frame_clean[y1:y2, x1:x2] + else: + frame[y1:y2, x1:x2] = frame_clean[y1:y2, x1:x2] + + if self.debug: + cv2.putText( + frame, + str(f"{cls} | {conf * 100:.1f}%"), + (x1, y1 - 20), + cv2.FONT_HERSHEY_PLAIN, + 1, + color, + 2, ) + except: + traceback.print_exc() + + # second pass, fix intersecting text areas + # for i in range(len(to_translate)): + # bbox_a = to_translate[i][0] + # text_bounds_a_local = to_translate[i][3] + # text_bounds_a = [ + # [ + # text_bounds_a_local[0][0] + bbox_a[0], + # text_bounds_a_local[0][1] + bbox_a[1], + # ], + # [ + # text_bounds_a_local[1][0] + bbox_a[0], + # text_bounds_a_local[1][1] + bbox_a[1], + # ], + # ] + # for x in range(len(to_translate)): + # if x == i: + # continue + + # bbox_b = to_translate[x][0] + # text_bounds_b_local = to_translate[x][3] + # text_bounds_b = [ + # [ + # text_bounds_b_local[0][0] + bbox_b[0], + # text_bounds_b_local[0][1] + bbox_b[1], + # ], + # [ + # text_bounds_b_local[1][0] + bbox_b[0], + # text_bounds_b_local[1][1] + bbox_b[1], + # ], + # ] + + # fix_result = fix_intersection( + # text_bounds_a[0], + # text_bounds_a[1], + # text_bounds_b[0], + # text_bounds_b[1], + # ) + # found_intersection = fix_result[4] + # if found_intersection: + # to_translate[i][3] = [ + # [ + # fix_result[0][0] - bbox_a[0], + # fix_result[0][1] - bbox_a[1], + # ], + # [torch.cuda.is_available() + # fix_result[1][0] - bbox_a[0], + # fix_result[1][1] - bbox_a[1], + # ], + # ] + # to_translate[x][3] = [ + # [ + # fix_result[2][0] - bbox_b[0], + # fix_result[2][1] - bbox_b[1], + # ], + # [ + # fix_result[3][0] - bbox_b[0], + # fix_result[3][1] - bbox_b[1], + # ], + # ] + # print( + # bbox_a, + # text_bounds_a_local, + # text_bounds_a, + # "\n", + # bbox_b, + # text_bounds_b_local, + # text_bounds_b, + # ) + # debug_image(to_translate[i][1]) + # debug_image(to_translate[x][1]) + # cv2.rectangle( + # frame, + # fix_result[0], + # fix_result[1], + # (0, 255, 255), + # 1, + # ) + # cv2.rectangle( + # frame, + # fix_result[2], + # fix_result[3], + # (0, 255, 255), + # 1, + # ) + # print("intersection found") + + # third pass, draw text + draw_colors = [TranslatorGlobals.COLOR_BLACK for x in to_translate] + + start = time.time() + if self.color_detect_model is not None and len(draw_colors) > 0: + with torch.no_grad(): # model needs work + with torch.inference_mode(): + with self.frame_process_mutex: # this may not be needed + + def fix_image(img): + # img = adjust_contrast_brightness(frame,contrast=2) + # size_dil = 3 + # returncv2.GaussianBlur(img, (size_dil, size_dil), 0) + + # final_mask_dilation = 6 + # kernel = np.ones((final_mask_dilation,final_mask_dilation),np.uint8) + # return cv2.dilate(img,kernel,iterations = 1) + return img + + images = [ + fix_image(frame_with_text.copy()) + for _, frame_with_text in to_translate + ] + # images = [x[2].copy() for x in to_translate] + # [display_image(x,"To Detect") for x in images] + + draw_colors = [ + (x.cpu().numpy() * 255).astype(np.uint8) + for x in self.color_detect_model( + torch.stack([transform_sample(y) for y in images]).to( + torch.device("cuda:0") + ) + ) + ] + else: + print("Using black since color detect model is 'None'") - to_translate.append([(x1, y1, x2, y2), text_only]) - - frame[y1:y2, x1:x2] = frame_clean[y1:y2, x1:x2] - else: - frame[y1:y2, x1:x2] = frame_clean[y1:y2, x1:x2] - - if self.debug: - cv2.putText( - frame, - str(f"{cls} | {conf * 100:.1f}%"), - (x1, y1 - 20), - cv2.FONT_HERSHEY_PLAIN, - 1, - color, - 2, - ) + print(f"Color Detection => {time.time() - start} seconds") - # second pass, fix intersecting text areas - # for i in range(len(to_translate)): - # bbox_a = to_translate[i][0] - # text_bounds_a_local = to_translate[i][3] - # text_bounds_a = [ - # [ - # text_bounds_a_local[0][0] + bbox_a[0], - # text_bounds_a_local[0][1] + bbox_a[1], - # ], - # [ - # text_bounds_a_local[1][0] + bbox_a[0], - # text_bounds_a_local[1][1] + bbox_a[1], - # ], - # ] - # for x in range(len(to_translate)): - # if x == i: - # continue + start = time.time() - # bbox_b = to_translate[x][0] - # text_bounds_b_local = to_translate[x][3] - # text_bounds_b = [ - # [ - # text_bounds_b_local[0][0] + bbox_b[0], - # text_bounds_b_local[0][1] + bbox_b[1], - # ], - # [ - # text_bounds_b_local[1][0] + bbox_b[0], - # text_bounds_b_local[1][1] + bbox_b[1], - # ], - # ] - - # fix_result = fix_intersection( - # text_bounds_a[0], - # text_bounds_a[1], - # text_bounds_b[0], - # text_bounds_b[1], - # ) - # found_intersection = fix_result[4] - # if found_intersection: - # to_translate[i][3] = [ - # [ - # fix_result[0][0] - bbox_a[0], - # fix_result[0][1] - bbox_a[1], - # ], - # [torch.cuda.is_available() - # fix_result[1][0] - bbox_a[0], - # fix_result[1][1] - bbox_a[1], - # ], - # ] - # to_translate[x][3] = [ - # [ - # fix_result[2][0] - bbox_b[0], - # fix_result[2][1] - bbox_b[1], - # ], - # [ - # fix_result[3][0] - bbox_b[0], - # fix_result[3][1] - bbox_b[1], - # ], - # ] - # print( - # bbox_a, - # text_bounds_a_local, - # text_bounds_a, - # "\n", - # bbox_b, - # text_bounds_b_local, - # text_bounds_b, - # ) - # debug_image(to_translate[i][1]) - # debug_image(to_translate[x][1]) - # cv2.rectangle( - # frame, - # fix_result[0], - # fix_result[1], - # (0, 255, 255), - # 1, - # ) - # cv2.rectangle( - # frame, - # fix_result[2], - # fix_result[3], - # (0, 255, 255), - # 1, - # ) - # print("intersection found") - - # third pass, draw text - draw_colors = [TranslatorGlobals.COLOR_BLACK for x in to_translate] + to_draw = [] - start = time.time() - if self.color_detect_model is not None and len(draw_colors) > 0: - with torch.no_grad(): # model needs work - with torch.inference_mode(): - with self.frame_process_mutex: # this may not be needed - - def fix_image(img): - # img = adjust_contrast_brightness(frame,contrast=2) - # size_dil = 3 - # returncv2.GaussianBlur(img, (size_dil, size_dil), 0) - - # final_mask_dilation = 6 - # kernel = np.ones((final_mask_dilation,final_mask_dilation),np.uint8) - # return cv2.dilate(img,kernel,iterations = 1) - return img - - images = [ - fix_image(frame_with_text.copy()) - for _, frame_with_text in to_translate - ] - # images = [x[2].copy() for x in to_translate] - # [display_image(x,"To Detect") for x in images] - - draw_colors = [ - (x.cpu().numpy() * 255).astype(np.uint8) - for x in self.color_detect_model( - torch.stack([transform_sample(y) for y in images]).to( - torch.device("cuda:0") - ) - ) - ] - else: - print("Using black since color detect model is 'None'") - print(f"Color Detection => {time.time() - start} seconds") + if self.translator and self.ocr and len(to_translate) > 0: + bboxes,images = zip(*to_translate) - start = time.time() + ocr_results = await self.ocr(images) - to_draw = [] + translation_results = await self.translator(ocr_results) + to_draw = [] + for bbox,translation,color in zip(bboxes,translation_results,draw_colors): - if self.translator and self.ocr and len(to_translate) > 0: - tasks = [] + (x1, y1, x2, y2) = bbox + draw_area = frame[y1:y2, x1:x2].copy() - for i in range(len(to_translate)): - bbox, frame_with_text = to_translate[i] - draw_color = draw_colors[i] + to_draw.append(Drawable(color=color,frame=draw_area,translation=translation)) - tasks.append(self.get_translation(frame_with_text,(bbox, draw_color))) + - to_draw = [x for x in await asyncio.gather(*tasks)] - + print(f"Ocr And Translation => {time.time() - start} seconds") - print(f"Ocr And Translation => {time.time() - start} seconds") + start = time.time() - start = time.time() + drawn_frames = await self.drawer(to_draw) - for translation, bbox, draw_color in to_draw: - (x1, y1, x2, y2) = bbox - draw_frame = frame[y1:y2, x1:x2] + for bbox, drawn_frame in zip(bboxes,drawn_frames): + (x1, y1, x2, y2) = bbox - outline_color = get_outline_color(draw_frame, draw_color) + # draw_frame = frame[y1:y2, x1:x2] - frame[y1:y2, x1:x2] = await self.drawer( - draw_color=draw_color, translation=translation, frame=draw_frame - ) + # outline_color = get_outline_color(draw_frame, draw_color) + + drawn_frame,drawn_frame_mask = drawn_frame - print(f"Drawing => {time.time() - start} seconds") - return frame + frame[y1:y2, x1:x2] = apply_mask(frame[y1:y2,x1:x2],drawn_frame,drawn_frame_mask) + + print(f"Drawing => {time.time() - start} seconds") + return frame + except: + traceback.print_exc() + return None async def __call__( self, diff --git a/translator/translators/debug.py b/translator/translators/debug.py index 520d5a8..a74ef35 100644 --- a/translator/translators/debug.py +++ b/translator/translators/debug.py @@ -14,8 +14,8 @@ def __init__(self, text="") -> None: super().__init__() self.to_write = text - async def translate(self, ocr_result: OcrResult): - return TranslatorResult(self.to_write) + async def translate(self, ocr_results: list[OcrResult]): + return [TranslatorResult(self.to_write) for _ in ocr_results] @staticmethod def get_name() -> str: diff --git a/translator/translators/deepl.py b/translator/translators/deepl.py index 234c1b6..c04eb21 100644 --- a/translator/translators/deepl.py +++ b/translator/translators/deepl.py @@ -1,5 +1,6 @@ import requests import traceback +import asyncio from requests.utils import requote_uri from translator.core.plugin import ( Translator, @@ -26,20 +27,20 @@ def get_arguments() -> list[PluginArgument]: id="auth_token", name="Auth Token", description="DeepL Api Auth Token" ) ] - - async def translate(self, ocr_result: OcrResult): + + async def do_api(self,result: OcrResult): if self.auth_token is None or len(self.auth_token.strip()) == 0: - return "Need DeepL Auth" + return TranslatorResult("Need DeepL Auth") - if len(ocr_result.text.strip()) == 0: - return TranslatorResult() + if len(result.text.strip()) == 0: + return TranslatorResult("") - if ocr_result.language == "ja": + if result.language == "ja": try: data = [ ("target_lang", "EN-US"), ("source_lang", "JA"), - ("text", ocr_result.text), + ("text", result.text), ] uri = f"https://api-free.deepl.com/v2/translate?{'&'.join([f'{data[i][0]}={data[i][1]}' for i in range(len(data))])}" uri = requote_uri(uri) @@ -59,6 +60,10 @@ async def translate(self, ocr_result: OcrResult): else: return TranslatorResult("Language not supported") + async def translate(self, ocr_results: list[OcrResult]): + return await asyncio.gather(*[self.do_api(x) for x in ocr_results]) + + @staticmethod def get_name() -> str: return "DeepL Translator" diff --git a/translator/translators/google.py b/translator/translators/google.py index d895f8c..704e1a2 100644 --- a/translator/translators/google.py +++ b/translator/translators/google.py @@ -27,20 +27,18 @@ def __init__(self, key_path="") -> None: else: self.trans = None - async def translate(self, ocr_result: OcrResult): + + async def translate(self, ocr_results: list[OcrResult]): if self.trans is None: - return TranslatorResult("Invalid Key Path") + return [TranslatorResult("Invalid Key Path") for _ in ocr_results] - if len(ocr_result.text.strip()) == 0: - return TranslatorResult() - - return TranslatorResult( + return [TranslatorResult( self.trans.translate( - ocr_result.text, - source_language=ocr_result.language, + x.text, + source_language=x.language, target_language="en", )["translatedText"] - ) + ) for x in ocr_results] @staticmethod def get_arguments() -> list[PluginArgument]: diff --git a/translator/translators/hugging_face.py b/translator/translators/hugging_face.py index aac1193..9c9f304 100644 --- a/translator/translators/hugging_face.py +++ b/translator/translators/hugging_face.py @@ -1,4 +1,6 @@ +import torch from transformers import pipeline +from translator.utils import get_torch_device from translator.core.plugin import ( Translator, TranslatorResult, @@ -13,13 +15,19 @@ class HuggingFace(Translator): def __init__(self, model_url: str = "Helsinki-NLP/opus-mt-ja-en") -> None: super().__init__() - self.pipeline = pipeline("translation", model=model_url) + print("Using model",model_url) + self.pipeline = pipeline("translation", model=model_url, device=get_torch_device()) - async def translate(self, ocr_result: OcrResult): - if len(ocr_result.text.strip()) == 0: - return TranslatorResult() + # if torch.cuda.is_available(): + # self.pipeline.cuda() + # elif torch.backends.mps.is_available(): + # self.pipeline.to('mps') - return TranslatorResult(self.pipeline(ocr_result.text)[0]["translation_text"]) + async def translate(self, ocr_results: list[OcrResult]): + #return [print(y) for y in self.pipeline([x.text for x in ocr_results])] + + + return [TranslatorResult(y["translation_text"]) for y in self.pipeline([x.text for x in ocr_results])] @staticmethod def get_name() -> str: @@ -32,6 +40,6 @@ def get_arguments() -> list[PluginArgument]: id="model_url", name="Model", description="The Hugging Face translation model to use", - default="staka/fugumt-ja-en", + default="Helsinki-NLP/opus-mt-ja-en", ) ] diff --git a/translator/translators/openai.py b/translator/translators/openai.py index 15a6a3a..1ce2279 100644 --- a/translator/translators/openai.py +++ b/translator/translators/openai.py @@ -1,3 +1,4 @@ +import asyncio from translator.utils import get_languages from translator.core.plugin import ( Translator, @@ -33,10 +34,7 @@ def __init__( self.model = model self.temp = float(temp) - async def translate(self, ocr_result: OcrResult): - if len(ocr_result.text.strip()) == 0: - return TranslatorResult(lang_code=self.target_lang) - + async def translate_one(self,ocr_result: OcrResult): message = f"{ocr_result.language.upper()} to {self.target_lang.upper()}\n{ocr_result.text}" result = self.openai.ChatCompletion.create( @@ -50,6 +48,15 @@ async def translate(self, ocr_result: OcrResult): return TranslatorResult( result["choices"][0].message["content"].strip(), self.target_lang ) + + async def translate(self, ocr_results: list[OcrResult]): + if len(ocr_results) == 0: + return [TranslatorResult(lang_code=self.target_lang) for _ in ocr_results] + + + return await asyncio.gather(*[self.translate_one(x) for x in ocr_results]) + + @staticmethod def get_name() -> str: diff --git a/translator/utils.py b/translator/utils.py index e513fe3..a675b8c 100644 --- a/translator/utils.py +++ b/translator/utils.py @@ -12,6 +12,7 @@ import PIL import PySimpleGUI as sg import asyncio +import inspect import largestinteriorrectangle as lir from torchvision import transforms from typing import Union, Callable @@ -19,12 +20,41 @@ from hyphen import Hyphenator from tqdm import tqdm from collections import deque +import traceback class TranslatorGlobals: COLOR_BLACK = np.array((0, 0, 0)) COLOR_WHITE = np.array((255, 255, 255)) +async def run_in_thread(func,*args,**kwargs): + loop = asyncio.get_event_loop() + task = asyncio.Future() + def run(): + nonlocal loop + nonlocal func + nonlocal task + + result = func(*args,**kwargs) + + if inspect.isawaitable(result): + result = asyncio.run(result) + loop.call_soon_threadsafe(task.set_result,result) + + task_thread = threading.Thread(group=None,daemon=True,target=run) + task_thread.start() + return await task + +def run_in_thread_decorator(func): + async def wrapper(*args,**kwargs): + return await run_in_thread(func,*args,**kwargs) + + return wrapper + + + +def get_torch_device() -> torch.device: + return torch.device('cuda') if torch.cuda.is_available() else (torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')) def simplify_lang_code(code: str) -> Union[str, None]: try: @@ -404,7 +434,7 @@ def mask_text_for_in_painting(frame: np.ndarray, mask: np.ndarray): return new_mask - +@run_in_thread_decorator async def in_paint_optimized( frame: np.ndarray, mask: np.ndarray, @@ -412,7 +442,7 @@ async def in_paint_optimized( max_height: int = 256, max_width: int = 256, mask_dilation_kernel_size: int = 9, - inpaint_fun: Callable[[np.ndarray, np.ndarray], asyncio.Task[np.ndarray]] = lambda a, b: a, + inpaint_fun: Callable[[np.ndarray, np.ndarray], np.ndarray] = lambda a, b: a, ) -> tuple[np.ndarray, np.ndarray]: h, w, c = frame.shape max_height = int(math.floor(max_height / 8) * 8) @@ -426,109 +456,113 @@ async def in_paint_optimized( half_width = int(max_width / 2) for bbox, cls, conf in filtered: - bx1, by1, bx2, by2 = bbox - bx1, by1, bx2, by2 = round(bx1), round(by1), round(bx2), round(by2) + try: + bx1, by1, bx2, by2 = bbox + bx1, by1, bx2, by2 = round(bx1), round(by1), round(bx2), round(by2) - half_bx = round((bx2 - bx1) / 2) - half_by = round((by2 - by1) / 2) - midpoint_x, midpoint_y = round(bx1 + half_bx), round(by1 + half_by) + half_bx = round((bx2 - bx1) / 2) + half_by = round((by2 - by1) / 2) + midpoint_x, midpoint_y = round(bx1 + half_bx), round(by1 + half_by) - x1, y1 = max(0, midpoint_x - half_width), max(0, midpoint_y - half_height) + x1, y1 = max(0, midpoint_x - half_width), max(0, midpoint_y - half_height) - x2, y2 = min(w, midpoint_x + half_width), min(h, midpoint_y + half_height) + x2, y2 = min(w, midpoint_x + half_width), min(h, midpoint_y + half_height) - if y2 < by2: - y2 = by2 + if y2 < by2: + y2 = by2 - if y1 > by1: - y1 = by1 + if y1 > by1: + y1 = by1 - if x2 < bx2: - x2 = bx2 + if x2 < bx2: + x2 = bx2 - if x1 > bx1: - x1 = bx1 - - overflow_x = (x2 - x1) % 8 - x1_adjust = 0 - if overflow_x != 0: - if x2 > x1: - x2 -= overflow_x - else: - x1 += overflow_x - x1_adjust = overflow_x + if x1 > bx1: + x1 = bx1 - overflow_y = (y2 - y1) % 8 + overflow_x = (x2 - x1) % 8 + x1_adjust = 0 + if overflow_x != 0: + if x2 > x1: + x2 -= overflow_x + else: + x1 += overflow_x + x1_adjust = overflow_x - y1_adjust = 0 - if overflow_y != 0: - if y2 > y1: - y2 -= overflow_y - else: - y1 += overflow_y - y1_adjust = overflow_y - - bx1 = bx1 - (x1 + x1_adjust) - bx2 = bx2 - (x1 + x1_adjust) - by1 = by1 - (y1 + y1_adjust) - by2 = by2 - (y1 + y1_adjust) - - region_mask = mask[y1:y2, x1:x2].copy() - - focus_mask = cv2.rectangle( - np.zeros_like(region_mask), - (bx1, by1), - (bx2, by2), - (255, 255, 255), - -1, - ) + overflow_y = (y2 - y1) % 8 - region_mask = apply_mask( - region_mask, np.zeros_like(region_mask), focus_mask, True - ) + y1_adjust = 0 + if overflow_y != 0: + if y2 > y1: + y2 -= overflow_y + else: + y1 += overflow_y + y1_adjust = overflow_y + + bx1 = bx1 - (x1 + x1_adjust) + bx2 = bx2 - (x1 + x1_adjust) + by1 = by1 - (y1 + y1_adjust) + by2 = by2 - (y1 + y1_adjust) + + region_mask = mask[y1:y2, x1:x2].copy() + + focus_mask = cv2.rectangle( + np.zeros_like(region_mask), + (bx1, by1), + (bx2, by2), + (255, 255, 255), + -1, + ) - if has_white(region_mask): - ( - target_region_x1, - target_region_y1, - target_region_x2, - target_region_y2, - ) = get_masked_bounds(region_mask) - - section_to_in_paint = final[y1:y2, x1:x2] - - section_to_refine = section_to_in_paint[ - target_region_y1:target_region_y2, target_region_x1:target_region_x2 - ] - section_to_refine_mask = region_mask[ - target_region_y1:target_region_y2, target_region_x1:target_region_x2 - ] - - # Generate a mask of the actual characters/text - refined_mask = np.zeros_like(region_mask) - refined_mask[ - target_region_y1:target_region_y2, target_region_x1:target_region_x2 - ] = mask_text_for_in_painting(section_to_refine, section_to_refine_mask) - - # The text mask is used for other stuff so we set it here before we dilate for inpainting - text_mask[y1:y2, x1:x2][ - target_region_y1:target_region_y2, target_region_x1:target_region_x2 - ] = refined_mask[ - target_region_y1:target_region_y2, target_region_x1:target_region_x2 - ].copy() - - # Dilate the text mask for inpainting - kernel = np.ones( - (mask_dilation_kernel_size, mask_dilation_kernel_size), np.uint8 + region_mask = apply_mask( + region_mask, np.zeros_like(region_mask), focus_mask, True ) - refined_mask = cv2.dilate(refined_mask, kernel, iterations=1) - # Inpaint using the dilated text mask - final[y1:y2, x1:x2][ - target_region_y1:target_region_y2, target_region_x1:target_region_x2 - ] = (await inpaint_fun(final[y1:y2, x1:x2], refined_mask))[ - target_region_y1:target_region_y2, target_region_x1:target_region_x2 - ] + if has_white(region_mask): + ( + target_region_x1, + target_region_y1, + target_region_x2, + target_region_y2, + ) = get_masked_bounds(region_mask) + + section_to_in_paint = final[y1:y2, x1:x2] + + section_to_refine = section_to_in_paint[ + target_region_y1:target_region_y2, target_region_x1:target_region_x2 + ] + section_to_refine_mask = region_mask[ + target_region_y1:target_region_y2, target_region_x1:target_region_x2 + ] + + # Generate a mask of the actual characters/text + refined_mask = np.zeros_like(region_mask) + refined_mask[ + target_region_y1:target_region_y2, target_region_x1:target_region_x2 + ] = mask_text_for_in_painting(section_to_refine, section_to_refine_mask) + + # The text mask is used for other stuff so we set it here before we dilate for inpainting + text_mask[y1:y2, x1:x2][ + target_region_y1:target_region_y2, target_region_x1:target_region_x2 + ] = refined_mask[ + target_region_y1:target_region_y2, target_region_x1:target_region_x2 + ].copy() + + # Dilate the text mask for inpainting + kernel = np.ones( + (mask_dilation_kernel_size, mask_dilation_kernel_size), np.uint8 + ) + refined_mask = cv2.dilate(refined_mask, kernel, iterations=1) + + # Inpaint using the dilated text mask + final[y1:y2, x1:x2][ + target_region_y1:target_region_y2, target_region_x1:target_region_x2 + ] = inpaint_fun(final[y1:y2, x1:x2], refined_mask)[ + target_region_y1:target_region_y2, target_region_x1:target_region_x2 + ] + except: + traceback.print_exc() + continue return final, text_mask @@ -1329,6 +1363,45 @@ def roboflow_coco_to_yolo(dataset_dir): os.remove(annotations_path) +def union(box1: tuple[int,int,int,int], box2: tuple[int,int,int,int]) -> float: + box1_x1, box1_y1, box1_x2, box1_y2 = box1 + box2_x1, box2_y1, box2_x2, box2_y2 = box2 + + box1_area = (box1_x2 - box1_x1) * (box1_y2 - box1_y1) + box2_area = (box2_x2 - box2_x1) * (box2_y2 - box2_y1) + return box1_area + box2_area - intersection(box1, box2) + +def intersection(box1: tuple[int,int,int,int], box2: tuple[int,int,int,int]) -> float: + box1_x1, box1_y1, box1_x2, box1_y2 = box1 + box2_x1, box2_y1, box2_x2, box2_y2 = box2 + x1 = max(box1_x1, box2_x1) + y1 = max(box1_y1, box2_y1) + x2 = min(box1_x2, box2_x2) + y2 = min(box1_y2, box2_y2) + return (x2 - x1) * (y2 - y1) + + +def overlap_area(box1: tuple[int,int,int,int], box2: tuple[int,int,int,int]): + box1_x1, box1_y1, box1_x2, box1_y2 = box1 + box2_x1, box2_y1, box2_x2, box2_y2 = box2 + + if box1_x2 < box2_x1 or (not (box1_y1 <= box2_y1 <= box1_y2) and not (box2_y1 <= box1_y1 <= box2_y2) and not (box1_y1 <= box2_y2 <= box1_y2) and not (box2_y1 <= box1_y2 <= box2_y2)): + return 0 + + return 1 + +def overlap_percent(box1: tuple[int,int,int,int], box2: tuple[int,int,int,int]) -> float: + if (box2_x1 - box1_x1) < 0: + area_overlaped = overlap_area(box2,box1) + else: + area_overlaped = overlap_area(box1,box2) + + if area_overlaped == 0: + return 0 + + + box1_x1, box1_y1, box1_x2, box1_y2 = box1 + box2_x1, box2_y1, box2_x2, box2_y2 = box2 def is_cuda_available(): return torch.cuda.is_available() and torch.cuda.device_count() > 0 diff --git a/.eslintrc.json b/ui/.eslintrc.json similarity index 100% rename from .eslintrc.json rename to ui/.eslintrc.json diff --git a/.prettierignore b/ui/.prettierignore similarity index 100% rename from .prettierignore rename to ui/.prettierignore diff --git a/.prettierrc.json b/ui/.prettierrc.json similarity index 100% rename from .prettierrc.json rename to ui/.prettierrc.json diff --git a/ui/build/asset-manifest.json b/ui/build/asset-manifest.json index 4820754..89c6dda 100644 --- a/ui/build/asset-manifest.json +++ b/ui/build/asset-manifest.json @@ -1,15 +1,15 @@ { "files": { "main.css": "/static/css/main.bd8e5710.css", - "main.js": "/static/js/main.a6910cdf.js", + "main.js": "/static/js/main.774d98f5.js", "static/js/787.776d5c80.chunk.js": "/static/js/787.776d5c80.chunk.js", "index.html": "/index.html", "main.bd8e5710.css.map": "/static/css/main.bd8e5710.css.map", - "main.a6910cdf.js.map": "/static/js/main.a6910cdf.js.map", + "main.774d98f5.js.map": "/static/js/main.774d98f5.js.map", "787.776d5c80.chunk.js.map": "/static/js/787.776d5c80.chunk.js.map" }, "entrypoints": [ "static/css/main.bd8e5710.css", - "static/js/main.a6910cdf.js" + "static/js/main.774d98f5.js" ] } \ No newline at end of file diff --git a/ui/build/index.html b/ui/build/index.html index 75bd852..f9d4f80 100644 --- a/ui/build/index.html +++ b/ui/build/index.html @@ -1 +1 @@ -Manga Translator Sample
\ No newline at end of file +Manga Translator Sample
\ No newline at end of file diff --git a/ui/build/static/js/main.a6910cdf.js b/ui/build/static/js/main.774d98f5.js similarity index 92% rename from ui/build/static/js/main.a6910cdf.js rename to ui/build/static/js/main.774d98f5.js index 775227a..2636d3a 100644 --- a/ui/build/static/js/main.a6910cdf.js +++ b/ui/build/static/js/main.774d98f5.js @@ -1,3 +1,3 @@ -/*! For license information please see main.a6910cdf.js.LICENSE.txt */ -!function(){"use strict";var e={110:function(e,t,n){var r=n(309),a={childContextTypes:!0,contextType:!0,contextTypes:!0,defaultProps:!0,displayName:!0,getDefaultProps:!0,getDerivedStateFromError:!0,getDerivedStateFromProps:!0,mixins:!0,propTypes:!0,type:!0},o={name:!0,length:!0,prototype:!0,caller:!0,callee:!0,arguments:!0,arity:!0},l={$$typeof:!0,compare:!0,defaultProps:!0,displayName:!0,propTypes:!0,type:!0},i={};function u(e){return r.isMemo(e)?l:i[e.$$typeof]||a}i[r.ForwardRef]={$$typeof:!0,render:!0,defaultProps:!0,displayName:!0,propTypes:!0},i[r.Memo]=l;var c=Object.defineProperty,s=Object.getOwnPropertyNames,f=Object.getOwnPropertySymbols,d=Object.getOwnPropertyDescriptor,p=Object.getPrototypeOf,h=Object.prototype;e.exports=function e(t,n,r){if("string"!==typeof n){if(h){var a=p(n);a&&a!==h&&e(t,a,r)}var l=s(n);f&&(l=l.concat(f(n)));for(var i=u(t),v=u(n),m=0;m