Skip to content

Commit

Permalink
Add hardware detection to choose between Torch and MLX
Browse files Browse the repository at this point in the history
  • Loading branch information
perryzjc committed Aug 31, 2024
1 parent 3442454 commit e234c34
Show file tree
Hide file tree
Showing 11 changed files with 233 additions and 137 deletions.
25 changes: 10 additions & 15 deletions rag/file_conversion_router/conversion/pdf_converter.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,19 @@
import os
import re
import subprocess
from pathlib import Path
import os

import fitz
import re
from pix2text import Pix2Text

from rag.file_conversion_router.conversion.base_converter import BaseConverter
from rag.file_conversion_router.utils.hardware_detection import detect_gpu_setup
from rag.file_conversion_router.services.tai_nougat_service import TAINougatConfig
from rag.file_conversion_router.services.tai_nougat_service.api import convert_pdf_to_mmd
from rag.file_conversion_router.services.tai_nougat_service.nougat_config import NougatConfig


class PdfConverter(BaseConverter):
def __init__(self):
super().__init__()
self.device_type, _ = detect_gpu_setup() # Detect hardware configuration

self._logger.info(f"Using {self.device_type} on Torch")

def convert_pdf_to_markdown(self, pdf_file_path, output_file_path, page_numbers=None):
"""
Expand Down Expand Up @@ -129,6 +126,7 @@ def _to_markdown(self, input_path: Path, output_path: Path) -> Path:
# Convert the PDF to Markdown using Nougat.
# self._to_markdown_using_native_nougat_cli(pdf_without_images_path, output_path)
self._to_markdown_using_tai_nougat(pdf_without_images_path, output_path)
# self._to_markdown_using_mlx_nougat(pdf_without_images_path, output_path)

# Now change the file name of generated mmd file to align with the expected md file path from base converter
output_mmd_path = output_path.with_suffix(".mmd")
Expand All @@ -139,7 +137,6 @@ def _to_markdown(self, input_path: Path, output_path: Path) -> Path:
print(output_mmd_path)
return target


def _to_markdown_using_native_nougat_cli(self, input_pdf_path: Path, output_path: Path) -> None:
"""
Perform PDF to Markdown conversion using Native Nougat CLI.
Expand Down Expand Up @@ -167,15 +164,13 @@ def _to_markdown_using_native_nougat_cli(self, input_pdf_path: Path, output_path
self._logger.error(f"An error occurred: {str(e)}")
raise

@staticmethod
def _to_markdown_using_tai_nougat(input_pdf_path: Path, output_path: Path) -> None:
"""Perform PDF to Markdown conversion using TAI Nougat.
def _to_markdown_using_tai_nougat(self, input_pdf_path: Path, output_path: Path) -> None:
"""
Perform PDF to Markdown conversion using TAI Nougat.
TAI nougat is our custom implementation of the Nougat API, with better abstraction
and especially optimization on avoiding loading nougat model repetitively.
TAI nougat is our custom implementation of the Nougat API, with better performance and abstraction.
"""
config = NougatConfig(
config = TAINougatConfig(
pdf_paths=[input_pdf_path],
output_dir=output_path.parent,
)
Expand Down
14 changes: 14 additions & 0 deletions rag/file_conversion_router/services/tai_nougat_service/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# TAI Nougat Service

The TAI Nougat Service is an optimized and refined version of the original [Meta Nougat](https://github.com/facebookresearch/nougat) service, developed to enhance performance and maintainability.

## Key Enhancements

- **Performance Optimization**
- Implemented **Dependency Injection** to reduce repeated model loading, resulting in significant time savings.
- Testing on developer laptops within our team shows a reduction of 4 to 8 seconds per PDF, depending on the specific device.
- **Better Support for Apple Chips**. It supports hard-ware detection and can use MLX instead of Torch for Apple Chips. This saves conversion time significantly for Apple Chip Devices.
- By @perryzjc 's testing, a M1 Pro Chip Macbook Pro takes less than 30 seconds to convert 10 pages PDF, compared to original 10 mins in @perryzjc's device.

- **Code Refactoring**
- Improved the readability and maintainability of the codebase to facilitate easier future development and collaboration.
127 changes: 13 additions & 114 deletions rag/file_conversion_router/services/tai_nougat_service/__init__.py
Original file line number Diff line number Diff line change
@@ -1,119 +1,18 @@
import logging
import re
from pathlib import Path
from typing import List
from .config_nougat.tai_nougat_config import TAINougatConfig
from .mlx_nougat_service import api as mlx_nougat_api
from .torch_nougat_service import api as torch_nougat_api

from dependency_injector import containers, providers
from dependency_injector.wiring import Provide, inject
from nougat import NougatModel
from nougat.postprocessing import markdown_compatible
from nougat.utils.checkpoint import get_checkpoint
from nougat.utils.dataset import LazyDataset
from nougat.utils.device import move_to_device
from torch.utils.data import ConcatDataset, DataLoader
from tqdm import tqdm

from .nougat_config import NougatConfig

logging.basicConfig(level=logging.INFO)


def create_model(config: NougatConfig) -> NougatModel:
if config.checkpoint is None or not config.checkpoint.exists():
config.checkpoint = get_checkpoint(config.checkpoint, model_tag=config.model_tag)

model = NougatModel.from_pretrained(config.checkpoint)
model = move_to_device(model, bf16=not config.full_precision, cuda=config.batch_size > 0)
model.eval()
print("model loaded") # debug
return model


class NougatContainer(containers.DeclarativeContainer):
model = providers.Singleton(
create_model,
config=NougatConfig()
)


def load_datasets(config: NougatConfig, model: NougatModel) -> List[LazyDataset]:
datasets = []
for pdf in config.pdf_paths:
if not pdf.exists():
continue
if config.output_dir:
out_path = config.output_dir / pdf.with_suffix(".mmd").name
if out_path.exists() and not config.recompute:
logging.info(f"Skipping {pdf.name}, already computed. Run with recompute=True to convert again.")
continue
try:
dataset = LazyDataset(
pdf,
model.encoder.prepare_input,
config.pages,
)
datasets.append(dataset)
except Exception as e:
logging.info(f"Could not load file {str(pdf)}: {e}")
return datasets


def process_output(output: str, page_num: int, config: NougatConfig) -> str:
if output.strip() == "[MISSING_PAGE_POST]":
return f"\n\n[MISSING_PAGE_EMPTY:{page_num}]\n\n"
if config.markdown_compatible:
output = markdown_compatible(output)
return output


@inject
def main(config: NougatConfig, model: NougatModel = Provide[NougatContainer.model]):
datasets = load_datasets(config, model)

if not datasets:
logging.info("No valid datasets found.")
return

dataloader = DataLoader(
ConcatDataset(datasets),
batch_size=config.batch_size,
shuffle=False,
collate_fn=LazyDataset.ignore_none_collate,
)

predictions = []
file_index = 0
page_num = 0

for sample, is_last_page in tqdm(dataloader):
model_output = model.inference(image_tensors=sample, early_stopping=config.skipping)

for j, output in enumerate(model_output["predictions"]):
if page_num == 0:
logging.info(f"Processing file {datasets[file_index].name} with {datasets[file_index].size} pages")
page_num += 1

processed_output = process_output(output, page_num, config)
predictions.append(processed_output)
if is_last_page[j]:
out = "".join(predictions).strip()
out = re.sub(r"\n{3,}", "\n\n", out).strip()
if config.output_dir:
out_path = config.output_dir / Path(is_last_page[j]).with_suffix(".mmd").name
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text(out, encoding="utf-8")
else:
print(out, "\n\n")
predictions = []
page_num = 0
file_index += 1


def run_nougat(config: NougatConfig):
def run_nougat(config: TAINougatConfig):
"""Run Nougat with the provided configuration.
model is initialized using config only on the first time run_nougat is called .
"""
if not hasattr(run_nougat, "container"):
run_nougat.container = NougatContainer()
run_nougat.container.wire(modules=[__name__])
main(config)
input_path = config.pdf_paths[0]
output_dir = config.output_dir

using_torch = config.using_torch

if using_torch:
torch_nougat_api.convert_pdf_to_mmd(input_path, output_dir)
else:
mlx_nougat_api.convert_pdf_to_mmd(input_path, output_dir)
9 changes: 3 additions & 6 deletions rag/file_conversion_router/services/tai_nougat_service/api.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
import logging

from .__init__ import run_nougat
from .nougat_config import NougatConfig
from .config_nougat.tai_nougat_config import TAINougatConfig

logging.basicConfig(level=logging.INFO)


def convert_pdf_to_mmd(config: NougatConfig) -> None:
def convert_pdf_to_mmd(config: TAINougatConfig) -> None:
"""Converts a PDF file to MMD format using TAI Nougat.
"""
logging.info(
"Initialized NougatConfig",
extra={"config": config}
)
logging.info(f"Initialized NougatConfig. Config: {config}")

try:
logging.info("Executing conversion process.")
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional
from typing import List, Optional, Literal

from nougat.utils.device import default_batch_size

Expand All @@ -9,7 +9,7 @@
class NougatConfig:
batch_size: int = default_batch_size()
checkpoint: Optional[Path] = None
model_tag: str = "0.1.0-base"
model_tag: Literal["0.1.0-base", "1.0.0-small"] = "0.1.0-base"
output_dir: Optional[Path] = None
recompute: bool = True
full_precision: bool = False
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Compared to the default NougatConfig, TAINougatConfig includes an additional configurable field, using_torch
Currently, other TAI Nougat functionalities, such as Dependency Injection, are not configurable.
"""
import logging
from dataclasses import dataclass, field
from typing import Callable, Union

from .nougat_config import NougatConfig

from rag.file_conversion_router.utils.hardware_detection import detect_is_apple_silicon

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)


def determine_using_torch() -> bool:
is_apple_silicon = detect_is_apple_silicon()
logger.info(f"Apple Silicon detected: {is_apple_silicon}")
# By default, using torch is False on Apple Silicon, True otherwise
return not is_apple_silicon


@dataclass
class TAINougatConfig(NougatConfig):
# Allow either a boolean or a callable that returns a boolean
using_torch: Union[bool, Callable[[], bool]] = field(default_factory=determine_using_torch)

def __post_init__(self):
if callable(self.using_torch):
self.using_torch = self.using_torch()

if self.using_torch:
logger.info("Using Torch for Nougat PDF conversion.")
else:
logger.info("Using MLX for Nougat PDF conversion for Apple Silicon for better performance.")
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Torch Nougat Service

We have migrated code from the Meta Nougat repository to enable more flexible development of Nougat services, including support for **Dependency Injection**.

The original Meta Nougat is built on Torch.
Loading

0 comments on commit e234c34

Please sign in to comment.