mirror of
https://github.com/Unstructured-IO/unstructured.git
synced 2025-12-16 01:34:56 +00:00
feat: per-process ingest connections (#1058)
* adds per process connections for Google Drive connector
This commit is contained in:
parent
dd0f582585
commit
668d0f1b01
@ -1,9 +1,10 @@
|
||||
## 0.10.2
|
||||
## 0.10.3
|
||||
|
||||
### Enhancements
|
||||
* Bump unstructured-inference==0.5.13:
|
||||
- Fix extracted image elements being included in layout merge, addresses the issue
|
||||
where an entire-page image in a PDF was not passed to the layout model when using hi_res.
|
||||
* Adds ability to reuse connections per process in unstructured-ingest
|
||||
|
||||
### Features
|
||||
|
||||
|
||||
@ -0,0 +1,39 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from unstructured.ingest.doc_processor.generalized import (
|
||||
process_document,
|
||||
)
|
||||
from unstructured.ingest.interfaces import BaseIngestDoc, IngestDocSessionHandleMixin
|
||||
|
||||
|
||||
@dataclass
|
||||
class IngestDocWithSessionHandle(IngestDocSessionHandleMixin, BaseIngestDoc):
|
||||
pass
|
||||
|
||||
def test_process_document_with_session_handle(mocker):
|
||||
"""Test that the process_document function calls the doc_processor_fn with the correct
|
||||
arguments, assigns the session handle, and returns the correct results."""
|
||||
mock_session_handle = mocker.MagicMock()
|
||||
mocker.patch("unstructured.ingest.doc_processor.generalized.session_handle", mock_session_handle)
|
||||
mock_doc = mocker.MagicMock(spec=(IngestDocWithSessionHandle))
|
||||
|
||||
result = process_document(mock_doc)
|
||||
|
||||
mock_doc.get_file.assert_called_once_with()
|
||||
mock_doc.write_result.assert_called_with()
|
||||
mock_doc.cleanup_file.assert_called_once_with()
|
||||
assert result == mock_doc.process_file.return_value
|
||||
assert mock_doc.session_handle == mock_session_handle
|
||||
|
||||
|
||||
def test_process_document_no_session_handle(mocker):
|
||||
"""Test that the process_document function calls does not assign session handle the IngestDoc
|
||||
does not have the session handle mixin."""
|
||||
mocker.patch("unstructured.ingest.doc_processor.generalized.session_handle", mocker.MagicMock())
|
||||
mock_doc = mocker.MagicMock(spec=(BaseIngestDoc))
|
||||
|
||||
process_document(mock_doc)
|
||||
|
||||
assert not hasattr(mock_doc, "session_handle")
|
||||
@ -144,7 +144,6 @@ def test_partition_file():
|
||||
assert data_source_metadata["date_processed"] == TEST_DATE_PROCESSSED
|
||||
|
||||
|
||||
@freeze_time(TEST_DATE_PROCESSSED)
|
||||
def test_process_file_fields_include_default(mocker, partition_test_results):
|
||||
"""Validate when metadata_include and metadata_exclude are not set, all fields:
|
||||
("element_id", "text", "type", "metadata") are included"""
|
||||
@ -162,10 +161,6 @@ def test_process_file_fields_include_default(mocker, partition_test_results):
|
||||
isd_elems = test_ingest_doc.process_file()
|
||||
assert len(isd_elems)
|
||||
assert mock_partition.call_count == 1
|
||||
assert (
|
||||
mock_partition.call_args.kwargs["data_source_metadata"].date_processed
|
||||
== TEST_DATE_PROCESSSED
|
||||
)
|
||||
for elem in isd_elems:
|
||||
assert {"element_id", "text", "type", "metadata"} == set(elem.keys())
|
||||
data_source_metadata = elem["metadata"]["data_source"]
|
||||
|
||||
@ -1 +1 @@
|
||||
__version__ = "0.10.2" # pragma: no cover
|
||||
__version__ = "0.10.3" # pragma: no cover
|
||||
|
||||
@ -4,7 +4,7 @@ import os
|
||||
from dataclasses import dataclass
|
||||
from mimetypes import guess_extension
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
from typing import TYPE_CHECKING, Dict, Optional
|
||||
|
||||
from unstructured.file_utils.filetype import EXT_TO_FILETYPE
|
||||
from unstructured.file_utils.google_filetype import GOOGLE_DRIVE_EXPORT_TYPES
|
||||
@ -12,17 +12,28 @@ from unstructured.ingest.interfaces import (
|
||||
BaseConnector,
|
||||
BaseConnectorConfig,
|
||||
BaseIngestDoc,
|
||||
BaseSessionHandle,
|
||||
ConfigSessionHandleMixin,
|
||||
ConnectorCleanupMixin,
|
||||
IngestDocCleanupMixin,
|
||||
IngestDocSessionHandleMixin,
|
||||
StandardConnectorConfig,
|
||||
)
|
||||
from unstructured.ingest.logger import logger
|
||||
from unstructured.utils import requires_dependencies
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from googleapiclient.discovery import Resource as GoogleAPIResource
|
||||
|
||||
FILE_FORMAT = "{id}-{name}{ext}"
|
||||
DIRECTORY_FORMAT = "{id}-{name}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class GoogleDriveSessionHandle(BaseSessionHandle):
|
||||
service: "GoogleAPIResource"
|
||||
|
||||
|
||||
@requires_dependencies(["googleapiclient"], extras="google-drive")
|
||||
def create_service_account_object(key_path, id=None):
|
||||
"""
|
||||
@ -65,7 +76,7 @@ def create_service_account_object(key_path, id=None):
|
||||
|
||||
|
||||
@dataclass
|
||||
class SimpleGoogleDriveConfig(BaseConnectorConfig):
|
||||
class SimpleGoogleDriveConfig(ConfigSessionHandleMixin, BaseConnectorConfig):
|
||||
"""Connector config where drive_id is the id of the document to process or
|
||||
the folder to process all documents from."""
|
||||
|
||||
@ -81,11 +92,16 @@ class SimpleGoogleDriveConfig(BaseConnectorConfig):
|
||||
f"Extension not supported. "
|
||||
f"Value MUST be one of {', '.join([k for k in EXT_TO_FILETYPE if k is not None])}.",
|
||||
)
|
||||
self.service = create_service_account_object(self.service_account_key, self.drive_id)
|
||||
|
||||
def create_session_handle(
|
||||
self,
|
||||
) -> GoogleDriveSessionHandle:
|
||||
service = create_service_account_object(self.service_account_key)
|
||||
return GoogleDriveSessionHandle(service=service)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GoogleDriveIngestDoc(IngestDocCleanupMixin, BaseIngestDoc):
|
||||
class GoogleDriveIngestDoc(IngestDocSessionHandleMixin, IngestDocCleanupMixin, BaseIngestDoc):
|
||||
config: SimpleGoogleDriveConfig
|
||||
file_meta: Dict
|
||||
|
||||
@ -103,8 +119,6 @@ class GoogleDriveIngestDoc(IngestDocCleanupMixin, BaseIngestDoc):
|
||||
from googleapiclient.errors import HttpError
|
||||
from googleapiclient.http import MediaIoBaseDownload
|
||||
|
||||
self.config.service = create_service_account_object(self.config.service_account_key)
|
||||
|
||||
if self.file_meta.get("mimeType", "").startswith("application/vnd.google-apps"):
|
||||
export_mime = GOOGLE_DRIVE_EXPORT_TYPES.get(
|
||||
self.file_meta.get("mimeType"), # type: ignore
|
||||
@ -117,12 +131,12 @@ class GoogleDriveIngestDoc(IngestDocCleanupMixin, BaseIngestDoc):
|
||||
)
|
||||
return
|
||||
|
||||
request = self.config.service.files().export_media(
|
||||
request = self.session_handle.service.files().export_media(
|
||||
fileId=self.file_meta.get("id"),
|
||||
mimeType=export_mime,
|
||||
)
|
||||
else:
|
||||
request = self.config.service.files().get_media(fileId=self.file_meta.get("id"))
|
||||
request = self.session_handle.service.files().get_media(fileId=self.file_meta.get("id"))
|
||||
file = io.BytesIO()
|
||||
downloader = MediaIoBaseDownload(file, request)
|
||||
downloaded = False
|
||||
@ -170,12 +184,13 @@ class GoogleDriveConnector(ConnectorCleanupMixin, BaseConnector):
|
||||
|
||||
def _list_objects(self, drive_id, recursive=False):
|
||||
files = []
|
||||
service = self.config.create_session_handle().service
|
||||
|
||||
def traverse(drive_id, download_dir, output_dir, recursive=False):
|
||||
page_token = None
|
||||
while True:
|
||||
response = (
|
||||
self.config.service.files()
|
||||
service.files()
|
||||
.list(
|
||||
spaces="drive",
|
||||
fields="nextPageToken, files(id, name, mimeType)",
|
||||
@ -244,6 +259,4 @@ class GoogleDriveConnector(ConnectorCleanupMixin, BaseConnector):
|
||||
|
||||
def get_ingest_docs(self):
|
||||
files = self._list_objects(self.config.drive_id, self.config.recursive)
|
||||
# Setting to None because service object can't be pickled for multiprocessing.
|
||||
self.config.service = None
|
||||
return [GoogleDriveIngestDoc(self.standard_config, self.config, file) for file in files]
|
||||
|
||||
@ -6,8 +6,15 @@ from typing import Any, Dict, List, Optional
|
||||
from unstructured_inference.models.base import get_model
|
||||
|
||||
from unstructured.ingest.interfaces import BaseIngestDoc as IngestDoc
|
||||
from unstructured.ingest.interfaces import (
|
||||
BaseSessionHandle,
|
||||
IngestDocSessionHandleMixin,
|
||||
)
|
||||
from unstructured.ingest.logger import logger
|
||||
|
||||
# module-level variable to store session handle
|
||||
session_handle: Optional[BaseSessionHandle] = None
|
||||
|
||||
|
||||
def initialize():
|
||||
"""Download default model or model specified by UNSTRUCTURED_HI_RES_MODEL_NAME environment
|
||||
@ -30,8 +37,16 @@ def process_document(doc: "IngestDoc", **partition_kwargs) -> Optional[List[Dict
|
||||
partition_kwargs
|
||||
ultimately the parameters passed to partition()
|
||||
"""
|
||||
global session_handle
|
||||
isd_elems_no_filename = None
|
||||
try:
|
||||
if isinstance(doc, IngestDocSessionHandleMixin):
|
||||
if session_handle is None:
|
||||
# create via doc.session_handle, which is a property that creates a
|
||||
# session handle if one is not already defined
|
||||
session_handle = doc.session_handle
|
||||
else:
|
||||
doc.session_handle = session_handle
|
||||
# does the work necessary to load file into filesystem
|
||||
# in the future, get_file_handle() could also be supported
|
||||
doc.get_file()
|
||||
@ -39,7 +54,7 @@ def process_document(doc: "IngestDoc", **partition_kwargs) -> Optional[List[Dict
|
||||
isd_elems_no_filename = doc.process_file(**partition_kwargs)
|
||||
|
||||
# Note, this may be a no-op if the IngestDoc doesn't do anything to persist
|
||||
# the results. Instead, the MainProcess (caller) may work with the aggregate
|
||||
# the results. Instead, the Processor (caller) may work with the aggregate
|
||||
# results across all docs in memory.
|
||||
doc.write_result()
|
||||
except Exception:
|
||||
|
||||
@ -18,6 +18,12 @@ from unstructured.partition.auto import partition
|
||||
from unstructured.staging.base import convert_to_dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseSessionHandle(ABC):
|
||||
"""Abstract Base Class for sharing resources that are local to an individual process.
|
||||
e.g., a connection for making a request for fetching documents."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessorConfigs:
|
||||
"""Common set of config required when running data connectors."""
|
||||
@ -330,3 +336,26 @@ class IngestDocCleanupMixin:
|
||||
):
|
||||
logger.debug(f"Cleaning up {self}")
|
||||
os.unlink(self.filename)
|
||||
|
||||
|
||||
class ConfigSessionHandleMixin:
|
||||
@abstractmethod
|
||||
def create_session_handle(self) -> BaseSessionHandle:
|
||||
"""Creates a session handle that will be assigned on each IngestDoc to share
|
||||
session related resources across all document handling for a given subprocess."""
|
||||
|
||||
|
||||
class IngestDocSessionHandleMixin:
|
||||
config: ConfigSessionHandleMixin
|
||||
_session_handle: Optional[BaseSessionHandle] = None
|
||||
|
||||
@property
|
||||
def session_handle(self):
|
||||
"""If a session handle is not assigned, creates a new one and assigns it."""
|
||||
if self._session_handle is None:
|
||||
self._session_handle = self.config.create_session_handle()
|
||||
return self._session_handle
|
||||
|
||||
@session_handle.setter
|
||||
def session_handle(self, session_handle: BaseSessionHandle):
|
||||
self._session_handle = session_handle
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user