feat: partition_pdf() support language specification for PaddleOCR (#3400)

Closes #3159.

This PR extends language specification capability to `PaddleOCR` in
addition to `TesseractOCR`. Users can now specify OCR languages for both
OCR engines when using `partition_pdf()`.

### Testing

```
os.environ["OCR_AGENT"] = "unstructured.partition.utils.ocr_models.paddle_ocr.OCRAgentPaddle"

elements = partition_pdf(
    filename=<file_path>,
    strategy=strategy,
    languages=["chi_sim"], # chinese - simplified
    infer_table_structure=True,
)
```
This commit is contained in:
Christine Straub 2024-07-16 15:19:25 -07:00 committed by GitHub
parent 6b1d5f28bb
commit 48bdf94656
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 186 additions and 108 deletions

View File

@ -9,6 +9,7 @@
### Features
* **Add support for specifying OCR language to `partition_pdf()`.** Extend language specification capability to `PaddleOCR` in addition to `TesseractOCR`. Users can now specify OCR languages for both OCR engines when using `partition_pdf()`.
* **Add AstraDB source connector** Adds support for ingesting documents from AstraDB.
### Fixes

View File

@ -21,6 +21,7 @@ from unstructured.partition.utils.constants import (
Source,
)
from unstructured.partition.utils.ocr_models.google_vision_ocr import OCRAgentGoogleVision
from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent
from unstructured.partition.utils.ocr_models.paddle_ocr import OCRAgentPaddle
from unstructured.partition.utils.ocr_models.tesseract_ocr import (
OCRAgentTesseract,
@ -85,10 +86,7 @@ def test_get_ocr_layout_from_image_tesseract(monkeypatch):
image = Image.new("RGB", (100, 100))
ocr_agent = OCRAgentTesseract()
ocr_layout = ocr_agent.get_layout_from_image(
image,
ocr_languages="eng",
)
ocr_layout = ocr_agent.get_layout_from_image(image)
expected_layout = [
TextRegion.from_coords(10, 5, 25, 15, "Hello", source=Source.OCR_TESSERACT),
@ -128,7 +126,7 @@ def mock_ocr(*args, **kwargs):
]
def monkeypatch_load_agent(language: str):
def monkeypatch_load_agent(*args):
class MockAgent:
def __init__(self):
self.ocr = mock_ocr
@ -145,10 +143,7 @@ def test_get_ocr_layout_from_image_paddle(monkeypatch):
image = Image.new("RGB", (100, 100))
ocr_layout = OCRAgentPaddle().get_layout_from_image(
image,
ocr_languages="eng",
)
ocr_layout = OCRAgentPaddle().get_layout_from_image(image)
expected_layout = [
TextRegion.from_coords(10, 5, 25, 15, "Hello", source=Source.OCR_PADDLE),
@ -168,10 +163,7 @@ def test_get_ocr_text_from_image_tesseract(monkeypatch):
image = Image.new("RGB", (100, 100))
ocr_agent = OCRAgentTesseract()
ocr_text = ocr_agent.get_text_from_image(
image,
ocr_languages="eng",
)
ocr_text = ocr_agent.get_text_from_image(image)
assert ocr_text == "Hello World"
@ -186,10 +178,7 @@ def test_get_ocr_text_from_image_paddle(monkeypatch):
image = Image.new("RGB", (100, 100))
ocr_agent = OCRAgentPaddle()
ocr_text = ocr_agent.get_text_from_image(
image,
ocr_languages="eng",
)
ocr_text = ocr_agent.get_text_from_image(image)
assert ocr_text == "Hello\n\nWorld\n\n!"
@ -251,7 +240,7 @@ def test_get_ocr_from_image_google_vision(google_vision_client):
image = Image.new("RGB", (100, 100))
ocr_agent = google_vision_client
ocr_text = ocr_agent.get_text_from_image(image, ocr_languages="eng")
ocr_text = ocr_agent.get_text_from_image(image)
assert ocr_text == "Hello World!"
@ -428,7 +417,8 @@ def mock_ocr_layout():
def test_get_table_tokens(mock_ocr_layout):
with patch.object(OCRAgentTesseract, "get_layout_from_image", return_value=mock_ocr_layout):
table_tokens = ocr.get_table_tokens(table_element_image=None)
ocr_agent = OCRAgent.get_agent(language="eng")
table_tokens = ocr.get_table_tokens(table_element_image=None, ocr_agent=ocr_agent)
expected_tokens = [
{
"bbox": [15, 25, 35, 45],

View File

@ -21,6 +21,7 @@ from unstructured.partition.lang import (
check_language_args,
detect_languages,
prepare_languages_for_tesseract,
tesseract_to_paddle_language,
)
DIRECTORY = pathlib.Path(__file__).parent.resolve()
@ -84,6 +85,39 @@ def test_prepare_languages_for_tesseract_no_valid_languages(caplog):
assert "Failed to find any valid standard language code from languages" in caplog.text
@pytest.mark.parametrize(
("tesseract_lang", "expected_lang"),
[
("eng", "en"),
("chi_sim", "ch"),
("chi_tra", "chinese_cht"),
("deu", "german"),
("jpn", "japan"),
("kor", "korean"),
],
)
def test_tesseract_to_paddle_language_valid_codes(tesseract_lang, expected_lang):
assert expected_lang == tesseract_to_paddle_language(tesseract_lang)
def test_tesseract_to_paddle_language_invalid_codes(caplog):
tesseract_lang = "unsupported_lang"
assert tesseract_to_paddle_language(tesseract_lang) == "en"
assert "unsupported_lang is not a language code supported by PaddleOCR," in caplog.text
@pytest.mark.parametrize(
("tesseract_lang", "expected_lang"),
[
("ENG", "en"),
("Fra", "fr"),
("DEU", "german"),
],
)
def test_tesseract_to_paddle_language_case_sensitivity(tesseract_lang, expected_lang):
assert expected_lang == tesseract_to_paddle_language(tesseract_lang)
def test_detect_languages_english_auto():
text = "This is a short sentence."
assert detect_languages(text) == ["eng"]

View File

@ -35,10 +35,10 @@ class DescribeOCRAgent:
_get_ocr_agent_cls_qname_.return_value = OCR_AGENT_TESSERACT
get_instance_.return_value = ocr_agent_
ocr_agent = OCRAgent.get_agent()
ocr_agent = OCRAgent.get_agent(language="eng")
_get_ocr_agent_cls_qname_.assert_called_once_with()
get_instance_.assert_called_once_with(OCR_AGENT_TESSERACT)
get_instance_.assert_called_once_with(OCR_AGENT_TESSERACT, "eng")
assert ocr_agent is ocr_agent_
def but_it_raises_when_the_requested_agent_is_not_whitelisted(
@ -46,7 +46,7 @@ class DescribeOCRAgent:
):
_get_ocr_agent_cls_qname_.return_value = "Invalid.Ocr.Agent.Qname"
with pytest.raises(ValueError, match="must be set to a whitelisted module"):
OCRAgent.get_agent()
OCRAgent.get_agent(language="eng")
@pytest.mark.parametrize("exception_cls", [ImportError, AttributeError])
def and_it_raises_when_the_requested_agent_cannot_be_loaded(
@ -57,7 +57,7 @@ class DescribeOCRAgent:
"unstructured.partition.utils.ocr_models.ocr_interface.importlib.import_module",
side_effect=exception_cls,
), pytest.raises(RuntimeError, match="Could not get the OCRAgent instance"):
OCRAgent.get_agent()
OCRAgent.get_agent(language="eng")
@pytest.mark.parametrize(
("OCR_AGENT", "expected_value"),

View File

@ -4,6 +4,7 @@ from PIL import Image
from unstructured.partition.pdf import convert_pdf_to_images
from unstructured.partition.pdf_image.ocr import get_table_tokens
from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent
from unstructured.utils import requires_dependencies
@ -21,8 +22,10 @@ def image_or_pdf_to_dataframe(filename: str) -> pd.DataFrame:
else:
image = Image.open(filename).convert("RGB")
ocr_agent = OCRAgent.get_agent(language="eng")
return tables_agent.run_prediction(
image, ocr_tokens=get_table_tokens(image), result_format="dataframe"
image, ocr_tokens=get_table_tokens(image, ocr_agent), result_format="dataframe"
)

View File

@ -98,7 +98,7 @@ def partition_image(
"""
exactly_one(filename=filename, file=file)
languages = check_language_args(languages or [], ocr_languages) or ["eng"]
languages = check_language_args(languages or [], ocr_languages)
return partition_pdf_or_image(
filename=filename,

View File

@ -144,6 +144,63 @@ PYTESSERACT_LANG_CODES = [
"yor",
]
PYTESSERACT_TO_PADDLE_LANG_CODE_MAP = {
"afr": "af", # Afrikaans
"ara": "ar", # Arabic
"aze": "az", # Azerbaijani
"bel": "be", # Belarusian
"bos": "bs", # Bosnian
"bul": "bg", # Bulgarian
"ces": "cs", # Czech
"chi_sim": "ch", # Simplified Chinese
"chi_tra": "chinese_cht", # Traditional Chinese
"cym": "cy", # Welsh
"dan": "da", # Danish
"deu": "german", # German
"eng": "en", # English
"est": "et", # Estonian
"fas": "fa", # Persian
"fra": "fr", # French
"gle": "ga", # Irish
"hin": "hi", # Hindi
"hrv": "hr", # Croatian
"hun": "hu", # Hungarian
"ind": "id", # Indonesian
"isl": "is", # Icelandic
"ita": "it", # Italian
"jpn": "japan", # Japanese
"kor": "korean", # Korean
"kmr": "ku", # Kurdish
"lat": "rs_latin", # Latin
"lav": "lv", # Latvian
"lit": "lt", # Lithuanian
"mar": "mr", # Marathi
"mlt": "mt", # Maltese
"msa": "ms", # Malay
"nep": "ne", # Nepali
"nld": "nl", # Dutch
"nor": "no", # Norwegian
"pol": "pl", # Polish
"por": "pt", # Portuguese
"ron": "ro", # Romanian
"rus": "ru", # Russian
"slk": "sk", # Slovak
"slv": "sl", # Slovenian
"spa": "es", # Spanish
"sqi": "sq", # Albanian
"srp": "rs_cyrillic", # Serbian
"swa": "sw", # Swahili
"swe": "sv", # Swedish
"tam": "ta", # Tamil
"tel": "te", # Telugu
"tur": "tr", # Turkish
"uig": "ug", # Uyghur
"ukr": "uk", # Ukrainian
"urd": "ur", # Urdu
"uzb": "uz", # Uzbek
"vie": "vi", # Vietnamese
}
def prepare_languages_for_tesseract(languages: Optional[list[str]] = ["eng"]) -> str:
"""
@ -169,6 +226,25 @@ def prepare_languages_for_tesseract(languages: Optional[list[str]] = ["eng"]) ->
return TESSERACT_LANGUAGES_SPLITTER.join(converted_languages)
def tesseract_to_paddle_language(tesseract_language: str) -> str:
"""
Convert TesseractOCR language code to PaddleOCR language code.
:param tesseract_language: str, language code used in TesseractOCR
:return: str, corresponding language code for PaddleOCR or None if not found
"""
lang = PYTESSERACT_TO_PADDLE_LANG_CODE_MAP.get(tesseract_language.lower())
if not lang:
logger.warning(
f"{tesseract_language} is not a language code supported by PaddleOCR, "
f"proceeding with `en` instead."
)
return "en"
return lang
def check_language_args(languages: list[str], ocr_languages: Optional[str]) -> Optional[list[str]]:
"""Handle users defining both `ocr_languages` and `languages`, giving preference to `languages`
and converting `ocr_languages` if needed, but defaulting to `None.

View File

@ -45,7 +45,11 @@ from unstructured.partition.common import (
ocr_data_to_elements,
spooled_to_bytes_io_if_needed,
)
from unstructured.partition.lang import check_language_args, prepare_languages_for_tesseract
from unstructured.partition.lang import (
check_language_args,
prepare_languages_for_tesseract,
tesseract_to_paddle_language,
)
from unstructured.partition.pdf_image.analysis.bbox_visualisation import (
AnalysisDrawer,
FinalLayoutDrawer,
@ -77,6 +81,7 @@ from unstructured.partition.strategies import determine_pdf_or_image_strategy, v
from unstructured.partition.text import element_from_text
from unstructured.partition.utils.config import env_config
from unstructured.partition.utils.constants import (
OCR_AGENT_PADDLE,
SORT_MODE_BASIC,
SORT_MODE_DONT,
SORT_MODE_XY_CUT,
@ -197,7 +202,7 @@ def partition_pdf(
exactly_one(filename=filename, file=file)
languages = check_language_args(languages or [], ocr_languages) or ["eng"]
languages = check_language_args(languages or [], ocr_languages)
return partition_pdf_or_image(
filename=filename,
@ -227,7 +232,6 @@ def partition_pdf_or_image(
include_page_breaks: bool = False,
strategy: str = PartitionStrategy.AUTO,
infer_table_structure: bool = False,
ocr_languages: Optional[str] = None,
languages: Optional[list[str]] = None,
metadata_last_modified: Optional[str] = None,
hi_res_model_name: Optional[str] = None,
@ -247,6 +251,9 @@ def partition_pdf_or_image(
# that task so as routing design changes, those changes are implemented in a single
# function.
if languages is None:
languages = ["eng"]
# init ability to process .heic files
register_heif_opener()
@ -291,6 +298,10 @@ def partition_pdf_or_image(
if file is not None:
file.seek(0)
ocr_languages = prepare_languages_for_tesseract(languages)
if env_config.OCR_AGENT == OCR_AGENT_PADDLE:
ocr_languages = tesseract_to_paddle_language(ocr_languages)
if strategy == PartitionStrategy.HI_RES:
# NOTE(robinson): Catches a UserWarning that occurs when detection is called
with warnings.catch_warnings():
@ -302,6 +313,7 @@ def partition_pdf_or_image(
infer_table_structure=infer_table_structure,
include_page_breaks=include_page_breaks,
languages=languages,
ocr_languages=ocr_languages,
metadata_last_modified=metadata_last_modified or last_modification_date,
hi_res_model_name=hi_res_model_name,
pdf_text_extractable=pdf_text_extractable,
@ -333,6 +345,7 @@ def partition_pdf_or_image(
file=file,
include_page_breaks=include_page_breaks,
languages=languages,
ocr_languages=ocr_languages,
is_image=is_image,
metadata_last_modified=metadata_last_modified or last_modification_date,
starting_page_number=starting_page_number,
@ -500,6 +513,7 @@ def _partition_pdf_or_image_local(
infer_table_structure: bool = False,
include_page_breaks: bool = False,
languages: Optional[list[str]] = None,
ocr_languages: Optional[str] = None,
ocr_mode: str = OCRMode.FULL_PAGE.value,
model_name: Optional[str] = None, # to be deprecated in favor of `hi_res_model_name`
hi_res_model_name: Optional[str] = None,
@ -529,11 +543,6 @@ def _partition_pdf_or_image_local(
process_file_with_pdfminer,
)
if languages is None:
languages = ["eng"]
ocr_languages = prepare_languages_for_tesseract(languages)
hi_res_model_name = hi_res_model_name or model_name or default_hi_res_model()
if pdf_image_dpi is None:
pdf_image_dpi = 300 if hi_res_model_name.startswith("chipper") else 200
@ -819,7 +828,8 @@ def _partition_pdf_or_image_with_ocr(
filename: str = "",
file: Optional[bytes | IO[bytes]] = None,
include_page_breaks: bool = False,
languages: Optional[list[str]] = ["eng"],
languages: Optional[list[str]] = None,
ocr_languages: Optional[str] = None,
is_image: bool = False,
metadata_last_modified: Optional[str] = None,
starting_page_number: int = 1,
@ -838,6 +848,7 @@ def _partition_pdf_or_image_with_ocr(
page_elements = _partition_pdf_or_image_with_ocr_from_image(
image=image,
languages=languages,
ocr_languages=ocr_languages,
page_number=page_number,
include_page_breaks=include_page_breaks,
metadata_last_modified=metadata_last_modified,
@ -851,6 +862,7 @@ def _partition_pdf_or_image_with_ocr(
page_elements = _partition_pdf_or_image_with_ocr_from_image(
image=image,
languages=languages,
ocr_languages=ocr_languages,
page_number=page_number,
include_page_breaks=include_page_breaks,
metadata_last_modified=metadata_last_modified,
@ -864,6 +876,7 @@ def _partition_pdf_or_image_with_ocr(
def _partition_pdf_or_image_with_ocr_from_image(
image: PILImage.Image,
languages: Optional[list[str]] = None,
ocr_languages: Optional[str] = None,
page_number: int = 1,
include_page_breaks: bool = False,
metadata_last_modified: Optional[str] = None,
@ -874,17 +887,13 @@ def _partition_pdf_or_image_with_ocr_from_image(
from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent
ocr_agent = OCRAgent.get_agent()
ocr_languages = prepare_languages_for_tesseract(languages)
ocr_agent = OCRAgent.get_agent(language=ocr_languages)
# NOTE(christine): `unstructured_pytesseract.image_to_string()` returns sorted text
if ocr_agent.is_text_sorted():
sort_mode = SORT_MODE_DONT
ocr_data = ocr_agent.get_layout_elements_from_image(
image=image,
ocr_languages=ocr_languages,
)
ocr_data = ocr_agent.get_layout_elements_from_image(image=image)
metadata = ElementMetadata(
last_modified=metadata_last_modified,

View File

@ -16,7 +16,7 @@ from unstructured.metrics.table.table_formats import SimpleTableCell
from unstructured.partition.pdf_image.analysis.bbox_visualisation import OCRLayoutDrawer
from unstructured.partition.pdf_image.pdf_image_utils import pad_element_bboxes, valid_text
from unstructured.partition.utils.config import env_config
from unstructured.partition.utils.constants import OCR_AGENT_TESSERACT, OCRMode
from unstructured.partition.utils.constants import OCRMode
from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent
from unstructured.utils import requires_dependencies
@ -200,12 +200,9 @@ def supplement_page_layout_with_ocr(
with no text and add text from OCR to each element.
"""
ocr_agent = OCRAgent.get_agent()
ocr_agent = OCRAgent.get_agent(language=ocr_languages)
if ocr_mode == OCRMode.FULL_PAGE.value:
ocr_layout = ocr_agent.get_layout_from_image(
image,
ocr_languages=ocr_languages,
)
ocr_layout = ocr_agent.get_layout_from_image(image)
if ocr_drawer:
ocr_drawer.add_ocred_page(ocr_layout)
page_layout.elements[:] = merge_out_layout_with_ocr_layout(
@ -227,10 +224,7 @@ def supplement_page_layout_with_ocr(
)
# Note(yuming): instead of getting OCR layout, we just need
# the text extraced from OCR for individual elements
text_from_ocr = ocr_agent.get_text_from_image(
cropped_image,
ocr_languages=ocr_languages,
)
text_from_ocr = ocr_agent.get_text_from_image(cropped_image)
element.text = text_from_ocr
else:
raise ValueError(
@ -250,7 +244,6 @@ def supplement_page_layout_with_ocr(
elements=cast(List["LayoutElement"], page_layout.elements),
image=image,
tables_agent=tables.tables_agent,
ocr_languages=ocr_languages,
ocr_agent=ocr_agent,
extracted_regions=extracted_regions,
)
@ -263,8 +256,7 @@ def supplement_element_with_table_extraction(
elements: List["LayoutElement"],
image: PILImage.Image,
tables_agent: "UnstructuredTableTransformerModel",
ocr_languages: str = "eng",
ocr_agent: OCRAgent = OCRAgent.get_instance(OCR_AGENT_TESSERACT),
ocr_agent,
extracted_regions: Optional[List["TextRegion"]] = None,
) -> List["LayoutElement"]:
"""Supplement the existing layout with table extraction. Any Table elements
@ -288,7 +280,6 @@ def supplement_element_with_table_extraction(
)
table_tokens = get_table_tokens(
table_element_image=cropped_image,
ocr_languages=ocr_languages,
ocr_agent=ocr_agent,
extracted_regions=extracted_regions,
table_element=padded_element,
@ -312,17 +303,13 @@ def supplement_element_with_table_extraction(
def get_table_tokens(
table_element_image: PILImage.Image,
ocr_languages: str = "eng",
ocr_agent: OCRAgent = OCRAgent.get_instance(OCR_AGENT_TESSERACT),
ocr_agent: OCRAgent,
extracted_regions: Optional[List["TextRegion"]] = None,
table_element: Optional["LayoutElement"] = None,
) -> List[dict[str, Any]]:
"""Get OCR tokens from either paddleocr or tesseract"""
ocr_layout = ocr_agent.get_layout_from_image(
image=table_element_image,
ocr_languages=ocr_languages,
)
ocr_layout = ocr_agent.get_layout_from_image(image=table_element_image)
table_tokens = []
for ocr_region in ocr_layout:
table_tokens.append(

View File

@ -43,10 +43,6 @@ OCR_AGENT_MODULES_WHITELIST = os.getenv(
UNSTRUCTURED_INCLUDE_DEBUG_METADATA = os.getenv("UNSTRUCTURED_INCLUDE_DEBUG_METADATA", False)
# Note(yuming): Default language for paddle OCR
# soon will be able to specify the language down through partition() as well
DEFAULT_PADDLE_LANG = os.getenv("DEFAULT_PADDLE_LANG", "en")
# this field is defined by pytesseract/unstructured.pytesseract
TESSERACT_TEXT_HEIGHT = "height"

View File

@ -25,17 +25,17 @@ class OCRAgent(ABC):
"""Defines the interface for an Optical Character Recognition (OCR) service."""
@classmethod
def get_agent(cls) -> OCRAgent:
def get_agent(cls, language: str) -> OCRAgent:
"""Get the configured OCRAgent instance.
The OCR package used by the agent is determined by the `OCR_AGENT` environment variable.
"""
ocr_agent_cls_qname = cls._get_ocr_agent_cls_qname()
return cls.get_instance(ocr_agent_cls_qname)
return cls.get_instance(ocr_agent_cls_qname, language)
@staticmethod
@functools.lru_cache(maxsize=None)
def get_instance(ocr_agent_module: str) -> "OCRAgent":
def get_instance(ocr_agent_module: str, language: str) -> "OCRAgent":
module_name, class_name = ocr_agent_module.rsplit(".", 1)
if module_name not in OCR_AGENT_MODULES_WHITELIST:
raise ValueError(
@ -46,7 +46,7 @@ class OCRAgent(ABC):
try:
module = importlib.import_module(module_name)
loaded_class = getattr(module, class_name)
return loaded_class()
return loaded_class(language)
except (ImportError, AttributeError) as e:
logger.error(f"Failed to get OCRAgent instance: {e}")
raise RuntimeError(
@ -55,19 +55,15 @@ class OCRAgent(ABC):
)
@abstractmethod
def get_layout_elements_from_image(
self, image: PILImage.Image, ocr_languages: str = "eng"
) -> list[LayoutElement]:
def get_layout_elements_from_image(self, image: PILImage.Image) -> list[LayoutElement]:
pass
@abstractmethod
def get_layout_from_image(
self, image: PILImage.Image, ocr_languages: str = "eng"
) -> list[TextRegion]:
def get_layout_from_image(self, image: PILImage.Image) -> list[TextRegion]:
pass
@abstractmethod
def get_text_from_image(self, image: PILImage.Image, ocr_languages: str = "eng") -> str:
def get_text_from_image(self, image: PILImage.Image) -> str:
pass
@abstractmethod

View File

@ -7,7 +7,7 @@ from PIL import Image as PILImage
from unstructured.documents.elements import ElementType
from unstructured.logger import logger, trace_logger
from unstructured.partition.utils.constants import DEFAULT_PADDLE_LANG, Source
from unstructured.partition.utils.constants import Source
from unstructured.partition.utils.ocr_models.ocr_interface import OCRAgent
from unstructured.utils import requires_dependencies
@ -19,10 +19,10 @@ if TYPE_CHECKING:
class OCRAgentPaddle(OCRAgent):
"""OCR service implementation for PaddleOCR."""
def __init__(self):
self.agent = self.load_agent()
def __init__(self, language: str = "en"):
self.agent = self.load_agent(language)
def load_agent(self, language: str = DEFAULT_PADDLE_LANG):
def load_agent(self, language: str):
"""Loads the PaddleOCR agent as a global variable to ensure that we only load it once."""
import paddle
@ -59,16 +59,14 @@ class OCRAgentPaddle(OCRAgent):
)
return paddle_ocr
def get_text_from_image(self, image: PILImage.Image, ocr_languages: str = "eng") -> str:
def get_text_from_image(self, image: PILImage.Image) -> str:
ocr_regions = self.get_layout_from_image(image)
return "\n\n".join([r.text for r in ocr_regions])
def is_text_sorted(self):
return False
def get_layout_from_image(
self, image: PILImage.Image, ocr_languages: str = "eng"
) -> list[TextRegion]:
def get_layout_from_image(self, image: PILImage.Image) -> list[TextRegion]:
"""Get the OCR regions from image as a list of text regions with paddle."""
trace_logger.detail("Processing entire page OCR with paddle...")
@ -82,15 +80,10 @@ class OCRAgentPaddle(OCRAgent):
return ocr_regions
@requires_dependencies("unstructured_inference")
def get_layout_elements_from_image(
self, image: PILImage.Image, ocr_languages: str = "eng"
) -> list[LayoutElement]:
def get_layout_elements_from_image(self, image: PILImage.Image) -> list[LayoutElement]:
from unstructured.partition.pdf_image.inference_utils import build_layout_element
ocr_regions = self.get_layout_from_image(
image,
ocr_languages=ocr_languages,
)
ocr_regions = self.get_layout_from_image(image)
# NOTE(christine): For paddle, there is no difference in `ocr_layout` and `ocr_text` in
# terms of grouping because we get ocr_text from `ocr_layout, so the first two grouping

View File

@ -33,22 +33,23 @@ if "OMP_THREAD_LIMIT" not in os.environ:
class OCRAgentTesseract(OCRAgent):
"""OCR service implementation for Tesseract."""
def __init__(self, language: str = "eng"):
self.language = language
def is_text_sorted(self):
return True
def get_text_from_image(self, image: PILImage.Image, ocr_languages: str = "eng") -> str:
return unstructured_pytesseract.image_to_string(np.array(image), lang=ocr_languages)
def get_text_from_image(self, image: PILImage.Image) -> str:
return unstructured_pytesseract.image_to_string(np.array(image), lang=self.language)
def get_layout_from_image(
self, image: PILImage.Image, ocr_languages: str = "eng"
) -> List[TextRegion]:
def get_layout_from_image(self, image: PILImage.Image) -> List[TextRegion]:
"""Get the OCR regions from image as a list of text regions with tesseract."""
trace_logger.detail("Processing entire page OCR with tesseract...")
zoom = 1
ocr_df: pd.DataFrame = unstructured_pytesseract.image_to_data(
np.array(image),
lang=ocr_languages,
lang=self.language,
output_type=Output.DATAFRAME,
)
ocr_df = ocr_df.dropna()
@ -77,7 +78,7 @@ class OCRAgentTesseract(OCRAgent):
)
ocr_df = unstructured_pytesseract.image_to_data(
np.array(zoom_image(image, zoom)),
lang=ocr_languages,
lang=self.language,
output_type=Output.DATAFRAME,
)
ocr_df = ocr_df.dropna()
@ -87,17 +88,12 @@ class OCRAgentTesseract(OCRAgent):
return ocr_regions
@requires_dependencies("unstructured_inference")
def get_layout_elements_from_image(
self, image: PILImage.Image, ocr_languages: str = "eng"
) -> List["LayoutElement"]:
def get_layout_elements_from_image(self, image: PILImage.Image) -> List["LayoutElement"]:
from unstructured.partition.pdf_image.inference_utils import (
build_layout_elements_from_ocr_regions,
)
ocr_regions = self.get_layout_from_image(
image,
ocr_languages=ocr_languages,
)
ocr_regions = self.get_layout_from_image(image)
# NOTE(christine): For tesseract, the ocr_text returned by
# `unstructured_pytesseract.image_to_string()` doesn't contain bounding box data but is
@ -106,10 +102,7 @@ class OCRAgentTesseract(OCRAgent):
# grouped. Therefore, we need to first group the `ocr_layout` by `ocr_text` and then merge
# the text regions in each group to create a list of layout elements.
ocr_text = self.get_text_from_image(
image,
ocr_languages=ocr_languages,
)
ocr_text = self.get_text_from_image(image)
return build_layout_elements_from_ocr_regions(
ocr_regions=ocr_regions,