feat: per-process ingest connections (#1058)

* adds per process connections for Google Drive connector
This commit is contained in:
ryannikolaidis 2023-08-17 10:34:08 -07:00 committed by GitHub
parent dd0f582585
commit 668d0f1b01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 111 additions and 19 deletions

View File

@ -1,9 +1,10 @@
## 0.10.2 ## 0.10.3
### Enhancements ### Enhancements
* Bump unstructured-inference==0.5.13: * Bump unstructured-inference==0.5.13:
- Fix extracted image elements being included in layout merge, addresses the issue - Fix extracted image elements being included in layout merge, addresses the issue
where an entire-page image in a PDF was not passed to the layout model when using hi_res. where an entire-page image in a PDF was not passed to the layout model when using hi_res.
* Adds ability to reuse connections per process in unstructured-ingest
### Features ### Features

View File

@ -0,0 +1,39 @@
from dataclasses import dataclass
import pytest
from unstructured.ingest.doc_processor.generalized import (
process_document,
)
from unstructured.ingest.interfaces import BaseIngestDoc, IngestDocSessionHandleMixin
@dataclass
class IngestDocWithSessionHandle(IngestDocSessionHandleMixin, BaseIngestDoc):
pass
def test_process_document_with_session_handle(mocker):
"""Test that the process_document function calls the doc_processor_fn with the correct
arguments, assigns the session handle, and returns the correct results."""
mock_session_handle = mocker.MagicMock()
mocker.patch("unstructured.ingest.doc_processor.generalized.session_handle", mock_session_handle)
mock_doc = mocker.MagicMock(spec=(IngestDocWithSessionHandle))
result = process_document(mock_doc)
mock_doc.get_file.assert_called_once_with()
mock_doc.write_result.assert_called_with()
mock_doc.cleanup_file.assert_called_once_with()
assert result == mock_doc.process_file.return_value
assert mock_doc.session_handle == mock_session_handle
def test_process_document_no_session_handle(mocker):
"""Test that the process_document function calls does not assign session handle the IngestDoc
does not have the session handle mixin."""
mocker.patch("unstructured.ingest.doc_processor.generalized.session_handle", mocker.MagicMock())
mock_doc = mocker.MagicMock(spec=(BaseIngestDoc))
process_document(mock_doc)
assert not hasattr(mock_doc, "session_handle")

View File

@ -144,7 +144,6 @@ def test_partition_file():
assert data_source_metadata["date_processed"] == TEST_DATE_PROCESSSED assert data_source_metadata["date_processed"] == TEST_DATE_PROCESSSED
@freeze_time(TEST_DATE_PROCESSSED)
def test_process_file_fields_include_default(mocker, partition_test_results): def test_process_file_fields_include_default(mocker, partition_test_results):
"""Validate when metadata_include and metadata_exclude are not set, all fields: """Validate when metadata_include and metadata_exclude are not set, all fields:
("element_id", "text", "type", "metadata") are included""" ("element_id", "text", "type", "metadata") are included"""
@ -162,10 +161,6 @@ def test_process_file_fields_include_default(mocker, partition_test_results):
isd_elems = test_ingest_doc.process_file() isd_elems = test_ingest_doc.process_file()
assert len(isd_elems) assert len(isd_elems)
assert mock_partition.call_count == 1 assert mock_partition.call_count == 1
assert (
mock_partition.call_args.kwargs["data_source_metadata"].date_processed
== TEST_DATE_PROCESSSED
)
for elem in isd_elems: for elem in isd_elems:
assert {"element_id", "text", "type", "metadata"} == set(elem.keys()) assert {"element_id", "text", "type", "metadata"} == set(elem.keys())
data_source_metadata = elem["metadata"]["data_source"] data_source_metadata = elem["metadata"]["data_source"]

View File

@ -1 +1 @@
__version__ = "0.10.2" # pragma: no cover __version__ = "0.10.3" # pragma: no cover

View File

@ -4,7 +4,7 @@ import os
from dataclasses import dataclass from dataclasses import dataclass
from mimetypes import guess_extension from mimetypes import guess_extension
from pathlib import Path from pathlib import Path
from typing import Dict, Optional from typing import TYPE_CHECKING, Dict, Optional
from unstructured.file_utils.filetype import EXT_TO_FILETYPE from unstructured.file_utils.filetype import EXT_TO_FILETYPE
from unstructured.file_utils.google_filetype import GOOGLE_DRIVE_EXPORT_TYPES from unstructured.file_utils.google_filetype import GOOGLE_DRIVE_EXPORT_TYPES
@ -12,17 +12,28 @@ from unstructured.ingest.interfaces import (
BaseConnector, BaseConnector,
BaseConnectorConfig, BaseConnectorConfig,
BaseIngestDoc, BaseIngestDoc,
BaseSessionHandle,
ConfigSessionHandleMixin,
ConnectorCleanupMixin, ConnectorCleanupMixin,
IngestDocCleanupMixin, IngestDocCleanupMixin,
IngestDocSessionHandleMixin,
StandardConnectorConfig, StandardConnectorConfig,
) )
from unstructured.ingest.logger import logger from unstructured.ingest.logger import logger
from unstructured.utils import requires_dependencies from unstructured.utils import requires_dependencies
if TYPE_CHECKING:
from googleapiclient.discovery import Resource as GoogleAPIResource
FILE_FORMAT = "{id}-{name}{ext}" FILE_FORMAT = "{id}-{name}{ext}"
DIRECTORY_FORMAT = "{id}-{name}" DIRECTORY_FORMAT = "{id}-{name}"
@dataclass
class GoogleDriveSessionHandle(BaseSessionHandle):
service: "GoogleAPIResource"
@requires_dependencies(["googleapiclient"], extras="google-drive") @requires_dependencies(["googleapiclient"], extras="google-drive")
def create_service_account_object(key_path, id=None): def create_service_account_object(key_path, id=None):
""" """
@ -65,7 +76,7 @@ def create_service_account_object(key_path, id=None):
@dataclass @dataclass
class SimpleGoogleDriveConfig(BaseConnectorConfig): class SimpleGoogleDriveConfig(ConfigSessionHandleMixin, BaseConnectorConfig):
"""Connector config where drive_id is the id of the document to process or """Connector config where drive_id is the id of the document to process or
the folder to process all documents from.""" the folder to process all documents from."""
@ -81,11 +92,16 @@ class SimpleGoogleDriveConfig(BaseConnectorConfig):
f"Extension not supported. " f"Extension not supported. "
f"Value MUST be one of {', '.join([k for k in EXT_TO_FILETYPE if k is not None])}.", f"Value MUST be one of {', '.join([k for k in EXT_TO_FILETYPE if k is not None])}.",
) )
self.service = create_service_account_object(self.service_account_key, self.drive_id)
def create_session_handle(
self,
) -> GoogleDriveSessionHandle:
service = create_service_account_object(self.service_account_key)
return GoogleDriveSessionHandle(service=service)
@dataclass @dataclass
class GoogleDriveIngestDoc(IngestDocCleanupMixin, BaseIngestDoc): class GoogleDriveIngestDoc(IngestDocSessionHandleMixin, IngestDocCleanupMixin, BaseIngestDoc):
config: SimpleGoogleDriveConfig config: SimpleGoogleDriveConfig
file_meta: Dict file_meta: Dict
@ -103,8 +119,6 @@ class GoogleDriveIngestDoc(IngestDocCleanupMixin, BaseIngestDoc):
from googleapiclient.errors import HttpError from googleapiclient.errors import HttpError
from googleapiclient.http import MediaIoBaseDownload from googleapiclient.http import MediaIoBaseDownload
self.config.service = create_service_account_object(self.config.service_account_key)
if self.file_meta.get("mimeType", "").startswith("application/vnd.google-apps"): if self.file_meta.get("mimeType", "").startswith("application/vnd.google-apps"):
export_mime = GOOGLE_DRIVE_EXPORT_TYPES.get( export_mime = GOOGLE_DRIVE_EXPORT_TYPES.get(
self.file_meta.get("mimeType"), # type: ignore self.file_meta.get("mimeType"), # type: ignore
@ -117,12 +131,12 @@ class GoogleDriveIngestDoc(IngestDocCleanupMixin, BaseIngestDoc):
) )
return return
request = self.config.service.files().export_media( request = self.session_handle.service.files().export_media(
fileId=self.file_meta.get("id"), fileId=self.file_meta.get("id"),
mimeType=export_mime, mimeType=export_mime,
) )
else: else:
request = self.config.service.files().get_media(fileId=self.file_meta.get("id")) request = self.session_handle.service.files().get_media(fileId=self.file_meta.get("id"))
file = io.BytesIO() file = io.BytesIO()
downloader = MediaIoBaseDownload(file, request) downloader = MediaIoBaseDownload(file, request)
downloaded = False downloaded = False
@ -170,12 +184,13 @@ class GoogleDriveConnector(ConnectorCleanupMixin, BaseConnector):
def _list_objects(self, drive_id, recursive=False): def _list_objects(self, drive_id, recursive=False):
files = [] files = []
service = self.config.create_session_handle().service
def traverse(drive_id, download_dir, output_dir, recursive=False): def traverse(drive_id, download_dir, output_dir, recursive=False):
page_token = None page_token = None
while True: while True:
response = ( response = (
self.config.service.files() service.files()
.list( .list(
spaces="drive", spaces="drive",
fields="nextPageToken, files(id, name, mimeType)", fields="nextPageToken, files(id, name, mimeType)",
@ -244,6 +259,4 @@ class GoogleDriveConnector(ConnectorCleanupMixin, BaseConnector):
def get_ingest_docs(self): def get_ingest_docs(self):
files = self._list_objects(self.config.drive_id, self.config.recursive) files = self._list_objects(self.config.drive_id, self.config.recursive)
# Setting to None because service object can't be pickled for multiprocessing.
self.config.service = None
return [GoogleDriveIngestDoc(self.standard_config, self.config, file) for file in files] return [GoogleDriveIngestDoc(self.standard_config, self.config, file) for file in files]

View File

@ -6,8 +6,15 @@ from typing import Any, Dict, List, Optional
from unstructured_inference.models.base import get_model from unstructured_inference.models.base import get_model
from unstructured.ingest.interfaces import BaseIngestDoc as IngestDoc from unstructured.ingest.interfaces import BaseIngestDoc as IngestDoc
from unstructured.ingest.interfaces import (
BaseSessionHandle,
IngestDocSessionHandleMixin,
)
from unstructured.ingest.logger import logger from unstructured.ingest.logger import logger
# module-level variable to store session handle
session_handle: Optional[BaseSessionHandle] = None
def initialize(): def initialize():
"""Download default model or model specified by UNSTRUCTURED_HI_RES_MODEL_NAME environment """Download default model or model specified by UNSTRUCTURED_HI_RES_MODEL_NAME environment
@ -30,8 +37,16 @@ def process_document(doc: "IngestDoc", **partition_kwargs) -> Optional[List[Dict
partition_kwargs partition_kwargs
ultimately the parameters passed to partition() ultimately the parameters passed to partition()
""" """
global session_handle
isd_elems_no_filename = None isd_elems_no_filename = None
try: try:
if isinstance(doc, IngestDocSessionHandleMixin):
if session_handle is None:
# create via doc.session_handle, which is a property that creates a
# session handle if one is not already defined
session_handle = doc.session_handle
else:
doc.session_handle = session_handle
# does the work necessary to load file into filesystem # does the work necessary to load file into filesystem
# in the future, get_file_handle() could also be supported # in the future, get_file_handle() could also be supported
doc.get_file() doc.get_file()
@ -39,7 +54,7 @@ def process_document(doc: "IngestDoc", **partition_kwargs) -> Optional[List[Dict
isd_elems_no_filename = doc.process_file(**partition_kwargs) isd_elems_no_filename = doc.process_file(**partition_kwargs)
# Note, this may be a no-op if the IngestDoc doesn't do anything to persist # Note, this may be a no-op if the IngestDoc doesn't do anything to persist
# the results. Instead, the MainProcess (caller) may work with the aggregate # the results. Instead, the Processor (caller) may work with the aggregate
# results across all docs in memory. # results across all docs in memory.
doc.write_result() doc.write_result()
except Exception: except Exception:

View File

@ -18,6 +18,12 @@ from unstructured.partition.auto import partition
from unstructured.staging.base import convert_to_dict from unstructured.staging.base import convert_to_dict
@dataclass
class BaseSessionHandle(ABC):
"""Abstract Base Class for sharing resources that are local to an individual process.
e.g., a connection for making a request for fetching documents."""
@dataclass @dataclass
class ProcessorConfigs: class ProcessorConfigs:
"""Common set of config required when running data connectors.""" """Common set of config required when running data connectors."""
@ -330,3 +336,26 @@ class IngestDocCleanupMixin:
): ):
logger.debug(f"Cleaning up {self}") logger.debug(f"Cleaning up {self}")
os.unlink(self.filename) os.unlink(self.filename)
class ConfigSessionHandleMixin:
@abstractmethod
def create_session_handle(self) -> BaseSessionHandle:
"""Creates a session handle that will be assigned on each IngestDoc to share
session related resources across all document handling for a given subprocess."""
class IngestDocSessionHandleMixin:
config: ConfigSessionHandleMixin
_session_handle: Optional[BaseSessionHandle] = None
@property
def session_handle(self):
"""If a session handle is not assigned, creates a new one and assigns it."""
if self._session_handle is None:
self._session_handle = self.config.create_session_handle()
return self._session_handle
@session_handle.setter
def session_handle(self, session_handle: BaseSessionHandle):
self._session_handle = session_handle