mirror of
https://github.com/Unstructured-IO/unstructured.git
synced 2025-09-01 21:04:06 +00:00
356 lines
12 KiB
Python
356 lines
12 KiB
Python
import hashlib
|
|
import json
|
|
import sys
|
|
import uuid
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from time import time
|
|
from typing import TYPE_CHECKING, Any, Generator, Optional
|
|
|
|
from unstructured.documents.elements import DataSourceMetadata
|
|
from unstructured.ingest.enhanced_dataclass import enhanced_field
|
|
from unstructured.ingest.error import SourceConnectionNetworkError
|
|
from unstructured.ingest.utils.data_prep import generator_batching_wbytes
|
|
from unstructured.ingest.v2.interfaces import (
|
|
AccessConfig,
|
|
ConnectionConfig,
|
|
Downloader,
|
|
DownloaderConfig,
|
|
DownloadResponse,
|
|
FileData,
|
|
Indexer,
|
|
IndexerConfig,
|
|
UploadContent,
|
|
Uploader,
|
|
UploaderConfig,
|
|
UploadStager,
|
|
UploadStagerConfig,
|
|
download_responses,
|
|
)
|
|
from unstructured.ingest.v2.logger import logger
|
|
from unstructured.ingest.v2.processes.connector_registry import (
|
|
DestinationRegistryEntry,
|
|
SourceRegistryEntry,
|
|
add_destination_entry,
|
|
add_source_entry,
|
|
)
|
|
from unstructured.staging.base import flatten_dict
|
|
from unstructured.utils import requires_dependencies
|
|
|
|
if TYPE_CHECKING:
|
|
from elasticsearch import Elasticsearch as ElasticsearchClient
|
|
|
|
CONNECTOR_TYPE = "elasticsearch"
|
|
|
|
|
|
@dataclass
|
|
class ElasticsearchAccessConfig(AccessConfig):
|
|
password: Optional[str] = None
|
|
api_key: Optional[str] = enhanced_field(default=None, overload_name="es_api_key")
|
|
bearer_auth: Optional[str] = None
|
|
ssl_assert_fingerprint: Optional[str] = None
|
|
|
|
|
|
@dataclass
|
|
class ElasticsearchConnectionConfig(ConnectionConfig):
|
|
hosts: Optional[list[str]] = None
|
|
username: Optional[str] = None
|
|
cloud_id: Optional[str] = None
|
|
api_key_id: Optional[str] = None
|
|
ca_certs: Optional[str] = None
|
|
access_config: ElasticsearchAccessConfig = enhanced_field(sensitive=True)
|
|
|
|
def get_client_kwargs(self) -> dict:
|
|
# Update auth related fields to conform to what the SDK expects based on the
|
|
# supported methods:
|
|
# https://www.elastic.co/guide/en/elasticsearch/client/python-api/current/connecting.html
|
|
client_kwargs = {
|
|
"hosts": self.hosts,
|
|
}
|
|
if self.ca_certs:
|
|
client_kwargs["ca_certs"] = self.ca_certs
|
|
if self.access_config.password and (
|
|
self.cloud_id or self.ca_certs or self.access_config.ssl_assert_fingerprint
|
|
):
|
|
client_kwargs["basic_auth"] = ("elastic", self.access_config.password)
|
|
elif not self.cloud_id and self.username and self.access_config.password:
|
|
client_kwargs["basic_auth"] = (self.username, self.access_config.password)
|
|
elif self.access_config.api_key and self.api_key_id:
|
|
client_kwargs["api_key"] = (self.api_key_id, self.access_config.api_key)
|
|
return client_kwargs
|
|
|
|
@requires_dependencies(["elasticsearch"], extras="elasticsearch")
|
|
def get_client(self) -> "ElasticsearchClient":
|
|
from elasticsearch import Elasticsearch as ElasticsearchClient
|
|
|
|
return ElasticsearchClient(**self.get_client_kwargs())
|
|
|
|
|
|
@dataclass
|
|
class ElasticsearchIndexerConfig(IndexerConfig):
|
|
index_name: str
|
|
batch_size: int = 100
|
|
|
|
|
|
@dataclass
|
|
class ElasticsearchIndexer(Indexer):
|
|
connection_config: ElasticsearchConnectionConfig
|
|
index_config: ElasticsearchIndexerConfig
|
|
client: "ElasticsearchClient" = field(init=False)
|
|
connector_type: str = CONNECTOR_TYPE
|
|
|
|
def __post_init__(self):
|
|
self.client = self.connection_config.get_client()
|
|
|
|
@requires_dependencies(["elasticsearch"], extras="elasticsearch")
|
|
def _get_doc_ids(self) -> set[str]:
|
|
"""Fetches all document ids in an index"""
|
|
from elasticsearch.helpers import scan
|
|
|
|
scan_query: dict = {"stored_fields": [], "query": {"match_all": {}}}
|
|
hits = scan(
|
|
self.client,
|
|
query=scan_query,
|
|
scroll="1m",
|
|
index=self.index_config.index_name,
|
|
)
|
|
|
|
return {hit["_id"] for hit in hits}
|
|
|
|
def run(self, **kwargs: Any) -> Generator[FileData, None, None]:
|
|
all_ids = self._get_doc_ids()
|
|
ids = list(all_ids)
|
|
id_batches: list[frozenset[str]] = [
|
|
frozenset(
|
|
ids[
|
|
i
|
|
* self.index_config.batch_size : (i + 1) # noqa
|
|
* self.index_config.batch_size
|
|
]
|
|
)
|
|
for i in range(
|
|
(len(ids) + self.index_config.batch_size - 1) // self.index_config.batch_size
|
|
)
|
|
]
|
|
for batch in id_batches:
|
|
# Make sure the hash is always a positive number to create identified
|
|
identified = str(hash(batch) + sys.maxsize + 1)
|
|
yield FileData(
|
|
identifier=identified,
|
|
connector_type=CONNECTOR_TYPE,
|
|
metadata=DataSourceMetadata(
|
|
url=f"{self.connection_config.hosts[0]}/{self.index_config.index_name}",
|
|
date_processed=str(time()),
|
|
),
|
|
additional_metadata={
|
|
"ids": list(batch),
|
|
"index_name": self.index_config.index_name,
|
|
},
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class ElasticsearchDownloaderConfig(DownloaderConfig):
|
|
fields: list[str] = field(default_factory=list)
|
|
|
|
|
|
@dataclass
|
|
class ElasticsearchDownloader(Downloader):
|
|
connection_config: ElasticsearchConnectionConfig
|
|
download_config: ElasticsearchDownloaderConfig
|
|
connector_type: str = CONNECTOR_TYPE
|
|
|
|
def is_async(self) -> bool:
|
|
return True
|
|
|
|
def get_identifier(self, index_name: str, record_id: str) -> str:
|
|
f = f"{index_name}-{record_id}"
|
|
if self.download_config.fields:
|
|
f = "{}-{}".format(
|
|
f,
|
|
hashlib.sha256(",".join(self.download_config.fields).encode()).hexdigest()[:8],
|
|
)
|
|
return f
|
|
|
|
def map_es_results(self, es_results: dict) -> str:
|
|
doc_body = es_results["_source"]
|
|
flattened_dict = flatten_dict(dictionary=doc_body)
|
|
str_values = [str(value) for value in flattened_dict.values()]
|
|
concatenated_values = "\n".join(str_values)
|
|
return concatenated_values
|
|
|
|
def generate_download_response(
|
|
self, result: dict, index_name: str, file_data: FileData
|
|
) -> DownloadResponse:
|
|
record_id = result["_id"]
|
|
filename_id = self.get_identifier(index_name=index_name, record_id=record_id)
|
|
filename = f"{filename_id}.txt"
|
|
download_path = self.download_dir / Path(filename)
|
|
logger.debug(
|
|
f"Downloading results from index {index_name} and id {record_id} to {download_path}"
|
|
)
|
|
download_path.parent.mkdir(parents=True, exist_ok=True)
|
|
try:
|
|
with open(download_path, "w", encoding="utf8") as f:
|
|
f.write(self.map_es_results(es_results=result))
|
|
except Exception as e:
|
|
logger.error(
|
|
f"failed to download from index {index_name} "
|
|
f"and id {record_id} to {download_path}: {e}",
|
|
exc_info=True,
|
|
)
|
|
raise SourceConnectionNetworkError(f"failed to download file {file_data.identifier}")
|
|
return DownloadResponse(
|
|
file_data=FileData(
|
|
identifier=filename_id,
|
|
connector_type=CONNECTOR_TYPE,
|
|
metadata=DataSourceMetadata(
|
|
version=str(result["_version"]) if "_version" in result else None,
|
|
date_processed=str(time()),
|
|
record_locator={
|
|
"hosts": self.connection_config.hosts,
|
|
"index_name": index_name,
|
|
"document_id": record_id,
|
|
},
|
|
),
|
|
),
|
|
path=download_path,
|
|
)
|
|
|
|
def run(self, file_data: FileData, **kwargs: Any) -> download_responses:
|
|
raise NotImplementedError()
|
|
|
|
@requires_dependencies(["elasticsearch"], extras="elasticsearch")
|
|
async def run_async(self, file_data: FileData, **kwargs: Any) -> download_responses:
|
|
from elasticsearch import AsyncElasticsearch as AsyncElasticsearchClient
|
|
from elasticsearch.helpers import async_scan
|
|
|
|
index_name: str = file_data.additional_metadata["index_name"]
|
|
ids: list[str] = file_data.additional_metadata["ids"]
|
|
|
|
scan_query = {
|
|
"_source": self.download_config.fields,
|
|
"version": True,
|
|
"query": {"ids": {"values": ids}},
|
|
}
|
|
|
|
download_responses = []
|
|
async with AsyncElasticsearchClient(**self.connection_config.get_client_kwargs()) as client:
|
|
async for result in async_scan(
|
|
client,
|
|
query=scan_query,
|
|
scroll="1m",
|
|
index=index_name,
|
|
):
|
|
download_responses.append(
|
|
self.generate_download_response(
|
|
result=result, index_name=index_name, file_data=file_data
|
|
)
|
|
)
|
|
return download_responses
|
|
|
|
|
|
@dataclass
|
|
class ElasticsearchUploadStagerConfig(UploadStagerConfig):
|
|
index_name: str
|
|
|
|
|
|
@dataclass
|
|
class ElasticsearchUploadStager(UploadStager):
|
|
upload_stager_config: ElasticsearchUploadStagerConfig
|
|
|
|
def conform_dict(self, data: dict) -> dict:
|
|
resp = {
|
|
"_index": self.upload_stager_config.index_name,
|
|
"_id": str(uuid.uuid4()),
|
|
"_source": {
|
|
"element_id": data.pop("element_id", None),
|
|
"embeddings": data.pop("embeddings", None),
|
|
"text": data.pop("text", None),
|
|
"type": data.pop("type", None),
|
|
},
|
|
}
|
|
if "metadata" in data and isinstance(data["metadata"], dict):
|
|
resp["_source"]["metadata"] = flatten_dict(data["metadata"], separator="-")
|
|
return resp
|
|
|
|
def run(
|
|
self,
|
|
elements_filepath: Path,
|
|
file_data: FileData,
|
|
output_dir: Path,
|
|
output_filename: str,
|
|
**kwargs: Any,
|
|
) -> Path:
|
|
with open(elements_filepath) as elements_file:
|
|
elements_contents = json.load(elements_file)
|
|
conformed_elements = [self.conform_dict(data=element) for element in elements_contents]
|
|
output_path = Path(output_dir) / Path(f"{output_filename}.json")
|
|
with open(output_path, "w") as output_file:
|
|
json.dump(conformed_elements, output_file)
|
|
return output_path
|
|
|
|
|
|
@dataclass
|
|
class ElasticsearchUploaderConfig(UploaderConfig):
|
|
index_name: str
|
|
batch_size_bytes: int = 15_000_000
|
|
thread_count: int = 4
|
|
|
|
|
|
@dataclass
|
|
class ElasticsearchUploader(Uploader):
|
|
upload_config: ElasticsearchUploaderConfig
|
|
connection_config: ElasticsearchConnectionConfig
|
|
|
|
def run(self, contents: list[UploadContent], **kwargs: Any) -> None:
|
|
elements_dict = []
|
|
for content in contents:
|
|
with open(content.path) as elements_file:
|
|
elements = json.load(elements_file)
|
|
elements_dict.extend(elements)
|
|
logger.info(
|
|
f"writing document batches to destination"
|
|
f" index named {self.upload_config.index_name}"
|
|
f" at {self.connection_config.hosts}"
|
|
f" with batch size (in bytes) {self.upload_config.batch_size_bytes}"
|
|
f" with {self.upload_config.thread_count} (number of) threads"
|
|
)
|
|
from elasticsearch.helpers import parallel_bulk
|
|
|
|
for batch in generator_batching_wbytes(
|
|
elements_dict, batch_size_limit_bytes=self.upload_config.batch_size_bytes
|
|
):
|
|
for success, info in parallel_bulk(
|
|
self.connection_config.get_client(),
|
|
batch,
|
|
thread_count=self.upload_config.thread_count,
|
|
):
|
|
if not success:
|
|
logger.error(
|
|
"upload failed for a batch in elasticsearch destination connector:", info
|
|
)
|
|
|
|
|
|
add_source_entry(
|
|
source_type=CONNECTOR_TYPE,
|
|
entry=SourceRegistryEntry(
|
|
connection_config=ElasticsearchConnectionConfig,
|
|
indexer=ElasticsearchIndexer,
|
|
indexer_config=ElasticsearchIndexerConfig,
|
|
downloader=ElasticsearchDownloader,
|
|
downloader_config=ElasticsearchDownloaderConfig,
|
|
),
|
|
)
|
|
|
|
add_destination_entry(
|
|
destination_type=CONNECTOR_TYPE,
|
|
entry=DestinationRegistryEntry(
|
|
connection_config=ElasticsearchConnectionConfig,
|
|
upload_stager_config=ElasticsearchUploadStagerConfig,
|
|
upload_stager=ElasticsearchUploadStager,
|
|
uploader_config=ElasticsearchUploaderConfig,
|
|
uploader=ElasticsearchUploader,
|
|
),
|
|
)
|