feat: new vlm-models support (#1570)

* feat: adding new vlm-models support

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* fixed the transformers

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* got microsoft/Phi-4-multimodal-instruct to work

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* working on vlm's

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* refactoring the VLM part

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* all working, now serious refacgtoring necessary

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* refactoring the download_model

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* added the formulate_prompt

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* pixtral 12b runs via MLX and native transformers

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* added the VlmPredictionToken

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* refactoring minimal_vlm_pipeline

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* fixed the MyPy

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* added pipeline_model_specializations file

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* need to get Phi4 working again ...

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* finalising last points for vlms support

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* fixed the pipeline for Phi4

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* streamlining all code

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* reformatted the code

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* fixing the tests

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* added the html backend to the VLM pipeline

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* fixed the static load_from_doctags

Signed-off-by: Peter Staar <taa@zurich.ibm.com>

* restore stable imports

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* use AutoModelForVision2Seq for Pixtral and review example (including rename)

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* remove unused value

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* refactor instances of VLM models

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* skip compare example in CI

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* use lowercase and uppercase only

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* add new minimal_vlm example and refactor pipeline_options_vlm_model for cleaner import

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* rename pipeline_vlm_model_spec

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* move more argument to options and simplify model init

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* add supported_devices

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* remove not-needed function

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* exclude minimal_vlm

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* missing file

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* add message for transformers version

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* rename to specs

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* use module import and remove MLX from non-darwin

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* remove hf_vlm_model and add extra_generation_args

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* use single HF VLM model class

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* remove torch type

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* add docs for vision models

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

---------

Signed-off-by: Peter Staar <taa@zurich.ibm.com>
Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
Co-authored-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
Peter W. J. Staar 2025-06-02 17:01:06 +02:00 committed by GitHub
parent 08dcacc5cb
commit cfdf4cea25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
46 changed files with 1968 additions and 1902 deletions

View File

@ -51,7 +51,7 @@ jobs:
run: |
for file in docs/examples/*.py; do
# Skip batch_convert.py
if [[ "$(basename "$file")" =~ ^(batch_convert|minimal_vlm_pipeline|minimal|export_multimodal|custom_convert|develop_picture_enrichment|rapidocr_with_custom_models|offline_convert|pictures_description|pictures_description_api|vlm_pipeline_api_model).py ]]; then
if [[ "$(basename "$file")" =~ ^(batch_convert|compare_vlm_models|minimal|minimal_vlm_pipeline|export_multimodal|custom_convert|develop_picture_enrichment|rapidocr_with_custom_models|offline_convert|pictures_description|pictures_description_api|vlm_pipeline_api_model).py ]]; then
echo "Skipping $file"
continue
fi

View File

@ -36,7 +36,7 @@ Docling simplifies document processing, parsing diverse formats — including ad
* 🔒 Local execution capabilities for sensitive data and air-gapped environments
* 🤖 Plug-and-play [integrations][integrations] incl. LangChain, LlamaIndex, Crew AI & Haystack for agentic AI
* 🔍 Extensive OCR support for scanned PDFs and images
* 🥚 Support of Visual Language Models ([SmolDocling](https://huggingface.co/ds4sd/SmolDocling-256M-preview)) 🆕
* 🥚 Support of several Visual Language Models ([SmolDocling](https://huggingface.co/ds4sd/SmolDocling-256M-preview))
* 💻 Simple and convenient CLI
### Coming soon

View File

@ -28,6 +28,7 @@ from docling.backend.docling_parse_v2_backend import DoclingParseV2DocumentBacke
from docling.backend.docling_parse_v4_backend import DoclingParseV4DocumentBackend
from docling.backend.pdf_backend import PdfDocumentBackend
from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
from docling.datamodel.base_models import (
ConversionStatus,
FormatToExtensions,
@ -36,8 +37,6 @@ from docling.datamodel.base_models import (
)
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import (
AcceleratorDevice,
AcceleratorOptions,
EasyOcrOptions,
OcrOptions,
PaginatedPipelineOptions,
@ -45,14 +44,16 @@ from docling.datamodel.pipeline_options import (
PdfPipeline,
PdfPipelineOptions,
TableFormerMode,
VlmModelType,
VlmPipelineOptions,
granite_vision_vlm_conversion_options,
granite_vision_vlm_ollama_conversion_options,
smoldocling_vlm_conversion_options,
smoldocling_vlm_mlx_conversion_options,
)
from docling.datamodel.settings import settings
from docling.datamodel.vlm_model_specs import (
GRANITE_VISION_OLLAMA,
GRANITE_VISION_TRANSFORMERS,
SMOLDOCLING_MLX,
SMOLDOCLING_TRANSFORMERS,
VlmModelType,
)
from docling.document_converter import DocumentConverter, FormatOption, PdfFormatOption
from docling.models.factories import get_ocr_factory
from docling.pipeline.vlm_pipeline import VlmPipeline
@ -579,20 +580,16 @@ def convert( # noqa: C901
)
if vlm_model == VlmModelType.GRANITE_VISION:
pipeline_options.vlm_options = granite_vision_vlm_conversion_options
pipeline_options.vlm_options = GRANITE_VISION_TRANSFORMERS
elif vlm_model == VlmModelType.GRANITE_VISION_OLLAMA:
pipeline_options.vlm_options = (
granite_vision_vlm_ollama_conversion_options
)
pipeline_options.vlm_options = GRANITE_VISION_OLLAMA
elif vlm_model == VlmModelType.SMOLDOCLING:
pipeline_options.vlm_options = smoldocling_vlm_conversion_options
pipeline_options.vlm_options = SMOLDOCLING_TRANSFORMERS
if sys.platform == "darwin":
try:
import mlx_vlm
pipeline_options.vlm_options = (
smoldocling_vlm_mlx_conversion_options
)
pipeline_options.vlm_options = SMOLDOCLING_MLX
except ImportError:
_log.warning(
"To run SmolDocling faster, please install mlx-vlm:\n"

View File

@ -0,0 +1,68 @@
import logging
import os
import re
from enum import Enum
from typing import Any, Union
from pydantic import field_validator, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
_log = logging.getLogger(__name__)
class AcceleratorDevice(str, Enum):
"""Devices to run model inference"""
AUTO = "auto"
CPU = "cpu"
CUDA = "cuda"
MPS = "mps"
class AcceleratorOptions(BaseSettings):
model_config = SettingsConfigDict(
env_prefix="DOCLING_", env_nested_delimiter="_", populate_by_name=True
)
num_threads: int = 4
device: Union[str, AcceleratorDevice] = "auto"
cuda_use_flash_attention2: bool = False
@field_validator("device")
def validate_device(cls, value):
# "auto", "cpu", "cuda", "mps", or "cuda:N"
if value in {d.value for d in AcceleratorDevice} or re.match(
r"^cuda(:\d+)?$", value
):
return value
raise ValueError(
"Invalid device option. Use 'auto', 'cpu', 'mps', 'cuda', or 'cuda:N'."
)
@model_validator(mode="before")
@classmethod
def check_alternative_envvars(cls, data: Any) -> Any:
r"""
Set num_threads from the "alternative" envvar OMP_NUM_THREADS.
The alternative envvar is used only if it is valid and the regular envvar is not set.
Notice: The standard pydantic settings mechanism with parameter "aliases" does not provide
the same functionality. In case the alias envvar is set and the user tries to override the
parameter in settings initialization, Pydantic treats the parameter provided in __init__()
as an extra input instead of simply overwriting the evvar value for that parameter.
"""
if isinstance(data, dict):
input_num_threads = data.get("num_threads")
# Check if to set the num_threads from the alternative envvar
if input_num_threads is None:
docling_num_threads = os.getenv("DOCLING_NUM_THREADS")
omp_num_threads = os.getenv("OMP_NUM_THREADS")
if docling_num_threads is None and omp_num_threads is not None:
try:
data["num_threads"] = int(omp_num_threads)
except ValueError:
_log.error(
"Ignoring misformatted envvar OMP_NUM_THREADS '%s'",
omp_num_threads,
)
return data

View File

@ -13,11 +13,11 @@ from docling_core.types.doc import (
TableCell,
)
from docling_core.types.doc.page import SegmentedPdfPage, TextCell
# DO NOT REMOVE; explicitly exposed from this location
from docling_core.types.io import (
DocumentStream,
)
# DO NOT REMOVE; explicitly exposed from this location
from PIL.Image import Image
from pydantic import BaseModel, ConfigDict, Field, computed_field
@ -131,12 +131,6 @@ class ErrorItem(BaseModel):
error_message: str
# class Cell(BaseModel):
# id: int
# text: str
# bbox: BoundingBox
class Cluster(BaseModel):
id: int
label: DocItemLabel
@ -158,8 +152,16 @@ class LayoutPrediction(BaseModel):
clusters: List[Cluster] = []
class VlmPredictionToken(BaseModel):
text: str = ""
token: int = -1
logprob: float = -1
class VlmPrediction(BaseModel):
text: str = ""
generated_tokens: list[VlmPredictionToken] = []
generation_time: float = -1
class ContainerElement(

View File

@ -1,6 +1,4 @@
import logging
import os
import re
from enum import Enum
from pathlib import Path
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union
@ -10,73 +8,28 @@ from pydantic import (
BaseModel,
ConfigDict,
Field,
field_validator,
model_validator,
)
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing_extensions import deprecated
# Import the following for backwards compatibility
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
from docling.datamodel.pipeline_options_vlm_model import (
ApiVlmOptions,
InferenceFramework,
InlineVlmOptions,
ResponseFormat,
)
from docling.datamodel.vlm_model_specs import (
GRANITE_VISION_OLLAMA as granite_vision_vlm_ollama_conversion_options,
GRANITE_VISION_TRANSFORMERS as granite_vision_vlm_conversion_options,
SMOLDOCLING_MLX as smoldocling_vlm_mlx_conversion_options,
SMOLDOCLING_TRANSFORMERS as smoldocling_vlm_conversion_options,
VlmModelType,
)
_log = logging.getLogger(__name__)
class AcceleratorDevice(str, Enum):
"""Devices to run model inference"""
AUTO = "auto"
CPU = "cpu"
CUDA = "cuda"
MPS = "mps"
class AcceleratorOptions(BaseSettings):
model_config = SettingsConfigDict(
env_prefix="DOCLING_", env_nested_delimiter="_", populate_by_name=True
)
num_threads: int = 4
device: Union[str, AcceleratorDevice] = "auto"
cuda_use_flash_attention2: bool = False
@field_validator("device")
def validate_device(cls, value):
# "auto", "cpu", "cuda", "mps", or "cuda:N"
if value in {d.value for d in AcceleratorDevice} or re.match(
r"^cuda(:\d+)?$", value
):
return value
raise ValueError(
"Invalid device option. Use 'auto', 'cpu', 'mps', 'cuda', or 'cuda:N'."
)
@model_validator(mode="before")
@classmethod
def check_alternative_envvars(cls, data: Any) -> Any:
r"""
Set num_threads from the "alternative" envvar OMP_NUM_THREADS.
The alternative envvar is used only if it is valid and the regular envvar is not set.
Notice: The standard pydantic settings mechanism with parameter "aliases" does not provide
the same functionality. In case the alias envvar is set and the user tries to override the
parameter in settings initialization, Pydantic treats the parameter provided in __init__()
as an extra input instead of simply overwriting the evvar value for that parameter.
"""
if isinstance(data, dict):
input_num_threads = data.get("num_threads")
# Check if to set the num_threads from the alternative envvar
if input_num_threads is None:
docling_num_threads = os.getenv("DOCLING_NUM_THREADS")
omp_num_threads = os.getenv("OMP_NUM_THREADS")
if docling_num_threads is None and omp_num_threads is not None:
try:
data["num_threads"] = int(omp_num_threads)
except ValueError:
_log.error(
"Ignoring misformatted envvar OMP_NUM_THREADS '%s'",
omp_num_threads,
)
return data
class BaseOptions(BaseModel):
"""Base class for options."""
@ -121,24 +74,22 @@ class RapidOcrOptions(OcrOptions):
lang: List[str] = [
"english",
"chinese",
] # However, language as a parameter is not supported by rapidocr yet and hence changing this options doesn't affect anything.
# For more details on supported languages by RapidOCR visit https://rapidai.github.io/RapidOCRDocs/blog/2022/09/28/%E6%94%AF%E6%8C%81%E8%AF%86%E5%88%AB%E8%AF%AD%E8%A8%80/
]
# However, language as a parameter is not supported by rapidocr yet
# and hence changing this options doesn't affect anything.
# For more details on supported languages by RapidOCR visit
# https://rapidai.github.io/RapidOCRDocs/blog/2022/09/28/%E6%94%AF%E6%8C%81%E8%AF%86%E5%88%AB%E8%AF%AD%E8%A8%80/
# For more details on the following options visit
# https://rapidai.github.io/RapidOCRDocs/install_usage/api/RapidOCR/
# For more details on the following options visit https://rapidai.github.io/RapidOCRDocs/install_usage/api/RapidOCR/
text_score: float = 0.5 # same default as rapidocr
use_det: Optional[bool] = None # same default as rapidocr
use_cls: Optional[bool] = None # same default as rapidocr
use_rec: Optional[bool] = None # same default as rapidocr
# class Device(Enum):
# CPU = "CPU"
# CUDA = "CUDA"
# DIRECTML = "DIRECTML"
# AUTO = "AUTO"
# device: Device = Device.AUTO # Default value is AUTO
print_verbose: bool = False # same default as rapidocr
det_model_path: Optional[str] = None # same default as rapidocr
@ -244,101 +195,18 @@ class PictureDescriptionVlmOptions(PictureDescriptionBaseOptions):
return self.repo_id.replace("/", "--")
# SmolVLM
smolvlm_picture_description = PictureDescriptionVlmOptions(
repo_id="HuggingFaceTB/SmolVLM-256M-Instruct"
)
# phi_picture_description = PictureDescriptionVlmOptions(repo_id="microsoft/Phi-3-vision-128k-instruct")
# GraniteVision
granite_picture_description = PictureDescriptionVlmOptions(
repo_id="ibm-granite/granite-vision-3.1-2b-preview",
prompt="What is shown in this image?",
)
class BaseVlmOptions(BaseModel):
kind: str
prompt: str
class ResponseFormat(str, Enum):
DOCTAGS = "doctags"
MARKDOWN = "markdown"
class InferenceFramework(str, Enum):
MLX = "mlx"
TRANSFORMERS = "transformers"
OPENAI = "openai"
class HuggingFaceVlmOptions(BaseVlmOptions):
kind: Literal["hf_model_options"] = "hf_model_options"
repo_id: str
load_in_8bit: bool = True
llm_int8_threshold: float = 6.0
quantized: bool = False
inference_framework: InferenceFramework
response_format: ResponseFormat
@property
def repo_cache_folder(self) -> str:
return self.repo_id.replace("/", "--")
class ApiVlmOptions(BaseVlmOptions):
kind: Literal["api_model_options"] = "api_model_options"
url: AnyUrl = AnyUrl(
"http://localhost:11434/v1/chat/completions"
) # Default to ollama
headers: Dict[str, str] = {}
params: Dict[str, Any] = {}
scale: float = 2.0
timeout: float = 60
concurrency: int = 1
response_format: ResponseFormat
smoldocling_vlm_mlx_conversion_options = HuggingFaceVlmOptions(
repo_id="ds4sd/SmolDocling-256M-preview-mlx-bf16",
prompt="Convert this page to docling.",
response_format=ResponseFormat.DOCTAGS,
inference_framework=InferenceFramework.MLX,
)
smoldocling_vlm_conversion_options = HuggingFaceVlmOptions(
repo_id="ds4sd/SmolDocling-256M-preview",
prompt="Convert this page to docling.",
response_format=ResponseFormat.DOCTAGS,
inference_framework=InferenceFramework.TRANSFORMERS,
)
granite_vision_vlm_conversion_options = HuggingFaceVlmOptions(
repo_id="ibm-granite/granite-vision-3.1-2b-preview",
# prompt="OCR the full page to markdown.",
prompt="OCR this image.",
response_format=ResponseFormat.MARKDOWN,
inference_framework=InferenceFramework.TRANSFORMERS,
)
granite_vision_vlm_ollama_conversion_options = ApiVlmOptions(
url=AnyUrl("http://localhost:11434/v1/chat/completions"),
params={"model": "granite3.2-vision:2b"},
prompt="OCR the full page to markdown.",
scale=1.0,
timeout=120,
response_format=ResponseFormat.MARKDOWN,
)
class VlmModelType(str, Enum):
SMOLDOCLING = "smoldocling"
GRANITE_VISION = "granite_vision"
GRANITE_VISION_OLLAMA = "granite_vision_ollama"
# Define an enum for the backend options
class PdfBackend(str, Enum):
"""Enum of valid PDF backends."""
@ -387,7 +255,7 @@ class VlmPipelineOptions(PaginatedPipelineOptions):
False # (To be used with vlms, or other generative models)
)
# If True, text from backend will be used instead of generated text
vlm_options: Union[HuggingFaceVlmOptions, ApiVlmOptions] = (
vlm_options: Union[InlineVlmOptions, ApiVlmOptions] = (
smoldocling_vlm_conversion_options
)

View File

@ -0,0 +1,81 @@
from enum import Enum
from typing import Any, Dict, List, Literal
from pydantic import AnyUrl, BaseModel
from typing_extensions import deprecated
from docling.datamodel.accelerator_options import AcceleratorDevice
class BaseVlmOptions(BaseModel):
kind: str
prompt: str
class ResponseFormat(str, Enum):
DOCTAGS = "doctags"
MARKDOWN = "markdown"
HTML = "html"
class InferenceFramework(str, Enum):
MLX = "mlx"
TRANSFORMERS = "transformers"
class TransformersModelType(str, Enum):
AUTOMODEL = "automodel"
AUTOMODEL_VISION2SEQ = "automodel-vision2seq"
AUTOMODEL_CAUSALLM = "automodel-causallm"
class InlineVlmOptions(BaseVlmOptions):
kind: Literal["inline_model_options"] = "inline_model_options"
repo_id: str
trust_remote_code: bool = False
load_in_8bit: bool = True
llm_int8_threshold: float = 6.0
quantized: bool = False
inference_framework: InferenceFramework
transformers_model_type: TransformersModelType = TransformersModelType.AUTOMODEL
response_format: ResponseFormat
supported_devices: List[AcceleratorDevice] = [
AcceleratorDevice.CPU,
AcceleratorDevice.CUDA,
AcceleratorDevice.MPS,
]
scale: float = 2.0
temperature: float = 0.0
stop_strings: List[str] = []
extra_generation_config: Dict[str, Any] = {}
use_kv_cache: bool = True
max_new_tokens: int = 4096
@property
def repo_cache_folder(self) -> str:
return self.repo_id.replace("/", "--")
@deprecated("Use InlineVlmOptions instead.")
class HuggingFaceVlmOptions(InlineVlmOptions):
pass
class ApiVlmOptions(BaseVlmOptions):
kind: Literal["api_model_options"] = "api_model_options"
url: AnyUrl = AnyUrl(
"http://localhost:11434/v1/chat/completions"
) # Default to ollama
headers: Dict[str, str] = {}
params: Dict[str, Any] = {}
scale: float = 2.0
timeout: float = 60
concurrency: int = 1
response_format: ResponseFormat

View File

@ -0,0 +1,144 @@
import logging
from enum import Enum
from pydantic import (
AnyUrl,
)
from docling.datamodel.accelerator_options import AcceleratorDevice
from docling.datamodel.pipeline_options_vlm_model import (
ApiVlmOptions,
InferenceFramework,
InlineVlmOptions,
ResponseFormat,
TransformersModelType,
)
_log = logging.getLogger(__name__)
# SmolDocling
SMOLDOCLING_MLX = InlineVlmOptions(
repo_id="ds4sd/SmolDocling-256M-preview-mlx-bf16",
prompt="Convert this page to docling.",
response_format=ResponseFormat.DOCTAGS,
inference_framework=InferenceFramework.MLX,
supported_devices=[AcceleratorDevice.MPS],
scale=2.0,
temperature=0.0,
)
SMOLDOCLING_TRANSFORMERS = InlineVlmOptions(
repo_id="ds4sd/SmolDocling-256M-preview",
prompt="Convert this page to docling.",
response_format=ResponseFormat.DOCTAGS,
inference_framework=InferenceFramework.TRANSFORMERS,
transformers_model_type=TransformersModelType.AUTOMODEL_VISION2SEQ,
supported_devices=[
AcceleratorDevice.CPU,
AcceleratorDevice.CUDA,
AcceleratorDevice.MPS,
],
scale=2.0,
temperature=0.0,
)
# GraniteVision
GRANITE_VISION_TRANSFORMERS = InlineVlmOptions(
repo_id="ibm-granite/granite-vision-3.2-2b",
prompt="Convert this page to markdown. Do not miss any text and only output the bare markdown!",
response_format=ResponseFormat.MARKDOWN,
inference_framework=InferenceFramework.TRANSFORMERS,
transformers_model_type=TransformersModelType.AUTOMODEL_VISION2SEQ,
supported_devices=[
AcceleratorDevice.CPU,
AcceleratorDevice.CUDA,
AcceleratorDevice.MPS,
],
scale=2.0,
temperature=0.0,
)
GRANITE_VISION_OLLAMA = ApiVlmOptions(
url=AnyUrl("http://localhost:11434/v1/chat/completions"),
params={"model": "granite3.2-vision:2b"},
prompt="Convert this page to markdown. Do not miss any text and only output the bare markdown!",
scale=1.0,
timeout=120,
response_format=ResponseFormat.MARKDOWN,
temperature=0.0,
)
# Pixtral
PIXTRAL_12B_TRANSFORMERS = InlineVlmOptions(
repo_id="mistral-community/pixtral-12b",
prompt="Convert this page to markdown. Do not miss any text and only output the bare markdown!",
response_format=ResponseFormat.MARKDOWN,
inference_framework=InferenceFramework.TRANSFORMERS,
transformers_model_type=TransformersModelType.AUTOMODEL_VISION2SEQ,
supported_devices=[AcceleratorDevice.CPU, AcceleratorDevice.CUDA],
scale=2.0,
temperature=0.0,
)
PIXTRAL_12B_MLX = InlineVlmOptions(
repo_id="mlx-community/pixtral-12b-bf16",
prompt="Convert this page to markdown. Do not miss any text and only output the bare markdown!",
response_format=ResponseFormat.MARKDOWN,
inference_framework=InferenceFramework.MLX,
supported_devices=[AcceleratorDevice.MPS],
scale=2.0,
temperature=0.0,
)
# Phi4
PHI4_TRANSFORMERS = InlineVlmOptions(
repo_id="microsoft/Phi-4-multimodal-instruct",
prompt="Convert this page to MarkDown. Do not miss any text and only output the bare markdown",
trust_remote_code=True,
response_format=ResponseFormat.MARKDOWN,
inference_framework=InferenceFramework.TRANSFORMERS,
transformers_model_type=TransformersModelType.AUTOMODEL_CAUSALLM,
supported_devices=[AcceleratorDevice.CPU, AcceleratorDevice.CUDA],
scale=2.0,
temperature=0.0,
extra_generation_config=dict(num_logits_to_keep=0),
)
# Qwen
QWEN25_VL_3B_MLX = InlineVlmOptions(
repo_id="mlx-community/Qwen2.5-VL-3B-Instruct-bf16",
prompt="Convert this page to markdown. Do not miss any text and only output the bare markdown!",
response_format=ResponseFormat.MARKDOWN,
inference_framework=InferenceFramework.MLX,
supported_devices=[AcceleratorDevice.MPS],
scale=2.0,
temperature=0.0,
)
# Gemma-3
GEMMA3_12B_MLX = InlineVlmOptions(
repo_id="mlx-community/gemma-3-12b-it-bf16",
prompt="Convert this page to markdown. Do not miss any text and only output the bare markdown!",
response_format=ResponseFormat.MARKDOWN,
inference_framework=InferenceFramework.MLX,
supported_devices=[AcceleratorDevice.MPS],
scale=2.0,
temperature=0.0,
)
GEMMA3_27B_MLX = InlineVlmOptions(
repo_id="mlx-community/gemma-3-27b-it-bf16",
prompt="Convert this page to markdown. Do not miss any text and only output the bare markdown!",
response_format=ResponseFormat.MARKDOWN,
inference_framework=InferenceFramework.MLX,
supported_devices=[AcceleratorDevice.MPS],
scale=2.0,
temperature=0.0,
)
class VlmModelType(str, Enum):
SMOLDOCLING = "smoldocling"
GRANITE_VISION = "granite_vision"
GRANITE_VISION_OLLAMA = "granite_vision_ollama"

View File

@ -186,6 +186,11 @@ class DocumentConverter:
Tuple[Type[BasePipeline], str], BasePipeline
] = {}
def _get_initialized_pipelines(
self,
) -> dict[tuple[Type[BasePipeline], str], BasePipeline]:
return self.initialized_pipelines
def _get_pipeline_options_hash(self, pipeline_options: PipelineOptions) -> str:
"""Generate a hash of pipeline options to use as part of the cache key."""
options_str = str(pipeline_options.model_dump())

View File

@ -3,7 +3,7 @@ from concurrent.futures import ThreadPoolExecutor
from docling.datamodel.base_models import Page, VlmPrediction
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import ApiVlmOptions
from docling.datamodel.pipeline_options_vlm_model import ApiVlmOptions
from docling.exceptions import OperationNotAllowed
from docling.models.base_model import BasePageModel
from docling.utils.api_image_request import api_image_request

View File

@ -11,9 +11,10 @@ from PIL import Image, ImageDraw
from rtree import index
from scipy.ndimage import binary_dilation, find_objects, label
from docling.datamodel.accelerator_options import AcceleratorOptions
from docling.datamodel.base_models import Page
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import AcceleratorOptions, OcrOptions
from docling.datamodel.pipeline_options import OcrOptions
from docling.datamodel.settings import settings
from docling.models.base_model import BaseModelWithOptions, BasePageModel

View File

@ -16,9 +16,10 @@ from docling_core.types.doc.labels import CodeLanguageLabel
from PIL import Image, ImageOps
from pydantic import BaseModel
from docling.datamodel.accelerator_options import AcceleratorOptions
from docling.datamodel.base_models import ItemAndImageEnrichmentElement
from docling.datamodel.pipeline_options import AcceleratorOptions
from docling.models.base_model import BaseItemAndImageEnrichmentModel
from docling.models.utils.hf_model_download import download_hf_model
from docling.utils.accelerator_utils import decide_device
@ -117,20 +118,14 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
force: bool = False,
progress: bool = False,
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
if not progress:
disable_progress_bars()
download_path = snapshot_download(
return download_hf_model(
repo_id="ds4sd/CodeFormula",
force_download=force,
local_dir=local_dir,
revision="v1.0.2",
local_dir=local_dir,
force=force,
progress=progress,
)
return Path(download_path)
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
"""
Determines if a given element in a document can be processed by the model.

View File

@ -13,8 +13,9 @@ from docling_core.types.doc import (
from PIL import Image
from pydantic import BaseModel
from docling.datamodel.pipeline_options import AcceleratorOptions
from docling.datamodel.accelerator_options import AcceleratorOptions
from docling.models.base_model import BaseEnrichmentModel
from docling.models.utils.hf_model_download import download_hf_model
from docling.utils.accelerator_utils import decide_device
@ -105,20 +106,14 @@ class DocumentPictureClassifier(BaseEnrichmentModel):
def download_models(
local_dir: Optional[Path] = None, force: bool = False, progress: bool = False
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
if not progress:
disable_progress_bars()
download_path = snapshot_download(
return download_hf_model(
repo_id="ds4sd/DocumentFigureClassifier",
force_download=force,
local_dir=local_dir,
revision="v1.0.1",
local_dir=local_dir,
force=force,
progress=progress,
)
return Path(download_path)
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
"""
Determines if the given element can be processed by the classifier.

View File

@ -9,11 +9,10 @@ import numpy
from docling_core.types.doc import BoundingBox, CoordOrigin
from docling_core.types.doc.page import BoundingRectangle, TextCell
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
from docling.datamodel.base_models import Page
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import (
AcceleratorDevice,
AcceleratorOptions,
EasyOcrOptions,
OcrOptions,
)

View File

@ -1,182 +0,0 @@
import logging
import time
from collections.abc import Iterable
from pathlib import Path
from typing import Optional
from docling.datamodel.base_models import Page, VlmPrediction
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import (
AcceleratorOptions,
HuggingFaceVlmOptions,
)
from docling.models.base_model import BasePageModel
from docling.utils.accelerator_utils import decide_device
from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__)
class HuggingFaceVlmModel(BasePageModel):
def __init__(
self,
enabled: bool,
artifacts_path: Optional[Path],
accelerator_options: AcceleratorOptions,
vlm_options: HuggingFaceVlmOptions,
):
self.enabled = enabled
self.vlm_options = vlm_options
if self.enabled:
import torch
from transformers import ( # type: ignore
AutoModelForVision2Seq,
AutoProcessor,
BitsAndBytesConfig,
)
device = decide_device(accelerator_options.device)
self.device = device
_log.debug(f"Available device for HuggingFace VLM: {device}")
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
# PARAMETERS:
if artifacts_path is None:
artifacts_path = self.download_models(self.vlm_options.repo_id)
elif (artifacts_path / repo_cache_folder).exists():
artifacts_path = artifacts_path / repo_cache_folder
self.param_question = vlm_options.prompt # "Perform Layout Analysis."
self.param_quantization_config = BitsAndBytesConfig(
load_in_8bit=vlm_options.load_in_8bit, # True,
llm_int8_threshold=vlm_options.llm_int8_threshold, # 6.0
)
self.param_quantized = vlm_options.quantized # False
self.processor = AutoProcessor.from_pretrained(artifacts_path)
if not self.param_quantized:
self.vlm_model = AutoModelForVision2Seq.from_pretrained(
artifacts_path,
device_map=device,
torch_dtype=torch.bfloat16,
_attn_implementation=(
"flash_attention_2"
if self.device.startswith("cuda")
and accelerator_options.cuda_use_flash_attention2
else "eager"
),
) # .to(self.device)
else:
self.vlm_model = AutoModelForVision2Seq.from_pretrained(
artifacts_path,
device_map=device,
torch_dtype="auto",
quantization_config=self.param_quantization_config,
_attn_implementation=(
"flash_attention_2"
if self.device.startswith("cuda")
and accelerator_options.cuda_use_flash_attention2
else "eager"
),
) # .to(self.device)
@staticmethod
def download_models(
repo_id: str,
local_dir: Optional[Path] = None,
force: bool = False,
progress: bool = False,
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
if not progress:
disable_progress_bars()
download_path = snapshot_download(
repo_id=repo_id,
force_download=force,
local_dir=local_dir,
# revision="v0.0.1",
)
return Path(download_path)
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
for page in page_batch:
assert page._backend is not None
if not page._backend.is_valid():
yield page
else:
with TimeRecorder(conv_res, "vlm"):
assert page.size is not None
hi_res_image = page.get_image(scale=2.0) # 144dpi
# hi_res_image = page.get_image(scale=1.0) # 72dpi
if hi_res_image is not None:
im_width, im_height = hi_res_image.size
# populate page_tags with predicted doc tags
page_tags = ""
if hi_res_image:
if hi_res_image.mode != "RGB":
hi_res_image = hi_res_image.convert("RGB")
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "This is a page from a document.",
},
{"type": "image"},
{"type": "text", "text": self.param_question},
],
}
]
prompt = self.processor.apply_chat_template(
messages, add_generation_prompt=False
)
inputs = self.processor(
text=prompt, images=[hi_res_image], return_tensors="pt"
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
start_time = time.time()
# Call model to generate:
generated_ids = self.vlm_model.generate(
**inputs, max_new_tokens=4096, use_cache=True
)
generation_time = time.time() - start_time
generated_texts = self.processor.batch_decode(
generated_ids[:, inputs["input_ids"].shape[1] :],
skip_special_tokens=False,
)[0]
num_tokens = len(generated_ids[0])
page_tags = generated_texts
_log.debug(
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
)
# inference_time = time.time() - start_time
# tokens_per_second = num_tokens / generation_time
# print("")
# print(f"Page Inference Time: {inference_time:.2f} seconds")
# print(f"Total tokens on page: {num_tokens:.2f}")
# print(f"Tokens/sec: {tokens_per_second:.2f}")
# print("")
page.predictions.vlm_response = VlmPrediction(text=page_tags)
yield page

View File

@ -10,11 +10,12 @@ from docling_core.types.doc import DocItemLabel
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
from PIL import Image
from docling.datamodel.accelerator_options import AcceleratorOptions
from docling.datamodel.base_models import BoundingBox, Cluster, LayoutPrediction, Page
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import AcceleratorOptions
from docling.datamodel.settings import settings
from docling.models.base_model import BasePageModel
from docling.models.utils.hf_model_download import download_hf_model
from docling.utils.accelerator_utils import decide_device
from docling.utils.layout_postprocessor import LayoutPostprocessor
from docling.utils.profiling import TimeRecorder
@ -83,20 +84,14 @@ class LayoutModel(BasePageModel):
force: bool = False,
progress: bool = False,
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
if not progress:
disable_progress_bars()
download_path = snapshot_download(
return download_hf_model(
repo_id="ds4sd/docling-models",
force_download=force,
revision="v2.2.0",
local_dir=local_dir,
revision="v2.1.0",
force=force,
progress=progress,
)
return Path(download_path)
def draw_clusters_and_cells_side_by_side(
self, conv_res, page, clusters, mode_prefix: str, show: bool = False
):

View File

@ -8,10 +8,10 @@ from typing import Optional, Type
from docling_core.types.doc import BoundingBox, CoordOrigin
from docling_core.types.doc.page import BoundingRectangle, TextCell
from docling.datamodel.accelerator_options import AcceleratorOptions
from docling.datamodel.base_models import Page
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import (
AcceleratorOptions,
OcrMacOptions,
OcrOptions,
)

View File

@ -5,8 +5,8 @@ from typing import Optional, Type, Union
from PIL import Image
from docling.datamodel.accelerator_options import AcceleratorOptions
from docling.datamodel.pipeline_options import (
AcceleratorOptions,
PictureDescriptionApiOptions,
PictureDescriptionBaseOptions,
)

View File

@ -13,8 +13,8 @@ from docling_core.types.doc.document import ( # TODO: move import to docling_co
)
from PIL import Image
from docling.datamodel.accelerator_options import AcceleratorOptions
from docling.datamodel.pipeline_options import (
AcceleratorOptions,
PictureDescriptionBaseOptions,
)
from docling.models.base_model import (

View File

@ -4,16 +4,21 @@ from typing import Optional, Type, Union
from PIL import Image
from docling.datamodel.accelerator_options import AcceleratorOptions
from docling.datamodel.pipeline_options import (
AcceleratorOptions,
PictureDescriptionBaseOptions,
PictureDescriptionVlmOptions,
)
from docling.models.picture_description_base_model import PictureDescriptionBaseModel
from docling.models.utils.hf_model_download import (
HuggingFaceModelDownloadMixin,
)
from docling.utils.accelerator_utils import decide_device
class PictureDescriptionVlmModel(PictureDescriptionBaseModel):
class PictureDescriptionVlmModel(
PictureDescriptionBaseModel, HuggingFaceModelDownloadMixin
):
@classmethod
def get_options_type(cls) -> Type[PictureDescriptionBaseOptions]:
return PictureDescriptionVlmOptions
@ -66,26 +71,6 @@ class PictureDescriptionVlmModel(PictureDescriptionBaseModel):
self.provenance = f"{self.options.repo_id}"
@staticmethod
def download_models(
repo_id: str,
local_dir: Optional[Path] = None,
force: bool = False,
progress: bool = False,
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
if not progress:
disable_progress_bars()
download_path = snapshot_download(
repo_id=repo_id,
force_download=force,
local_dir=local_dir,
)
return Path(download_path)
def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]:
from transformers import GenerationConfig

View File

@ -7,11 +7,10 @@ import numpy
from docling_core.types.doc import BoundingBox, CoordOrigin
from docling_core.types.doc.page import BoundingRectangle, TextCell
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
from docling.datamodel.base_models import Page
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import (
AcceleratorDevice,
AcceleratorOptions,
OcrOptions,
RapidOcrOptions,
)

View File

@ -13,16 +13,16 @@ from docling_core.types.doc.page import (
from docling_ibm_models.tableformer.data_management.tf_predictor import TFPredictor
from PIL import ImageDraw
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
from docling.datamodel.base_models import Page, Table, TableStructurePrediction
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import (
AcceleratorDevice,
AcceleratorOptions,
TableFormerMode,
TableStructureOptions,
)
from docling.datamodel.settings import settings
from docling.models.base_model import BasePageModel
from docling.models.utils.hf_model_download import download_hf_model
from docling.utils.accelerator_utils import decide_device
from docling.utils.profiling import TimeRecorder
@ -90,20 +90,14 @@ class TableStructureModel(BasePageModel):
def download_models(
local_dir: Optional[Path] = None, force: bool = False, progress: bool = False
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
if not progress:
disable_progress_bars()
download_path = snapshot_download(
return download_hf_model(
repo_id="ds4sd/docling-models",
force_download=force,
local_dir=local_dir,
revision="v2.2.0",
local_dir=local_dir,
force=force,
progress=progress,
)
return Path(download_path)
def draw_table_and_cells(
self,
conv_res: ConversionResult,

View File

@ -13,10 +13,10 @@ import pandas as pd
from docling_core.types.doc import BoundingBox, CoordOrigin
from docling_core.types.doc.page import TextCell
from docling.datamodel.accelerator_options import AcceleratorOptions
from docling.datamodel.base_models import Page
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import (
AcceleratorOptions,
OcrOptions,
TesseractCliOcrOptions,
)

View File

@ -7,10 +7,10 @@ from typing import Iterable, Optional, Type
from docling_core.types.doc import BoundingBox, CoordOrigin
from docling_core.types.doc.page import TextCell
from docling.datamodel.accelerator_options import AcceleratorOptions
from docling.datamodel.base_models import Page
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import (
AcceleratorOptions,
OcrOptions,
TesseractOcrOptions,
)

View File

View File

@ -0,0 +1,40 @@
import logging
from pathlib import Path
from typing import Optional
_log = logging.getLogger(__name__)
def download_hf_model(
repo_id: str,
local_dir: Optional[Path] = None,
force: bool = False,
progress: bool = False,
revision: Optional[str] = None,
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
if not progress:
disable_progress_bars()
download_path = snapshot_download(
repo_id=repo_id,
force_download=force,
local_dir=local_dir,
revision=revision,
)
return Path(download_path)
class HuggingFaceModelDownloadMixin:
@staticmethod
def download_models(
repo_id: str,
local_dir: Optional[Path] = None,
force: bool = False,
progress: bool = False,
) -> Path:
return download_hf_model(
repo_id=repo_id, local_dir=local_dir, force=force, progress=progress
)

View File

@ -0,0 +1,194 @@
import importlib.metadata
import logging
import time
from collections.abc import Iterable
from pathlib import Path
from typing import Any, Optional
from docling.datamodel.accelerator_options import (
AcceleratorOptions,
)
from docling.datamodel.base_models import Page, VlmPrediction
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options_vlm_model import (
InlineVlmOptions,
TransformersModelType,
)
from docling.models.base_model import BasePageModel
from docling.models.utils.hf_model_download import (
HuggingFaceModelDownloadMixin,
)
from docling.utils.accelerator_utils import decide_device
from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__)
class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMixin):
def __init__(
self,
enabled: bool,
artifacts_path: Optional[Path],
accelerator_options: AcceleratorOptions,
vlm_options: InlineVlmOptions,
):
self.enabled = enabled
self.vlm_options = vlm_options
if self.enabled:
import torch
from transformers import (
AutoModel,
AutoModelForCausalLM,
AutoModelForVision2Seq,
AutoProcessor,
BitsAndBytesConfig,
GenerationConfig,
)
transformers_version = importlib.metadata.version("transformers")
if (
self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct"
and transformers_version >= "4.52.0"
):
raise NotImplementedError(
f"Phi 4 only works with transformers<4.52.0 but you have {transformers_version=}. Please downgrage running pip install -U 'transformers<4.52.0'."
)
self.device = decide_device(
accelerator_options.device,
supported_devices=vlm_options.supported_devices,
)
_log.debug(f"Available device for VLM: {self.device}")
self.use_cache = vlm_options.use_kv_cache
self.max_new_tokens = vlm_options.max_new_tokens
self.temperature = vlm_options.temperature
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
if artifacts_path is None:
artifacts_path = self.download_models(self.vlm_options.repo_id)
elif (artifacts_path / repo_cache_folder).exists():
artifacts_path = artifacts_path / repo_cache_folder
self.param_quantization_config: Optional[BitsAndBytesConfig] = None
if vlm_options.quantized:
self.param_quantization_config = BitsAndBytesConfig(
load_in_8bit=vlm_options.load_in_8bit,
llm_int8_threshold=vlm_options.llm_int8_threshold,
)
model_cls: Any = AutoModel
if (
self.vlm_options.transformers_model_type
== TransformersModelType.AUTOMODEL_CAUSALLM
):
model_cls = AutoModelForCausalLM
elif (
self.vlm_options.transformers_model_type
== TransformersModelType.AUTOMODEL_VISION2SEQ
):
model_cls = AutoModelForVision2Seq
self.processor = AutoProcessor.from_pretrained(
artifacts_path,
trust_remote_code=vlm_options.trust_remote_code,
)
self.vlm_model = model_cls.from_pretrained(
artifacts_path,
device_map=self.device,
_attn_implementation=(
"flash_attention_2"
if self.device.startswith("cuda")
and accelerator_options.cuda_use_flash_attention2
else "eager"
),
trust_remote_code=vlm_options.trust_remote_code,
)
# Load generation config
self.generation_config = GenerationConfig.from_pretrained(artifacts_path)
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
for page in page_batch:
assert page._backend is not None
if not page._backend.is_valid():
yield page
else:
with TimeRecorder(conv_res, "vlm"):
assert page.size is not None
hi_res_image = page.get_image(scale=self.vlm_options.scale)
# Define prompt structure
prompt = self.formulate_prompt()
inputs = self.processor(
text=prompt, images=[hi_res_image], return_tensors="pt"
).to(self.device)
start_time = time.time()
# Call model to generate:
generated_ids = self.vlm_model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
use_cache=self.use_cache,
temperature=self.temperature,
generation_config=self.generation_config,
**self.vlm_options.extra_generation_config,
)
generation_time = time.time() - start_time
generated_texts = self.processor.batch_decode(
generated_ids[:, inputs["input_ids"].shape[1] :],
skip_special_tokens=False,
)[0]
num_tokens = len(generated_ids[0])
_log.debug(
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
)
page.predictions.vlm_response = VlmPrediction(
text=generated_texts,
generation_time=generation_time,
)
yield page
def formulate_prompt(self) -> str:
"""Formulate a prompt for the VLM."""
if self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct":
_log.debug("Using specialized prompt for Phi-4")
# more info here: https://huggingface.co/microsoft/Phi-4-multimodal-instruct#loading-the-model-locally
user_prompt = "<|user|>"
assistant_prompt = "<|assistant|>"
prompt_suffix = "<|end|>"
prompt = f"{user_prompt}<|image_1|>{self.vlm_options.prompt}{prompt_suffix}{assistant_prompt}"
_log.debug(f"prompt for {self.vlm_options.repo_id}: {prompt}")
return prompt
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "This is a page from a document.",
},
{"type": "image"},
{"type": "text", "text": self.vlm_options.prompt},
],
}
]
prompt = self.processor.apply_chat_template(
messages, add_generation_prompt=False
)
return prompt

View File

@ -4,29 +4,34 @@ from collections.abc import Iterable
from pathlib import Path
from typing import Optional
from docling.datamodel.base_models import Page, VlmPrediction
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import (
from docling.datamodel.accelerator_options import (
AcceleratorOptions,
HuggingFaceVlmOptions,
)
from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
from docling.models.base_model import BasePageModel
from docling.models.utils.hf_model_download import (
HuggingFaceModelDownloadMixin,
)
from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__)
class HuggingFaceMlxModel(BasePageModel):
class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
def __init__(
self,
enabled: bool,
artifacts_path: Optional[Path],
accelerator_options: AcceleratorOptions,
vlm_options: HuggingFaceVlmOptions,
vlm_options: InlineVlmOptions,
):
self.enabled = enabled
self.vlm_options = vlm_options
self.max_tokens = vlm_options.max_new_tokens
self.temperature = vlm_options.temperature
if self.enabled:
try:
@ -39,42 +44,24 @@ class HuggingFaceMlxModel(BasePageModel):
)
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
self.apply_chat_template = apply_chat_template
self.stream_generate = stream_generate
# PARAMETERS:
if artifacts_path is None:
artifacts_path = self.download_models(self.vlm_options.repo_id)
artifacts_path = self.download_models(
self.vlm_options.repo_id,
)
elif (artifacts_path / repo_cache_folder).exists():
artifacts_path = artifacts_path / repo_cache_folder
self.param_question = vlm_options.prompt # "Perform Layout Analysis."
self.param_question = vlm_options.prompt
## Load the model
self.vlm_model, self.processor = load(artifacts_path)
self.config = load_config(artifacts_path)
@staticmethod
def download_models(
repo_id: str,
local_dir: Optional[Path] = None,
force: bool = False,
progress: bool = False,
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
if not progress:
disable_progress_bars()
download_path = snapshot_download(
repo_id=repo_id,
force_download=force,
local_dir=local_dir,
# revision="v0.0.1",
)
return Path(download_path)
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
@ -83,12 +70,10 @@ class HuggingFaceMlxModel(BasePageModel):
if not page._backend.is_valid():
yield page
else:
with TimeRecorder(conv_res, "vlm"):
with TimeRecorder(conv_res, f"vlm-mlx-{self.vlm_options.repo_id}"):
assert page.size is not None
hi_res_image = page.get_image(scale=2.0) # 144dpi
# hi_res_image = page.get_image(scale=1.0) # 72dpi
hi_res_image = page.get_image(scale=self.vlm_options.scale)
if hi_res_image is not None:
im_width, im_height = hi_res_image.size
@ -104,16 +89,45 @@ class HuggingFaceMlxModel(BasePageModel):
)
start_time = time.time()
_log.debug("start generating ...")
# Call model to generate:
tokens: list[VlmPredictionToken] = []
output = ""
for token in self.stream_generate(
self.vlm_model,
self.processor,
prompt,
[hi_res_image],
max_tokens=4096,
max_tokens=self.max_tokens,
verbose=False,
temp=self.temperature,
):
if len(token.logprobs.shape) == 1:
tokens.append(
VlmPredictionToken(
text=token.text,
token=token.token,
logprob=token.logprobs[token.token],
)
)
elif (
len(token.logprobs.shape) == 2
and token.logprobs.shape[0] == 1
):
tokens.append(
VlmPredictionToken(
text=token.text,
token=token.token,
logprob=token.logprobs[0, token.token],
)
)
else:
_log.warning(
f"incompatible shape for logprobs: {token.logprobs.shape}"
)
output += token.text
if "</doctag>" in token.text:
break
@ -121,15 +135,13 @@ class HuggingFaceMlxModel(BasePageModel):
generation_time = time.time() - start_time
page_tags = output
_log.debug(f"Generation time {generation_time:.2f} seconds.")
# inference_time = time.time() - start_time
# tokens_per_second = num_tokens / generation_time
# print("")
# print(f"Page Inference Time: {inference_time:.2f} seconds")
# print(f"Total tokens on page: {num_tokens:.2f}")
# print(f"Tokens/sec: {tokens_per_second:.2f}")
# print("")
page.predictions.vlm_response = VlmPrediction(text=page_tags)
_log.debug(
f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time} tokens/sec)."
)
page.predictions.vlm_response = VlmPrediction(
text=page_tags,
generation_time=generation_time,
generated_tokens=tokens,
)
yield page

View File

@ -1,29 +1,46 @@
import logging
import re
from io import BytesIO
from pathlib import Path
from typing import List, Optional, Union, cast
from docling_core.types import DoclingDocument
from docling_core.types.doc import BoundingBox, DocItem, ImageRef, PictureItem, TextItem
from docling_core.types.doc import (
BoundingBox,
DocItem,
DoclingDocument,
ImageRef,
PictureItem,
ProvenanceItem,
TextItem,
)
from docling_core.types.doc.base import (
BoundingBox,
Size,
)
from docling_core.types.doc.document import DocTagsDocument
from PIL import Image as PILImage
from docling.backend.abstract_backend import AbstractDocumentBackend
from docling.backend.html_backend import HTMLDocumentBackend
from docling.backend.md_backend import MarkdownDocumentBackend
from docling.backend.pdf_backend import PdfDocumentBackend
from docling.datamodel.base_models import InputFormat, Page
from docling.datamodel.document import ConversionResult, InputDocument
from docling.datamodel.pipeline_options import (
ApiVlmOptions,
HuggingFaceVlmOptions,
InferenceFramework,
ResponseFormat,
VlmPipelineOptions,
)
from docling.datamodel.pipeline_options_vlm_model import (
ApiVlmOptions,
InferenceFramework,
InlineVlmOptions,
ResponseFormat,
)
from docling.datamodel.settings import settings
from docling.models.api_vlm_model import ApiVlmModel
from docling.models.hf_mlx_model import HuggingFaceMlxModel
from docling.models.hf_vlm_model import HuggingFaceVlmModel
from docling.models.vlm_models_inline.hf_transformers_model import (
HuggingFaceTransformersVlmModel,
)
from docling.models.vlm_models_inline.mlx_model import HuggingFaceMlxModel
from docling.pipeline.base_pipeline import PaginatedPipeline
from docling.utils.profiling import ProfilingScope, TimeRecorder
@ -66,8 +83,8 @@ class VlmPipeline(PaginatedPipeline):
vlm_options=cast(ApiVlmOptions, self.pipeline_options.vlm_options),
),
]
elif isinstance(self.pipeline_options.vlm_options, HuggingFaceVlmOptions):
vlm_options = cast(HuggingFaceVlmOptions, self.pipeline_options.vlm_options)
elif isinstance(self.pipeline_options.vlm_options, InlineVlmOptions):
vlm_options = cast(InlineVlmOptions, self.pipeline_options.vlm_options)
if vlm_options.inference_framework == InferenceFramework.MLX:
self.build_pipe = [
HuggingFaceMlxModel(
@ -77,15 +94,19 @@ class VlmPipeline(PaginatedPipeline):
vlm_options=vlm_options,
),
]
else:
elif vlm_options.inference_framework == InferenceFramework.TRANSFORMERS:
self.build_pipe = [
HuggingFaceVlmModel(
HuggingFaceTransformersVlmModel(
enabled=True, # must be always enabled for this pipeline to make sense.
artifacts_path=artifacts_path,
accelerator_options=pipeline_options.accelerator_options,
vlm_options=vlm_options,
),
]
else:
raise ValueError(
f"Could not instantiate the right type of VLM pipeline: {vlm_options.inference_framework}"
)
self.enrichment_pipe = [
# Other models working on `NodeItem` elements in the DoclingDocument
@ -116,49 +137,19 @@ class VlmPipeline(PaginatedPipeline):
self.pipeline_options.vlm_options.response_format
== ResponseFormat.DOCTAGS
):
doctags_list = []
image_list = []
for page in conv_res.pages:
predicted_doctags = ""
img = PILImage.new("RGB", (1, 1), "rgb(255,255,255)")
if page.predictions.vlm_response:
predicted_doctags = page.predictions.vlm_response.text
if page.image:
img = page.image
image_list.append(img)
doctags_list.append(predicted_doctags)
conv_res.document = self._turn_dt_into_doc(conv_res)
doctags_list_c = cast(List[Union[Path, str]], doctags_list)
image_list_c = cast(List[Union[Path, PILImage.Image]], image_list)
doctags_doc = DocTagsDocument.from_doctags_and_image_pairs(
doctags_list_c, image_list_c
)
conv_res.document = DoclingDocument.load_from_doctags(doctags_doc)
# If forced backend text, replace model predicted text with backend one
if self.force_backend_text:
scale = self.pipeline_options.images_scale
for element, _level in conv_res.document.iterate_items():
if not isinstance(element, TextItem) or len(element.prov) == 0:
continue
page_ix = element.prov[0].page_no - 1
page = conv_res.pages[page_ix]
if not page.size:
continue
crop_bbox = (
element.prov[0]
.bbox.scaled(scale=scale)
.to_top_left_origin(page_height=page.size.height * scale)
)
txt = self.extract_text_from_backend(page, crop_bbox)
element.text = txt
element.orig = txt
elif (
self.pipeline_options.vlm_options.response_format
== ResponseFormat.MARKDOWN
):
conv_res.document = self._turn_md_into_doc(conv_res)
elif (
self.pipeline_options.vlm_options.response_format == ResponseFormat.HTML
):
conv_res.document = self._turn_html_into_doc(conv_res)
else:
raise RuntimeError(
f"Unsupported VLM response format {self.pipeline_options.vlm_options.response_format}"
@ -192,23 +183,199 @@ class VlmPipeline(PaginatedPipeline):
return conv_res
def _turn_md_into_doc(self, conv_res):
predicted_text = ""
for pg_idx, page in enumerate(conv_res.pages):
def _turn_dt_into_doc(self, conv_res) -> DoclingDocument:
doctags_list = []
image_list = []
for page in conv_res.pages:
predicted_doctags = ""
img = PILImage.new("RGB", (1, 1), "rgb(255,255,255)")
if page.predictions.vlm_response:
predicted_text += page.predictions.vlm_response.text + "\n\n"
response_bytes = BytesIO(predicted_text.encode("utf8"))
out_doc = InputDocument(
path_or_stream=response_bytes,
filename=conv_res.input.file.name,
format=InputFormat.MD,
backend=MarkdownDocumentBackend,
predicted_doctags = page.predictions.vlm_response.text
if page.image:
img = page.image
image_list.append(img)
doctags_list.append(predicted_doctags)
doctags_list_c = cast(List[Union[Path, str]], doctags_list)
image_list_c = cast(List[Union[Path, PILImage.Image]], image_list)
doctags_doc = DocTagsDocument.from_doctags_and_image_pairs(
doctags_list_c, image_list_c
)
backend = MarkdownDocumentBackend(
in_doc=out_doc,
path_or_stream=response_bytes,
conv_res.document = DoclingDocument.load_from_doctags(
doctag_document=doctags_doc
)
return backend.convert()
# If forced backend text, replace model predicted text with backend one
if page.size:
if self.force_backend_text:
scale = self.pipeline_options.images_scale
for element, _level in conv_res.document.iterate_items():
if not isinstance(element, TextItem) or len(element.prov) == 0:
continue
crop_bbox = (
element.prov[0]
.bbox.scaled(scale=scale)
.to_top_left_origin(page_height=page.size.height * scale)
)
txt = self.extract_text_from_backend(page, crop_bbox)
element.text = txt
element.orig = txt
return conv_res.document
def _turn_md_into_doc(self, conv_res):
def _extract_markdown_code(text):
"""
Extracts text from markdown code blocks (enclosed in triple backticks).
If no code blocks are found, returns the original text.
Args:
text (str): Input text that may contain markdown code blocks
Returns:
str: Extracted code if code blocks exist, otherwise original text
"""
# Regex pattern to match content between triple backticks
# This handles multiline content and optional language specifier
pattern = r"^```(?:\w*\n)?(.*?)```(\n)*$"
# Search with DOTALL flag to match across multiple lines
mtch = re.search(pattern, text, re.DOTALL)
if mtch:
# Return only the content of the first capturing group
return mtch.group(1)
else:
# No code blocks found, return original text
return text
for pg_idx, page in enumerate(conv_res.pages):
page_no = pg_idx + 1 # FIXME: might be incorrect
predicted_text = ""
if page.predictions.vlm_response:
predicted_text = page.predictions.vlm_response.text + "\n\n"
predicted_text = _extract_markdown_code(text=predicted_text)
response_bytes = BytesIO(predicted_text.encode("utf8"))
out_doc = InputDocument(
path_or_stream=response_bytes,
filename=conv_res.input.file.name,
format=InputFormat.MD,
backend=MarkdownDocumentBackend,
)
backend = MarkdownDocumentBackend(
in_doc=out_doc,
path_or_stream=response_bytes,
)
page_doc = backend.convert()
if page.image is not None:
pg_width = page.image.width
pg_height = page.image.height
else:
pg_width = 1
pg_height = 1
conv_res.document.add_page(
page_no=page_no,
size=Size(width=pg_width, height=pg_height),
image=ImageRef.from_pil(image=page.image, dpi=72)
if page.image
else None,
)
for item, level in page_doc.iterate_items():
item.prov = [
ProvenanceItem(
page_no=pg_idx + 1,
bbox=BoundingBox(
t=0.0, b=0.0, l=0.0, r=0.0
), # FIXME: would be nice not to have to "fake" it
charspan=[0, 0],
)
]
conv_res.document.append_child_item(child=item)
return conv_res.document
def _turn_html_into_doc(self, conv_res):
def _extract_html_code(text):
"""
Extracts text from markdown code blocks (enclosed in triple backticks).
If no code blocks are found, returns the original text.
Args:
text (str): Input text that may contain markdown code blocks
Returns:
str: Extracted code if code blocks exist, otherwise original text
"""
# Regex pattern to match content between triple backticks
# This handles multiline content and optional language specifier
pattern = r"^```(?:\w*\n)?(.*?)```(\n)*$"
# Search with DOTALL flag to match across multiple lines
mtch = re.search(pattern, text, re.DOTALL)
if mtch:
# Return only the content of the first capturing group
return mtch.group(1)
else:
# No code blocks found, return original text
return text
for pg_idx, page in enumerate(conv_res.pages):
page_no = pg_idx + 1 # FIXME: might be incorrect
predicted_text = ""
if page.predictions.vlm_response:
predicted_text = page.predictions.vlm_response.text + "\n\n"
predicted_text = _extract_html_code(text=predicted_text)
response_bytes = BytesIO(predicted_text.encode("utf8"))
out_doc = InputDocument(
path_or_stream=response_bytes,
filename=conv_res.input.file.name,
format=InputFormat.MD,
backend=HTMLDocumentBackend,
)
backend = HTMLDocumentBackend(
in_doc=out_doc,
path_or_stream=response_bytes,
)
page_doc = backend.convert()
if page.image is not None:
pg_width = page.image.width
pg_height = page.image.height
else:
pg_width = 1
pg_height = 1
conv_res.document.add_page(
page_no=page_no,
size=Size(width=pg_width, height=pg_height),
image=ImageRef.from_pil(image=page.image, dpi=72)
if page.image
else None,
)
for item, level in page_doc.iterate_items():
item.prov = [
ProvenanceItem(
page_no=pg_idx + 1,
bbox=BoundingBox(
t=0.0, b=0.0, l=0.0, r=0.0
), # FIXME: would be nice not to have to "fake" it
charspan=[0, 0],
)
]
conv_res.document.append_child_item(child=item)
return conv_res.document
@classmethod
def get_default_options(cls) -> VlmPipelineOptions:

View File

@ -1,13 +1,16 @@
import logging
from typing import List, Optional
import torch
from docling.datamodel.pipeline_options import AcceleratorDevice
from docling.datamodel.accelerator_options import AcceleratorDevice
_log = logging.getLogger(__name__)
def decide_device(accelerator_device: str) -> str:
def decide_device(
accelerator_device: str, supported_devices: Optional[List[AcceleratorDevice]] = None
) -> str:
r"""
Resolve the device based on the acceleration options and the available devices in the system.
@ -20,6 +23,18 @@ def decide_device(accelerator_device: str) -> str:
has_cuda = torch.backends.cuda.is_built() and torch.cuda.is_available()
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
if supported_devices is not None:
if has_cuda and AcceleratorDevice.CUDA not in supported_devices:
_log.info(
f"Removing CUDA from available devices because it is not in {supported_devices=}"
)
has_cuda = False
if has_mps and AcceleratorDevice.MPS not in supported_devices:
_log.info(
f"Removing MPS from available devices because it is not in {supported_devices=}"
)
has_mps = False
if accelerator_device == AcceleratorDevice.AUTO.value: # Handle 'auto'
if has_cuda:
device = "cuda:0"

View File

@ -4,18 +4,20 @@ from typing import Optional
from docling.datamodel.pipeline_options import (
granite_picture_description,
smoldocling_vlm_conversion_options,
smoldocling_vlm_mlx_conversion_options,
smolvlm_picture_description,
)
from docling.datamodel.settings import settings
from docling.datamodel.vlm_model_specs import (
SMOLDOCLING_MLX,
SMOLDOCLING_TRANSFORMERS,
)
from docling.models.code_formula_model import CodeFormulaModel
from docling.models.document_picture_classifier import DocumentPictureClassifier
from docling.models.easyocr_model import EasyOcrModel
from docling.models.hf_vlm_model import HuggingFaceVlmModel
from docling.models.layout_model import LayoutModel
from docling.models.picture_description_vlm_model import PictureDescriptionVlmModel
from docling.models.table_structure_model import TableStructureModel
from docling.models.utils.hf_model_download import download_hf_model
_log = logging.getLogger(__name__)
@ -75,7 +77,7 @@ def download_models(
if with_smolvlm:
_log.info("Downloading SmolVlm model...")
PictureDescriptionVlmModel.download_models(
download_hf_model(
repo_id=smolvlm_picture_description.repo_id,
local_dir=output_dir / smolvlm_picture_description.repo_cache_folder,
force=force,
@ -84,26 +86,25 @@ def download_models(
if with_smoldocling:
_log.info("Downloading SmolDocling model...")
HuggingFaceVlmModel.download_models(
repo_id=smoldocling_vlm_conversion_options.repo_id,
local_dir=output_dir / smoldocling_vlm_conversion_options.repo_cache_folder,
download_hf_model(
repo_id=SMOLDOCLING_TRANSFORMERS.repo_id,
local_dir=output_dir / SMOLDOCLING_TRANSFORMERS.repo_cache_folder,
force=force,
progress=progress,
)
if with_smoldocling_mlx:
_log.info("Downloading SmolDocling MLX model...")
HuggingFaceVlmModel.download_models(
repo_id=smoldocling_vlm_mlx_conversion_options.repo_id,
local_dir=output_dir
/ smoldocling_vlm_mlx_conversion_options.repo_cache_folder,
download_hf_model(
repo_id=SMOLDOCLING_MLX.repo_id,
local_dir=output_dir / SMOLDOCLING_MLX.repo_cache_folder,
force=force,
progress=progress,
)
if with_granite_vision:
_log.info("Downloading Granite Vision model...")
PictureDescriptionVlmModel.download_models(
download_hf_model(
repo_id=granite_picture_description.repo_id,
local_dir=output_dir / granite_picture_description.repo_cache_folder,
force=force,

160
docs/examples/compare_vlm_models.py vendored Normal file
View File

@ -0,0 +1,160 @@
# Compare VLM models
# ==================
#
# This example runs the VLM pipeline with different vision-language models.
# Their runtime as well output quality is compared.
import json
import sys
import time
from pathlib import Path
from docling_core.types.doc import DocItemLabel, ImageRefMode
from docling_core.types.doc.document import DEFAULT_EXPORT_LABELS
from tabulate import tabulate
from docling.datamodel import vlm_model_specs
from docling.datamodel.base_models import InputFormat
from docling.datamodel.pipeline_options import (
VlmPipelineOptions,
)
from docling.datamodel.pipeline_options_vlm_model import InferenceFramework
from docling.document_converter import DocumentConverter, PdfFormatOption
from docling.pipeline.vlm_pipeline import VlmPipeline
def convert(sources: list[Path], converter: DocumentConverter):
model_id = pipeline_options.vlm_options.repo_id.replace("/", "_")
framework = pipeline_options.vlm_options.inference_framework
for source in sources:
print("================================================")
print("Processing...")
print(f"Source: {source}")
print("---")
print(f"Model: {model_id}")
print(f"Framework: {framework}")
print("================================================")
print("")
res = converter.convert(source)
print("")
fname = f"{res.input.file.stem}-{model_id}-{framework}"
inference_time = 0.0
for i, page in enumerate(res.pages):
inference_time += page.predictions.vlm_response.generation_time
print("")
print(
f" ---------- Predicted page {i} in {pipeline_options.vlm_options.response_format} in {page.predictions.vlm_response.generation_time} [sec]:"
)
print(page.predictions.vlm_response.text)
print(" ---------- ")
print("===== Final output of the converted document =======")
with (out_path / f"{fname}.json").open("w") as fp:
fp.write(json.dumps(res.document.export_to_dict()))
res.document.save_as_json(
out_path / f"{fname}.json",
image_mode=ImageRefMode.PLACEHOLDER,
)
print(f" => produced {out_path / fname}.json")
res.document.save_as_markdown(
out_path / f"{fname}.md",
image_mode=ImageRefMode.PLACEHOLDER,
)
print(f" => produced {out_path / fname}.md")
res.document.save_as_html(
out_path / f"{fname}.html",
image_mode=ImageRefMode.EMBEDDED,
labels=[*DEFAULT_EXPORT_LABELS, DocItemLabel.FOOTNOTE],
split_page_view=True,
)
print(f" => produced {out_path / fname}.html")
pg_num = res.document.num_pages()
print("")
print(
f"Total document prediction time: {inference_time:.2f} seconds, pages: {pg_num}"
)
print("====================================================")
return [
source,
model_id,
str(framework),
pg_num,
inference_time,
]
if __name__ == "__main__":
sources = [
"tests/data/pdf/2305.03393v1-pg9.pdf",
]
out_path = Path("scratch")
out_path.mkdir(parents=True, exist_ok=True)
## Use VlmPipeline
pipeline_options = VlmPipelineOptions()
pipeline_options.generate_page_images = True
## On GPU systems, enable flash_attention_2 with CUDA:
# pipeline_options.accelerator_options.device = AcceleratorDevice.CUDA
# pipeline_options.accelerator_options.cuda_use_flash_attention2 = True
vlm_models = [
## DocTags / SmolDocling models
vlm_model_specs.SMOLDOCLING_MLX,
vlm_model_specs.SMOLDOCLING_TRANSFORMERS,
## Markdown models (using MLX framework)
vlm_model_specs.QWEN25_VL_3B_MLX,
vlm_model_specs.PIXTRAL_12B_MLX,
vlm_model_specs.GEMMA3_12B_MLX,
## Markdown models (using Transformers framework)
vlm_model_specs.GRANITE_VISION_TRANSFORMERS,
vlm_model_specs.PHI4_TRANSFORMERS,
vlm_model_specs.PIXTRAL_12B_TRANSFORMERS,
]
# Remove MLX models if not on Mac
if sys.platform != "darwin":
vlm_models = [
m for m in vlm_models if m.inference_framework != InferenceFramework.MLX
]
rows = []
for vlm_options in vlm_models:
pipeline_options.vlm_options = vlm_options
## Set up pipeline for PDF or image inputs
converter = DocumentConverter(
format_options={
InputFormat.PDF: PdfFormatOption(
pipeline_cls=VlmPipeline,
pipeline_options=pipeline_options,
),
InputFormat.IMAGE: PdfFormatOption(
pipeline_cls=VlmPipeline,
pipeline_options=pipeline_options,
),
},
)
row = convert(sources=sources, converter=converter)
rows.append(row)
print(
tabulate(
rows, headers=["source", "model_id", "framework", "num_pages", "time"]
)
)
print("see if memory gets released ...")
time.sleep(10)

View File

@ -3,10 +3,9 @@ import logging
import time
from pathlib import Path
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
from docling.datamodel.base_models import InputFormat
from docling.datamodel.pipeline_options import (
AcceleratorDevice,
AcceleratorOptions,
PdfPipelineOptions,
)
from docling.document_converter import DocumentConverter, PdfFormatOption

View File

@ -1,101 +1,46 @@
import json
import time
from pathlib import Path
from docling_core.types.doc import DocItemLabel, ImageRefMode
from docling_core.types.doc.document import DEFAULT_EXPORT_LABELS
from docling.datamodel import vlm_model_specs
from docling.datamodel.base_models import InputFormat
from docling.datamodel.pipeline_options import (
VlmPipelineOptions,
smoldocling_vlm_mlx_conversion_options,
)
from docling.document_converter import DocumentConverter, PdfFormatOption
from docling.pipeline.vlm_pipeline import VlmPipeline
sources = [
# "tests/data/2305.03393v1-pg9-img.png",
"tests/data/pdf/2305.03393v1-pg9.pdf",
]
source = "https://arxiv.org/pdf/2501.17887"
## Use experimental VlmPipeline
pipeline_options = VlmPipelineOptions()
# If force_backend_text = True, text from backend will be used instead of generated text
pipeline_options.force_backend_text = False
###### USING SIMPLE DEFAULT VALUES
# - SmolDocling model
# - Using the transformers framework
## On GPU systems, enable flash_attention_2 with CUDA:
# pipeline_options.accelerator_options.device = AcceleratorDevice.CUDA
# pipeline_options.accelerator_options.cuda_use_flash_attention2 = True
converter = DocumentConverter(
format_options={
InputFormat.PDF: PdfFormatOption(
pipeline_cls=VlmPipeline,
),
}
)
## Pick a VLM model. We choose SmolDocling-256M by default
# pipeline_options.vlm_options = smoldocling_vlm_conversion_options
doc = converter.convert(source=source).document
## Pick a VLM model. Fast Apple Silicon friendly implementation for SmolDocling-256M via MLX
pipeline_options.vlm_options = smoldocling_vlm_mlx_conversion_options
print(doc.export_to_markdown())
## Alternative VLM models:
# pipeline_options.vlm_options = granite_vision_vlm_conversion_options
## Set up pipeline for PDF or image inputs
###### USING MACOS MPS ACCELERATOR
# For more options see the compare_vlm_models.py example.
pipeline_options = VlmPipelineOptions(
vlm_options=vlm_model_specs.SMOLDOCLING_MLX,
)
converter = DocumentConverter(
format_options={
InputFormat.PDF: PdfFormatOption(
pipeline_cls=VlmPipeline,
pipeline_options=pipeline_options,
),
InputFormat.IMAGE: PdfFormatOption(
pipeline_cls=VlmPipeline,
pipeline_options=pipeline_options,
),
}
)
out_path = Path("scratch")
out_path.mkdir(parents=True, exist_ok=True)
doc = converter.convert(source=source).document
for source in sources:
start_time = time.time()
print("================================================")
print(f"Processing... {source}")
print("================================================")
print("")
res = converter.convert(source)
print("")
print(res.document.export_to_markdown())
for page in res.pages:
print("")
print("Predicted page in DOCTAGS:")
print(page.predictions.vlm_response.text)
res.document.save_as_html(
filename=Path(f"{out_path}/{res.input.file.stem}.html"),
image_mode=ImageRefMode.REFERENCED,
labels=[*DEFAULT_EXPORT_LABELS, DocItemLabel.FOOTNOTE],
)
with (out_path / f"{res.input.file.stem}.json").open("w") as fp:
fp.write(json.dumps(res.document.export_to_dict()))
res.document.save_as_json(
out_path / f"{res.input.file.stem}.json",
image_mode=ImageRefMode.PLACEHOLDER,
)
res.document.save_as_markdown(
out_path / f"{res.input.file.stem}.md",
image_mode=ImageRefMode.PLACEHOLDER,
)
pg_num = res.document.num_pages()
print("")
inference_time = time.time() - start_time
print(
f"Total document prediction time: {inference_time:.2f} seconds, pages: {pg_num}"
)
print("================================================")
print("done!")
print("================================================")
print(doc.export_to_markdown())

View File

@ -1,9 +1,8 @@
from pathlib import Path
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
from docling.datamodel.base_models import InputFormat
from docling.datamodel.pipeline_options import (
AcceleratorDevice,
AcceleratorOptions,
PdfPipelineOptions,
)
from docling.datamodel.settings import settings

View File

@ -1,5 +1,4 @@
import logging
import time
from pathlib import Path
from docling_core.types.doc import ImageRefMode, TableItem, TextItem

View File

@ -7,10 +7,9 @@ from dotenv import load_dotenv
from docling.datamodel.base_models import InputFormat
from docling.datamodel.pipeline_options import (
ApiVlmOptions,
ResponseFormat,
VlmPipelineOptions,
)
from docling.datamodel.pipeline_options_vlm_model import ApiVlmOptions, ResponseFormat
from docling.document_converter import DocumentConverter, PdfFormatOption
from docling.pipeline.vlm_pipeline import VlmPipeline

2
docs/index.md vendored
View File

@ -27,7 +27,7 @@ Docling simplifies document processing, parsing diverse formats — including ad
* 🔒 Local execution capabilities for sensitive data and air-gapped environments
* 🤖 Plug-and-play [integrations][integrations] incl. LangChain, LlamaIndex, Crew AI & Haystack for agentic AI
* 🔍 Extensive OCR support for scanned PDFs and images
* 🥚 Support of Visual Language Models ([SmolDocling](https://huggingface.co/ds4sd/SmolDocling-256M-preview)) 🆕🔥
* 🥚 Support of several Visual Language Models ([SmolDocling](https://huggingface.co/ds4sd/SmolDocling-256M-preview)) 🔥
* 💻 Simple and convenient CLI
### Coming soon

121
docs/usage/vision_models.md vendored Normal file
View File

@ -0,0 +1,121 @@
The `VlmPipeline` in Docling allows to convert documents end-to-end using a vision-language model.
Docling supports vision-language models which output:
- DocTags (e.g. [SmolDocling](https://huggingface.co/ds4sd/SmolDocling-256M-preview)), the preferred choice
- Markdown
- HTML
For running Docling using local models with the `VlmPipeline`:
=== "CLI"
```bash
docling --pipeline vlm FILE
```
=== "Python"
See also the example [minimal_vlm_pipeline.py](./../examples/minimal_vlm_pipeline.py).
```python
from docling.datamodel.base_models import InputFormat
from docling.document_converter import DocumentConverter, PdfFormatOption
from docling.pipeline.vlm_pipeline import VlmPipeline
converter = DocumentConverter(
format_options={
InputFormat.PDF: PdfFormatOption(
pipeline_cls=VlmPipeline,
),
}
)
doc = converter.convert(source="FILE").document
```
## Available local models
By default, the vision-language models are running locally.
Docling allows to choose between the Hugging Face [Transformers](https://github.com/huggingface/transformers) framweork and the [MLX](https://github.com/Blaizzy/mlx-vlm) (for Apple devices with MPS acceleration) one.
The following table reports the models currently available out-of-the-box.
| Model instance | Model | Framework | Device | Num pages | Inference time (sec) |
| ---------------|------ | --------- | ------ | --------- | ---------------------|
| `vlm_model_specs.SMOLDOCLING_TRANSFORMERS` | [ds4sd/SmolDocling-256M-preview](https://huggingface.co/ds4sd/SmolDocling-256M-preview) | `Transformers/AutoModelForVision2Seq` | MPS | 1 | 102.212 |
| `vlm_model_specs.SMOLDOCLING_MLX` | [ds4sd/SmolDocling-256M-preview-mlx-bf16](https://huggingface.co/ds4sd/SmolDocling-256M-preview-mlx-bf16) | `MLX`| MPS | 1 | 6.15453 |
| `vlm_model_specs.QWEN25_VL_3B_MLX` | [mlx-community/Qwen2.5-VL-3B-Instruct-bf16](https://huggingface.co/mlx-community/Qwen2.5-VL-3B-Instruct-bf16) | `MLX`| MPS | 1 | 23.4951 |
| `vlm_model_specs.PIXTRAL_12B_MLX` | [mlx-community/pixtral-12b-bf16](https://huggingface.co/mlx-community/pixtral-12b-bf16) | `MLX` | MPS | 1 | 308.856 |
| `vlm_model_specs.GEMMA3_12B_MLX` | [mlx-community/gemma-3-12b-it-bf16](https://huggingface.co/mlx-community/gemma-3-12b-it-bf16) | `MLX` | MPS | 1 | 378.486 |
| `vlm_model_specs.GRANITE_VISION_TRANSFORMERS` | [ibm-granite/granite-vision-3.2-2b](https://huggingface.co/ibm-granite/granite-vision-3.2-2b) | `Transformers/AutoModelForVision2Seq` | MPS | 1 | 104.75 |
| `vlm_model_specs.PHI4_TRANSFORMERS` | [microsoft/Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct) | `Transformers/AutoModelForCasualLM` | CPU | 1 | 1175.67 |
| `vlm_model_specs.PIXTRAL_12B_TRANSFORMERS` | [mistral-community/pixtral-12b](https://huggingface.co/mistral-community/pixtral-12b) | `Transformers/AutoModelForVision2Seq` | CPU | 1 | 1828.21 |
_Inference time is computed on a Macbook M3 Max using the example page `tests/data/pdf/2305.03393v1-pg9.pdf`. The comparision is done with the example [compare_vlm_models.py](./../examples/compare_vlm_models.py)._
For choosing the model, the code snippet above can be extended as follow
```python
from docling.datamodel.base_models import InputFormat
from docling.document_converter import DocumentConverter, PdfFormatOption
from docling.pipeline.vlm_pipeline import VlmPipeline
from docling.datamodel.pipeline_options import (
VlmPipelineOptions,
)
from docling.datamodel import vlm_model_specs
pipeline_options = VlmPipelineOptions(
vlm_options=vlm_model_specs.SMOLDOCLING_MLX, # <-- change the model here
)
converter = DocumentConverter(
format_options={
InputFormat.PDF: PdfFormatOption(
pipeline_cls=VlmPipeline,
pipeline_options=pipeline_options,
),
}
)
doc = converter.convert(source="FILE").document
```
### Other models
Other models can be configured by directly providing the Hugging Face `repo_id`, the prompt and a few more options.
For example:
```python
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions, InferenceFramework, TransformersModelType
pipeline_options = VlmPipelineOptions(
vlm_options=InlineVlmOptions(
repo_id="ibm-granite/granite-vision-3.2-2b",
prompt="Convert this page to markdown. Do not miss any text and only output the bare markdown!",
response_format=ResponseFormat.MARKDOWN,
inference_framework=InferenceFramework.TRANSFORMERS,
transformers_model_type=TransformersModelType.AUTOMODEL_VISION2SEQ,
supported_devices=[
AcceleratorDevice.CPU,
AcceleratorDevice.CUDA,
AcceleratorDevice.MPS,
],
scale=2.0,
temperature=0.0,
)
)
```
## Remote models
Additionally to local models, the `VlmPipeline` allows to offload the inference to a remote service hosting the models.
Many remote inference services are provided, the key requirement is to offer an OpenAI-compatible API. This includes vLLM, Ollama, etc.
More examples on how to connect with the remote inference services can be found in the following examples:
- [vlm_pipeline_api_model.py](./../examples/vlm_pipeline_api_model.py)

View File

@ -60,6 +60,7 @@ nav:
- Usage: usage/index.md
- Supported formats: usage/supported_formats.md
- Enrichment features: usage/enrichments.md
- Vision models: usage/vision_models.md
- FAQ:
- FAQ: faq/index.md
- Concepts:
@ -78,6 +79,7 @@ nav:
- "Multi-format conversion": examples/run_with_formats.py
- "VLM pipeline with SmolDocling": examples/minimal_vlm_pipeline.py
- "VLM pipeline with remote model": examples/vlm_pipeline_api_model.py
- "VLM comparison": examples/compare_vlm_models.py
- "Figure export": examples/export_figures.py
- "Table export": examples/export_tables.py
- "Multimodal export": examples/export_multimodal.py

1954
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -151,6 +151,11 @@ torchvision = [
{ markers = "sys_platform == 'darwin' and platform_machine == 'x86_64'", version = "~0.17.2" },
]
[tool.poetry.group.lm.dependencies]
peft = "^0.15.2"
backoff = "^2.2.1"
[tool.poetry.extras]
tesserocr = ["tesserocr"]
ocrmac = ["ocrmac"]

View File

@ -1,9 +1,10 @@
from pathlib import Path
from docling.backend.docling_parse_backend import DoclingParseDocumentBackend
from docling.datamodel.accelerator_options import AcceleratorDevice
from docling.datamodel.base_models import InputFormat
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import AcceleratorDevice, PdfPipelineOptions
from docling.datamodel.pipeline_options import PdfPipelineOptions
from docling.document_converter import DocumentConverter, PdfFormatOption
from .test_data_gen_flag import GEN_TEST_DATA

View File

@ -3,10 +3,10 @@ from pathlib import Path
from typing import List, Tuple
from docling.backend.docling_parse_backend import DoclingParseDocumentBackend
from docling.datamodel.accelerator_options import AcceleratorDevice
from docling.datamodel.base_models import InputFormat
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import (
AcceleratorDevice,
EasyOcrOptions,
OcrMacOptions,
OcrOptions,

View File

@ -7,11 +7,10 @@ from docling.backend.docling_parse_backend import DoclingParseDocumentBackend
from docling.backend.docling_parse_v2_backend import DoclingParseV2DocumentBackend
from docling.backend.docling_parse_v4_backend import DoclingParseV4DocumentBackend
from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
from docling.datamodel.base_models import ConversionStatus, InputFormat, QualityGrade
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import (
AcceleratorDevice,
AcceleratorOptions,
PdfPipelineOptions,
TableFormerMode,
)