-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add hardware detection to choose between Torch and MLX
- Loading branch information
Showing
11 changed files
with
233 additions
and
137 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
14 changes: 14 additions & 0 deletions
14
rag/file_conversion_router/services/tai_nougat_service/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# TAI Nougat Service | ||
|
||
The TAI Nougat Service is an optimized and refined version of the original [Meta Nougat](https://github.com/facebookresearch/nougat) service, developed to enhance performance and maintainability. | ||
|
||
## Key Enhancements | ||
|
||
- **Performance Optimization** | ||
- Implemented **Dependency Injection** to reduce repeated model loading, resulting in significant time savings. | ||
- Testing on developer laptops within our team shows a reduction of 4 to 8 seconds per PDF, depending on the specific device. | ||
- **Better Support for Apple Chips**. It supports hard-ware detection and can use MLX instead of Torch for Apple Chips. This saves conversion time significantly for Apple Chip Devices. | ||
- By @perryzjc 's testing, a M1 Pro Chip Macbook Pro takes less than 30 seconds to convert 10 pages PDF, compared to original 10 mins in @perryzjc's device. | ||
|
||
- **Code Refactoring** | ||
- Improved the readability and maintainability of the codebase to facilitate easier future development and collaboration. |
127 changes: 13 additions & 114 deletions
127
rag/file_conversion_router/services/tai_nougat_service/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,119 +1,18 @@ | ||
import logging | ||
import re | ||
from pathlib import Path | ||
from typing import List | ||
from .config_nougat.tai_nougat_config import TAINougatConfig | ||
from .mlx_nougat_service import api as mlx_nougat_api | ||
from .torch_nougat_service import api as torch_nougat_api | ||
|
||
from dependency_injector import containers, providers | ||
from dependency_injector.wiring import Provide, inject | ||
from nougat import NougatModel | ||
from nougat.postprocessing import markdown_compatible | ||
from nougat.utils.checkpoint import get_checkpoint | ||
from nougat.utils.dataset import LazyDataset | ||
from nougat.utils.device import move_to_device | ||
from torch.utils.data import ConcatDataset, DataLoader | ||
from tqdm import tqdm | ||
|
||
from .nougat_config import NougatConfig | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
|
||
|
||
def create_model(config: NougatConfig) -> NougatModel: | ||
if config.checkpoint is None or not config.checkpoint.exists(): | ||
config.checkpoint = get_checkpoint(config.checkpoint, model_tag=config.model_tag) | ||
|
||
model = NougatModel.from_pretrained(config.checkpoint) | ||
model = move_to_device(model, bf16=not config.full_precision, cuda=config.batch_size > 0) | ||
model.eval() | ||
print("model loaded") # debug | ||
return model | ||
|
||
|
||
class NougatContainer(containers.DeclarativeContainer): | ||
model = providers.Singleton( | ||
create_model, | ||
config=NougatConfig() | ||
) | ||
|
||
|
||
def load_datasets(config: NougatConfig, model: NougatModel) -> List[LazyDataset]: | ||
datasets = [] | ||
for pdf in config.pdf_paths: | ||
if not pdf.exists(): | ||
continue | ||
if config.output_dir: | ||
out_path = config.output_dir / pdf.with_suffix(".mmd").name | ||
if out_path.exists() and not config.recompute: | ||
logging.info(f"Skipping {pdf.name}, already computed. Run with recompute=True to convert again.") | ||
continue | ||
try: | ||
dataset = LazyDataset( | ||
pdf, | ||
model.encoder.prepare_input, | ||
config.pages, | ||
) | ||
datasets.append(dataset) | ||
except Exception as e: | ||
logging.info(f"Could not load file {str(pdf)}: {e}") | ||
return datasets | ||
|
||
|
||
def process_output(output: str, page_num: int, config: NougatConfig) -> str: | ||
if output.strip() == "[MISSING_PAGE_POST]": | ||
return f"\n\n[MISSING_PAGE_EMPTY:{page_num}]\n\n" | ||
if config.markdown_compatible: | ||
output = markdown_compatible(output) | ||
return output | ||
|
||
|
||
@inject | ||
def main(config: NougatConfig, model: NougatModel = Provide[NougatContainer.model]): | ||
datasets = load_datasets(config, model) | ||
|
||
if not datasets: | ||
logging.info("No valid datasets found.") | ||
return | ||
|
||
dataloader = DataLoader( | ||
ConcatDataset(datasets), | ||
batch_size=config.batch_size, | ||
shuffle=False, | ||
collate_fn=LazyDataset.ignore_none_collate, | ||
) | ||
|
||
predictions = [] | ||
file_index = 0 | ||
page_num = 0 | ||
|
||
for sample, is_last_page in tqdm(dataloader): | ||
model_output = model.inference(image_tensors=sample, early_stopping=config.skipping) | ||
|
||
for j, output in enumerate(model_output["predictions"]): | ||
if page_num == 0: | ||
logging.info(f"Processing file {datasets[file_index].name} with {datasets[file_index].size} pages") | ||
page_num += 1 | ||
|
||
processed_output = process_output(output, page_num, config) | ||
predictions.append(processed_output) | ||
if is_last_page[j]: | ||
out = "".join(predictions).strip() | ||
out = re.sub(r"\n{3,}", "\n\n", out).strip() | ||
if config.output_dir: | ||
out_path = config.output_dir / Path(is_last_page[j]).with_suffix(".mmd").name | ||
out_path.parent.mkdir(parents=True, exist_ok=True) | ||
out_path.write_text(out, encoding="utf-8") | ||
else: | ||
print(out, "\n\n") | ||
predictions = [] | ||
page_num = 0 | ||
file_index += 1 | ||
|
||
|
||
def run_nougat(config: NougatConfig): | ||
def run_nougat(config: TAINougatConfig): | ||
"""Run Nougat with the provided configuration. | ||
model is initialized using config only on the first time run_nougat is called . | ||
""" | ||
if not hasattr(run_nougat, "container"): | ||
run_nougat.container = NougatContainer() | ||
run_nougat.container.wire(modules=[__name__]) | ||
main(config) | ||
input_path = config.pdf_paths[0] | ||
output_dir = config.output_dir | ||
|
||
using_torch = config.using_torch | ||
|
||
if using_torch: | ||
torch_nougat_api.convert_pdf_to_mmd(input_path, output_dir) | ||
else: | ||
mlx_nougat_api.convert_pdf_to_mmd(input_path, output_dir) |
9 changes: 3 additions & 6 deletions
9
rag/file_conversion_router/services/tai_nougat_service/api.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
36 changes: 36 additions & 0 deletions
36
rag/file_conversion_router/services/tai_nougat_service/config_nougat/tai_nougat_config.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
"""Compared to the default NougatConfig, TAINougatConfig includes an additional configurable field, using_torch | ||
Currently, other TAI Nougat functionalities, such as Dependency Injection, are not configurable. | ||
""" | ||
import logging | ||
from dataclasses import dataclass, field | ||
from typing import Callable, Union | ||
|
||
from .nougat_config import NougatConfig | ||
|
||
from rag.file_conversion_router.utils.hardware_detection import detect_is_apple_silicon | ||
|
||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def determine_using_torch() -> bool: | ||
is_apple_silicon = detect_is_apple_silicon() | ||
logger.info(f"Apple Silicon detected: {is_apple_silicon}") | ||
# By default, using torch is False on Apple Silicon, True otherwise | ||
return not is_apple_silicon | ||
|
||
|
||
@dataclass | ||
class TAINougatConfig(NougatConfig): | ||
# Allow either a boolean or a callable that returns a boolean | ||
using_torch: Union[bool, Callable[[], bool]] = field(default_factory=determine_using_torch) | ||
|
||
def __post_init__(self): | ||
if callable(self.using_torch): | ||
self.using_torch = self.using_torch() | ||
|
||
if self.using_torch: | ||
logger.info("Using Torch for Nougat PDF conversion.") | ||
else: | ||
logger.info("Using MLX for Nougat PDF conversion for Apple Silicon for better performance.") |
5 changes: 5 additions & 0 deletions
5
...le_conversion_router/services/tai_nougat_service/torch_nougat_service/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Torch Nougat Service | ||
|
||
We have migrated code from the Meta Nougat repository to enable more flexible development of Nougat services, including support for **Dependency Injection**. | ||
|
||
The original Meta Nougat is built on Torch. |
Oops, something went wrong.