docling/docling/datamodel/base_models.py
Michele Dolfi e45dc5d1a5
ci: Add Github Actions (#4)
* add Github Actions

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* apply styling

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* Update .github/actions/setup-poetry/action.yml

Co-authored-by: Panos Vagenas <35837085+vagenas@users.noreply.github.com>
Signed-off-by: Michele Dolfi <97102151+dolfim-ibm@users.noreply.github.com>

* add semantic-release config

Signed-off-by: Panos Vagenas <35837085+vagenas@users.noreply.github.com>

---------

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
Signed-off-by: Michele Dolfi <97102151+dolfim-ibm@users.noreply.github.com>
Signed-off-by: Panos Vagenas <35837085+vagenas@users.noreply.github.com>
Co-authored-by: Panos Vagenas <35837085+vagenas@users.noreply.github.com>
2024-07-16 13:05:04 +02:00

247 lines
6.0 KiB
Python

from enum import Enum, auto
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple, Union
from PIL.Image import Image
from pydantic import BaseModel, ConfigDict, model_validator
from docling.backend.abstract_backend import PdfPageBackend
class ConversionStatus(str, Enum):
PENDING = auto()
STARTED = auto()
FAILURE = auto()
SUCCESS = auto()
SUCCESS_WITH_ERRORS = auto()
class DocInputType(str, Enum):
PATH = auto()
STREAM = auto()
class CoordOrigin(str, Enum):
TOPLEFT = auto()
BOTTOMLEFT = auto()
class PageSize(BaseModel):
width: float = 0.0
height: float = 0.0
class BoundingBox(BaseModel):
l: float # left
t: float # top
r: float # right
b: float # bottom
coord_origin: CoordOrigin = CoordOrigin.TOPLEFT
@property
def width(self):
return self.r - self.l
@property
def height(self):
return abs(self.t - self.b)
def as_tuple(self):
if self.coord_origin == CoordOrigin.TOPLEFT:
return (self.l, self.t, self.r, self.b)
elif self.coord_origin == CoordOrigin.BOTTOMLEFT:
return (self.l, self.b, self.r, self.t)
@classmethod
def from_tuple(cls, coord: Tuple[float], origin: CoordOrigin):
if origin == CoordOrigin.TOPLEFT:
return BoundingBox(
l=coord[0], t=coord[1], r=coord[2], b=coord[3], coord_origin=origin
)
elif origin == CoordOrigin.BOTTOMLEFT:
return BoundingBox(
l=coord[0], b=coord[1], r=coord[2], t=coord[3], coord_origin=origin
)
def area(self) -> float:
return (self.r - self.l) * (self.b - self.t)
def intersection_area_with(self, other: "BoundingBox") -> float:
# Calculate intersection coordinates
left = max(self.l, other.l)
top = max(self.t, other.t)
right = min(self.r, other.r)
bottom = min(self.b, other.b)
# Calculate intersection dimensions
width = right - left
height = bottom - top
# If the bounding boxes do not overlap, width or height will be negative
if width <= 0 or height <= 0:
return 0.0
return width * height
def to_bottom_left_origin(self, page_height) -> "BoundingBox":
if self.coord_origin == CoordOrigin.BOTTOMLEFT:
return self
elif self.coord_origin == CoordOrigin.TOPLEFT:
return BoundingBox(
l=self.l,
r=self.r,
t=page_height - self.t,
b=page_height - self.b,
coord_origin=CoordOrigin.BOTTOMLEFT,
)
def to_top_left_origin(self, page_height):
if self.coord_origin == CoordOrigin.TOPLEFT:
return self
elif self.coord_origin == CoordOrigin.BOTTOMLEFT:
return BoundingBox(
l=self.l,
r=self.r,
t=page_height - self.t, # self.b
b=page_height - self.b, # self.t
coord_origin=CoordOrigin.TOPLEFT,
)
class Cell(BaseModel):
id: int
text: str
bbox: BoundingBox
class OcrCell(Cell):
confidence: float
class Cluster(BaseModel):
id: int
label: str
bbox: BoundingBox
confidence: float = 1.0
cells: List[Cell] = []
class BasePageElement(BaseModel):
label: str
id: int
page_no: int
cluster: Cluster
text: Optional[str] = None
class LayoutPrediction(BaseModel):
clusters: List[Cluster] = []
class TableCell(BaseModel):
bbox: BoundingBox
row_span: int
col_span: int
start_row_offset_idx: int
end_row_offset_idx: int
start_col_offset_idx: int
end_col_offset_idx: int
text: str
column_header: bool = False
row_header: bool = False
row_section: bool = False
@model_validator(mode="before")
@classmethod
def from_dict_format(cls, data: Any) -> Any:
if isinstance(data, Dict):
text = data["bbox"].get("token", "")
if not len(text):
text_cells = data.pop("text_cell_bboxes", None)
if text_cells:
for el in text_cells:
text += el["token"] + " "
text = text.strip()
data["text"] = text
return data
class TableElement(BasePageElement):
otsl_seq: List[str]
num_rows: int = 0
num_cols: int = 0
table_cells: List[TableCell]
class TableStructurePrediction(BaseModel):
table_map: Dict[int, TableElement] = {}
class TextElement(BasePageElement): ...
class FigureData(BaseModel):
pass
class FigureElement(BasePageElement):
data: Optional[FigureData] = None
provenance: Optional[str] = None
predicted_class: Optional[str] = None
confidence: Optional[float] = None
class FigureClassificationPrediction(BaseModel):
figure_count: int = 0
figure_map: Dict[int, FigureElement] = {}
class EquationPrediction(BaseModel):
equation_count: int = 0
equation_map: Dict[int, TextElement] = {}
class PagePredictions(BaseModel):
layout: LayoutPrediction = None
tablestructure: TableStructurePrediction = None
figures_classification: FigureClassificationPrediction = None
equations_prediction: EquationPrediction = None
PageElement = Union[TextElement, TableElement, FigureElement]
class AssembledUnit(BaseModel):
elements: List[PageElement]
body: List[PageElement]
headers: List[PageElement]
class Page(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
page_no: int
page_hash: str = None
size: PageSize = None
image: Image = None
cells: List[Cell] = None
predictions: PagePredictions = PagePredictions()
assembled: AssembledUnit = None
_backend: PdfPageBackend = None # Internal PDF backend
class DocumentStream(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
filename: str
stream: BytesIO
class PipelineOptions(BaseModel):
do_table_structure: bool = True
do_ocr: bool = False