feat: per-process ingest connections (#1058)

* adds per process connections for Google Drive connector
This commit is contained in:
ryannikolaidis 2023-08-17 10:34:08 -07:00 committed by GitHub
parent dd0f582585
commit 668d0f1b01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 111 additions and 19 deletions

View File

@ -1,9 +1,10 @@
## 0.10.2
## 0.10.3
### Enhancements
* Bump unstructured-inference==0.5.13:
- Fix extracted image elements being included in layout merge, addresses the issue
where an entire-page image in a PDF was not passed to the layout model when using hi_res.
* Adds ability to reuse connections per process in unstructured-ingest
### Features

View File

@ -0,0 +1,39 @@
from dataclasses import dataclass
import pytest
from unstructured.ingest.doc_processor.generalized import (
process_document,
)
from unstructured.ingest.interfaces import BaseIngestDoc, IngestDocSessionHandleMixin
@dataclass
class IngestDocWithSessionHandle(IngestDocSessionHandleMixin, BaseIngestDoc):
pass
def test_process_document_with_session_handle(mocker):
"""Test that the process_document function calls the doc_processor_fn with the correct
arguments, assigns the session handle, and returns the correct results."""
mock_session_handle = mocker.MagicMock()
mocker.patch("unstructured.ingest.doc_processor.generalized.session_handle", mock_session_handle)
mock_doc = mocker.MagicMock(spec=(IngestDocWithSessionHandle))
result = process_document(mock_doc)
mock_doc.get_file.assert_called_once_with()
mock_doc.write_result.assert_called_with()
mock_doc.cleanup_file.assert_called_once_with()
assert result == mock_doc.process_file.return_value
assert mock_doc.session_handle == mock_session_handle
def test_process_document_no_session_handle(mocker):
"""Test that the process_document function calls does not assign session handle the IngestDoc
does not have the session handle mixin."""
mocker.patch("unstructured.ingest.doc_processor.generalized.session_handle", mocker.MagicMock())
mock_doc = mocker.MagicMock(spec=(BaseIngestDoc))
process_document(mock_doc)
assert not hasattr(mock_doc, "session_handle")

View File

@ -144,7 +144,6 @@ def test_partition_file():
assert data_source_metadata["date_processed"] == TEST_DATE_PROCESSSED
@freeze_time(TEST_DATE_PROCESSSED)
def test_process_file_fields_include_default(mocker, partition_test_results):
"""Validate when metadata_include and metadata_exclude are not set, all fields:
("element_id", "text", "type", "metadata") are included"""
@ -162,10 +161,6 @@ def test_process_file_fields_include_default(mocker, partition_test_results):
isd_elems = test_ingest_doc.process_file()
assert len(isd_elems)
assert mock_partition.call_count == 1
assert (
mock_partition.call_args.kwargs["data_source_metadata"].date_processed
== TEST_DATE_PROCESSSED
)
for elem in isd_elems:
assert {"element_id", "text", "type", "metadata"} == set(elem.keys())
data_source_metadata = elem["metadata"]["data_source"]

View File

@ -1 +1 @@
__version__ = "0.10.2" # pragma: no cover
__version__ = "0.10.3" # pragma: no cover

View File

@ -4,7 +4,7 @@ import os
from dataclasses import dataclass
from mimetypes import guess_extension
from pathlib import Path
from typing import Dict, Optional
from typing import TYPE_CHECKING, Dict, Optional
from unstructured.file_utils.filetype import EXT_TO_FILETYPE
from unstructured.file_utils.google_filetype import GOOGLE_DRIVE_EXPORT_TYPES
@ -12,17 +12,28 @@ from unstructured.ingest.interfaces import (
BaseConnector,
BaseConnectorConfig,
BaseIngestDoc,
BaseSessionHandle,
ConfigSessionHandleMixin,
ConnectorCleanupMixin,
IngestDocCleanupMixin,
IngestDocSessionHandleMixin,
StandardConnectorConfig,
)
from unstructured.ingest.logger import logger
from unstructured.utils import requires_dependencies
if TYPE_CHECKING:
from googleapiclient.discovery import Resource as GoogleAPIResource
FILE_FORMAT = "{id}-{name}{ext}"
DIRECTORY_FORMAT = "{id}-{name}"
@dataclass
class GoogleDriveSessionHandle(BaseSessionHandle):
service: "GoogleAPIResource"
@requires_dependencies(["googleapiclient"], extras="google-drive")
def create_service_account_object(key_path, id=None):
"""
@ -65,7 +76,7 @@ def create_service_account_object(key_path, id=None):
@dataclass
class SimpleGoogleDriveConfig(BaseConnectorConfig):
class SimpleGoogleDriveConfig(ConfigSessionHandleMixin, BaseConnectorConfig):
"""Connector config where drive_id is the id of the document to process or
the folder to process all documents from."""
@ -81,11 +92,16 @@ class SimpleGoogleDriveConfig(BaseConnectorConfig):
f"Extension not supported. "
f"Value MUST be one of {', '.join([k for k in EXT_TO_FILETYPE if k is not None])}.",
)
self.service = create_service_account_object(self.service_account_key, self.drive_id)
def create_session_handle(
self,
) -> GoogleDriveSessionHandle:
service = create_service_account_object(self.service_account_key)
return GoogleDriveSessionHandle(service=service)
@dataclass
class GoogleDriveIngestDoc(IngestDocCleanupMixin, BaseIngestDoc):
class GoogleDriveIngestDoc(IngestDocSessionHandleMixin, IngestDocCleanupMixin, BaseIngestDoc):
config: SimpleGoogleDriveConfig
file_meta: Dict
@ -103,8 +119,6 @@ class GoogleDriveIngestDoc(IngestDocCleanupMixin, BaseIngestDoc):
from googleapiclient.errors import HttpError
from googleapiclient.http import MediaIoBaseDownload
self.config.service = create_service_account_object(self.config.service_account_key)
if self.file_meta.get("mimeType", "").startswith("application/vnd.google-apps"):
export_mime = GOOGLE_DRIVE_EXPORT_TYPES.get(
self.file_meta.get("mimeType"), # type: ignore
@ -117,12 +131,12 @@ class GoogleDriveIngestDoc(IngestDocCleanupMixin, BaseIngestDoc):
)
return
request = self.config.service.files().export_media(
request = self.session_handle.service.files().export_media(
fileId=self.file_meta.get("id"),
mimeType=export_mime,
)
else:
request = self.config.service.files().get_media(fileId=self.file_meta.get("id"))
request = self.session_handle.service.files().get_media(fileId=self.file_meta.get("id"))
file = io.BytesIO()
downloader = MediaIoBaseDownload(file, request)
downloaded = False
@ -170,12 +184,13 @@ class GoogleDriveConnector(ConnectorCleanupMixin, BaseConnector):
def _list_objects(self, drive_id, recursive=False):
files = []
service = self.config.create_session_handle().service
def traverse(drive_id, download_dir, output_dir, recursive=False):
page_token = None
while True:
response = (
self.config.service.files()
service.files()
.list(
spaces="drive",
fields="nextPageToken, files(id, name, mimeType)",
@ -244,6 +259,4 @@ class GoogleDriveConnector(ConnectorCleanupMixin, BaseConnector):
def get_ingest_docs(self):
files = self._list_objects(self.config.drive_id, self.config.recursive)
# Setting to None because service object can't be pickled for multiprocessing.
self.config.service = None
return [GoogleDriveIngestDoc(self.standard_config, self.config, file) for file in files]

View File

@ -6,8 +6,15 @@ from typing import Any, Dict, List, Optional
from unstructured_inference.models.base import get_model
from unstructured.ingest.interfaces import BaseIngestDoc as IngestDoc
from unstructured.ingest.interfaces import (
BaseSessionHandle,
IngestDocSessionHandleMixin,
)
from unstructured.ingest.logger import logger
# module-level variable to store session handle
session_handle: Optional[BaseSessionHandle] = None
def initialize():
"""Download default model or model specified by UNSTRUCTURED_HI_RES_MODEL_NAME environment
@ -30,8 +37,16 @@ def process_document(doc: "IngestDoc", **partition_kwargs) -> Optional[List[Dict
partition_kwargs
ultimately the parameters passed to partition()
"""
global session_handle
isd_elems_no_filename = None
try:
if isinstance(doc, IngestDocSessionHandleMixin):
if session_handle is None:
# create via doc.session_handle, which is a property that creates a
# session handle if one is not already defined
session_handle = doc.session_handle
else:
doc.session_handle = session_handle
# does the work necessary to load file into filesystem
# in the future, get_file_handle() could also be supported
doc.get_file()
@ -39,7 +54,7 @@ def process_document(doc: "IngestDoc", **partition_kwargs) -> Optional[List[Dict
isd_elems_no_filename = doc.process_file(**partition_kwargs)
# Note, this may be a no-op if the IngestDoc doesn't do anything to persist
# the results. Instead, the MainProcess (caller) may work with the aggregate
# the results. Instead, the Processor (caller) may work with the aggregate
# results across all docs in memory.
doc.write_result()
except Exception:

View File

@ -18,6 +18,12 @@ from unstructured.partition.auto import partition
from unstructured.staging.base import convert_to_dict
@dataclass
class BaseSessionHandle(ABC):
"""Abstract Base Class for sharing resources that are local to an individual process.
e.g., a connection for making a request for fetching documents."""
@dataclass
class ProcessorConfigs:
"""Common set of config required when running data connectors."""
@ -330,3 +336,26 @@ class IngestDocCleanupMixin:
):
logger.debug(f"Cleaning up {self}")
os.unlink(self.filename)
class ConfigSessionHandleMixin:
@abstractmethod
def create_session_handle(self) -> BaseSessionHandle:
"""Creates a session handle that will be assigned on each IngestDoc to share
session related resources across all document handling for a given subprocess."""
class IngestDocSessionHandleMixin:
config: ConfigSessionHandleMixin
_session_handle: Optional[BaseSessionHandle] = None
@property
def session_handle(self):
"""If a session handle is not assigned, creates a new one and assigns it."""
if self._session_handle is None:
self._session_handle = self.config.create_session_handle()
return self._session_handle
@session_handle.setter
def session_handle(self, session_handle: BaseSessionHandle):
self._session_handle = session_handle