Skip to content

Commit

Permalink
add docling-models utility
Browse files Browse the repository at this point in the history
Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
  • Loading branch information
dolfim-ibm committed Feb 4, 2025
1 parent 18aad34 commit dc9e759
Show file tree
Hide file tree
Showing 12 changed files with 315 additions and 34 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
run: |
for file in docs/examples/*.py; do
# Skip batch_convert.py
if [[ "$(basename "$file")" =~ ^(batch_convert|minimal|export_multimodal|custom_convert|develop_picture_enrichment).py ]]; then
if [[ "$(basename "$file")" =~ ^(batch_convert|minimal|export_multimodal|custom_convert|develop_picture_enrichment|offline_convert).py ]]; then
echo "Skipping $file"
continue
fi
Expand Down
160 changes: 160 additions & 0 deletions docling/cli/models_download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import logging
import warnings
from pathlib import Path
from typing import Annotated

import typer

from docling.datamodel.settings import settings
from docling.models.code_formula_model import CodeFormulaModel
from docling.models.document_picture_classifier import DocumentPictureClassifier
from docling.models.easyocr_model import EasyOcrModel
from docling.models.layout_model import LayoutModel
from docling.models.rapid_ocr_model import RapidOcrModel
from docling.models.table_structure_model import TableStructureModel

warnings.filterwarnings(action="ignore", category=UserWarning, module="pydantic|torch")
warnings.filterwarnings(action="ignore", category=FutureWarning, module="easyocr")

_log = logging.getLogger(__name__)
from rich.console import Console

console = Console()
err_console = Console(stderr=True)


app = typer.Typer(
name="Docling model helper",
add_completion=False,
pretty_exceptions_enable=False,
)


@app.command("download")
def download(
output_dir: Annotated[
Path,
typer.Option(
...,
"-o",
"--output-dir",
help="The directory where all the models are downloaded.",
),
] = settings.cache_dir
/ "models",
force: Annotated[
bool, typer.Option(..., help="If true, the download will be forced")
] = False,
quite: Annotated[
bool,
typer.Option(
...,
"-q",
help="No extra output is generated, the CLI print only the directory with the cached models.",
),
] = False,
layout: Annotated[
bool,
typer.Option(..., help="If true, the layout model weights are downloaded."),
] = True,
tableformer: Annotated[
bool,
typer.Option(
..., help="If true, the tableformer model weights are downloaded."
),
] = True,
code_formula: Annotated[
bool,
typer.Option(
..., help="If true, the code formula model weights are downloaded."
),
] = True,
picture_classifier: Annotated[
bool,
typer.Option(
..., help="If true, the picture classifier model weights are downloaded."
),
] = True,
easyocr: Annotated[
bool,
typer.Option(..., help="If true, the easyocr model weights are downloaded."),
] = True,
rapidocr: Annotated[
bool,
typer.Option(..., help="If true, the rapidocr model weights are downloaded."),
] = True,
):
# Make sure the folder exists
output_dir.mkdir(exist_ok=True, parents=True)

show_progress = not quite

if layout:
if not quite:
typer.secho(f"Downloading layout model...", fg="blue")
LayoutModel.download_models_hf(
local_dir=output_dir / LayoutModel._model_repo_folder,
force=force,
progress=show_progress,
)

if tableformer:
if not quite:
typer.secho(f"Downloading tableformer model...", fg="blue")
TableStructureModel.download_models_hf(
local_dir=output_dir / TableStructureModel._model_repo_folder,
force=force,
progress=show_progress,
)

if picture_classifier:
if not quite:
typer.secho(f"Downloading picture classifier model...", fg="blue")
DocumentPictureClassifier.download_models_hf(
local_dir=output_dir / DocumentPictureClassifier._model_repo_folder,
force=force,
progress=show_progress,
)

if code_formula:
if not quite:
typer.secho(f"Downloading code formula model...", fg="blue")
CodeFormulaModel.download_models_hf(
local_dir=output_dir / CodeFormulaModel._model_repo_folder,
force=force,
progress=show_progress,
)

if easyocr:
if not quite:
typer.secho(f"Downloading easyocr models...", fg="blue")
EasyOcrModel.download_models(
local_dir=output_dir / EasyOcrModel._model_repo_folder,
force=force,
progress=show_progress,
)

if quite:
typer.echo(output_dir)
else:
typer.secho(f"All models downloaded in the directory {output_dir}.", fg="green")

console.print(
"\n",
"Docling can now be configured for running offline using the local artifacts.\n\n",
"Using the CLI:",
"`docling --artifacts-path={output_dir} FILE`",
"\n",
"Using Python: see the documentation at <https://ds4sd.github.io/docling/usage>.",
)


@app.command(hidden=True)
def other():
raise NotImplementedError()


click_app = typer.main.get_command(app)

if __name__ == "__main__":
app()
7 changes: 5 additions & 2 deletions docling/models/code_formula_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,15 @@ def __init__(

@staticmethod
def download_models_hf(
local_dir: Optional[Path] = None, force: bool = False
local_dir: Optional[Path] = None,
force: bool = False,
progress: bool = False,
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars

disable_progress_bars()
if not progress:
disable_progress_bars()
download_path = snapshot_download(
repo_id="ds4sd/CodeFormula",
force_download=force,
Expand Down
5 changes: 3 additions & 2 deletions docling/models/document_picture_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,13 @@ def __init__(

@staticmethod
def download_models_hf(
local_dir: Optional[Path] = None, force: bool = False
local_dir: Optional[Path] = None, force: bool = False, progress: bool = False
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars

disable_progress_bars()
if not progress:
disable_progress_bars()
download_path = snapshot_download(
repo_id="ds4sd/DocumentFigureClassifier",
force_download=force,
Expand Down
42 changes: 41 additions & 1 deletion docling/models/easyocr_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import logging
import warnings
from typing import Iterable
import zipfile
from pathlib import Path
from typing import Iterable, List, Optional

import httpx
import numpy
import torch
from docling_core.types.doc import BoundingBox, CoordOrigin
Expand All @@ -17,11 +20,14 @@
from docling.models.base_ocr_model import BaseOcrModel
from docling.utils.accelerator_utils import decide_device
from docling.utils.profiling import TimeRecorder
from docling.utils.utils import download_url_with_progress

_log = logging.getLogger(__name__)


class EasyOcrModel(BaseOcrModel):
_model_repo_folder = "EasyOcr"

def __init__(
self,
enabled: bool,
Expand Down Expand Up @@ -71,6 +77,40 @@ def __init__(
verbose=False,
)

@staticmethod
def download_models(
detection_models: List[str] = ["craft"],
recognition_models: List[str] = ["english_g2", "latin_g2"],
local_dir: Optional[Path] = None,
force: bool = False,
progress: bool = False,
) -> Path:
# Models are located in https://github.com/JaidedAI/EasyOCR/blob/master/easyocr/config.py
from easyocr.config import detection_models as det_models_dict
from easyocr.config import recognition_models as rec_models_dict

if local_dir is None:
local_dir = settings.cache_dir / "models" / EasyOcrModel._model_repo_folder

local_dir.mkdir(parents=True, exist_ok=True)

# Collect models to download
download_list = []
for model_name in detection_models:
if model_name in det_models_dict:
download_list.append(det_models_dict[model_name])
for model_name in recognition_models:
if model_name in rec_models_dict["gen2"]:
download_list.append(rec_models_dict["gen2"][model_name])

# Download models
for model_details in download_list:
buf = download_url_with_progress(model_details["url"], progress=progress)
with zipfile.ZipFile(buf, "r") as zip_ref:
zip_ref.extractall(local_dir)

return local_dir

def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
Expand Down
7 changes: 5 additions & 2 deletions docling/models/layout_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,15 @@ def __init__(

@staticmethod
def download_models_hf(
local_dir: Optional[Path] = None, force: bool = False
local_dir: Optional[Path] = None,
force: bool = False,
progress: bool = False,
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars

disable_progress_bars()
if not progress:
disable_progress_bars()
download_path = snapshot_download(
repo_id="ds4sd/docling-models",
force_download=force,
Expand Down
5 changes: 3 additions & 2 deletions docling/models/table_structure_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,13 @@ def __init__(

@staticmethod
def download_models_hf(
local_dir: Optional[Path] = None, force: bool = False
local_dir: Optional[Path] = None, force: bool = False, progress: bool = False
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars

disable_progress_bars()
if not progress:
disable_progress_bars()
download_path = snapshot_download(
repo_id="ds4sd/docling-models",
force_download=force,
Expand Down
24 changes: 24 additions & 0 deletions docling/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from pathlib import Path
from typing import List, Union

import requests
from tqdm import tqdm


def chunkify(iterator, chunk_size):
"""Yield successive chunks of chunk_size from the iterable."""
Expand Down Expand Up @@ -39,3 +42,24 @@ def create_hash(string: str):
hasher.update(string.encode("utf-8"))

return hasher.hexdigest()


def download_url_with_progress(url: str, progress: bool = False) -> BytesIO:
buf = BytesIO()
with requests.get(url, stream=True, allow_redirects=True) as response:
total_size = int(response.headers.get("content-length", 0))
progress_bar = tqdm(
total=total_size,
unit="B",
unit_scale=True,
unit_divisor=1024,
disable=(not progress),
)

for chunk in response.iter_content(10 * 1024):
buf.write(chunk)
progress_bar.update(len(chunk))
progress_bar.close()

buf.seek(0)
return buf
19 changes: 19 additions & 0 deletions docs/examples/offline_convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from pathlib import Path

from docling.datamodel.base_models import InputFormat
from docling.datamodel.pipeline_options import EasyOcrOptions, PdfPipelineOptions
from docling.document_converter import DocumentConverter, PdfFormatOption

# The location of the local artifacts, e.g. from the `docling-models download` command
artifacts_path = Path("PATH TO MODELS") # <-- fill me
pipeline_options = PdfPipelineOptions(artifacts_path=artifacts_path)
pipeline_options.ocr_options = EasyOcrOptions(
download_enabled=False, model_storage_directory=str(artifacts_path / "EasyOcr")
)

doc_converter = DocumentConverter(
format_options={InputFormat.PDF: PdfFormatOption(pipeline_options=pipeline_options)}
)

result = doc_converter.convert("FILE TO CONVERT") # <-- fill me
print(result.document.export_to_markdown())
Loading

0 comments on commit dc9e759

Please sign in to comment.