Roman Isecke fd98cf9ea5
Roman/migrate es dest (#3224)
### Description
Migrate elasticsearch destination connector to new v2 ingest framework
2024-06-18 14:20:49 +00:00

356 lines
12 KiB
Python

import hashlib
import json
import sys
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from time import time
from typing import TYPE_CHECKING, Any, Generator, Optional
from unstructured.documents.elements import DataSourceMetadata
from unstructured.ingest.enhanced_dataclass import enhanced_field
from unstructured.ingest.error import SourceConnectionNetworkError
from unstructured.ingest.utils.data_prep import generator_batching_wbytes
from unstructured.ingest.v2.interfaces import (
AccessConfig,
ConnectionConfig,
Downloader,
DownloaderConfig,
DownloadResponse,
FileData,
Indexer,
IndexerConfig,
UploadContent,
Uploader,
UploaderConfig,
UploadStager,
UploadStagerConfig,
download_responses,
)
from unstructured.ingest.v2.logger import logger
from unstructured.ingest.v2.processes.connector_registry import (
DestinationRegistryEntry,
SourceRegistryEntry,
add_destination_entry,
add_source_entry,
)
from unstructured.staging.base import flatten_dict
from unstructured.utils import requires_dependencies
if TYPE_CHECKING:
from elasticsearch import Elasticsearch as ElasticsearchClient
CONNECTOR_TYPE = "elasticsearch"
@dataclass
class ElasticsearchAccessConfig(AccessConfig):
password: Optional[str] = None
api_key: Optional[str] = enhanced_field(default=None, overload_name="es_api_key")
bearer_auth: Optional[str] = None
ssl_assert_fingerprint: Optional[str] = None
@dataclass
class ElasticsearchConnectionConfig(ConnectionConfig):
hosts: Optional[list[str]] = None
username: Optional[str] = None
cloud_id: Optional[str] = None
api_key_id: Optional[str] = None
ca_certs: Optional[str] = None
access_config: ElasticsearchAccessConfig = enhanced_field(sensitive=True)
def get_client_kwargs(self) -> dict:
# Update auth related fields to conform to what the SDK expects based on the
# supported methods:
# https://www.elastic.co/guide/en/elasticsearch/client/python-api/current/connecting.html
client_kwargs = {
"hosts": self.hosts,
}
if self.ca_certs:
client_kwargs["ca_certs"] = self.ca_certs
if self.access_config.password and (
self.cloud_id or self.ca_certs or self.access_config.ssl_assert_fingerprint
):
client_kwargs["basic_auth"] = ("elastic", self.access_config.password)
elif not self.cloud_id and self.username and self.access_config.password:
client_kwargs["basic_auth"] = (self.username, self.access_config.password)
elif self.access_config.api_key and self.api_key_id:
client_kwargs["api_key"] = (self.api_key_id, self.access_config.api_key)
return client_kwargs
@requires_dependencies(["elasticsearch"], extras="elasticsearch")
def get_client(self) -> "ElasticsearchClient":
from elasticsearch import Elasticsearch as ElasticsearchClient
return ElasticsearchClient(**self.get_client_kwargs())
@dataclass
class ElasticsearchIndexerConfig(IndexerConfig):
index_name: str
batch_size: int = 100
@dataclass
class ElasticsearchIndexer(Indexer):
connection_config: ElasticsearchConnectionConfig
index_config: ElasticsearchIndexerConfig
client: "ElasticsearchClient" = field(init=False)
connector_type: str = CONNECTOR_TYPE
def __post_init__(self):
self.client = self.connection_config.get_client()
@requires_dependencies(["elasticsearch"], extras="elasticsearch")
def _get_doc_ids(self) -> set[str]:
"""Fetches all document ids in an index"""
from elasticsearch.helpers import scan
scan_query: dict = {"stored_fields": [], "query": {"match_all": {}}}
hits = scan(
self.client,
query=scan_query,
scroll="1m",
index=self.index_config.index_name,
)
return {hit["_id"] for hit in hits}
def run(self, **kwargs: Any) -> Generator[FileData, None, None]:
all_ids = self._get_doc_ids()
ids = list(all_ids)
id_batches: list[frozenset[str]] = [
frozenset(
ids[
i
* self.index_config.batch_size : (i + 1) # noqa
* self.index_config.batch_size
]
)
for i in range(
(len(ids) + self.index_config.batch_size - 1) // self.index_config.batch_size
)
]
for batch in id_batches:
# Make sure the hash is always a positive number to create identified
identified = str(hash(batch) + sys.maxsize + 1)
yield FileData(
identifier=identified,
connector_type=CONNECTOR_TYPE,
metadata=DataSourceMetadata(
url=f"{self.connection_config.hosts[0]}/{self.index_config.index_name}",
date_processed=str(time()),
),
additional_metadata={
"ids": list(batch),
"index_name": self.index_config.index_name,
},
)
@dataclass
class ElasticsearchDownloaderConfig(DownloaderConfig):
fields: list[str] = field(default_factory=list)
@dataclass
class ElasticsearchDownloader(Downloader):
connection_config: ElasticsearchConnectionConfig
download_config: ElasticsearchDownloaderConfig
connector_type: str = CONNECTOR_TYPE
def is_async(self) -> bool:
return True
def get_identifier(self, index_name: str, record_id: str) -> str:
f = f"{index_name}-{record_id}"
if self.download_config.fields:
f = "{}-{}".format(
f,
hashlib.sha256(",".join(self.download_config.fields).encode()).hexdigest()[:8],
)
return f
def map_es_results(self, es_results: dict) -> str:
doc_body = es_results["_source"]
flattened_dict = flatten_dict(dictionary=doc_body)
str_values = [str(value) for value in flattened_dict.values()]
concatenated_values = "\n".join(str_values)
return concatenated_values
def generate_download_response(
self, result: dict, index_name: str, file_data: FileData
) -> DownloadResponse:
record_id = result["_id"]
filename_id = self.get_identifier(index_name=index_name, record_id=record_id)
filename = f"{filename_id}.txt"
download_path = self.download_dir / Path(filename)
logger.debug(
f"Downloading results from index {index_name} and id {record_id} to {download_path}"
)
download_path.parent.mkdir(parents=True, exist_ok=True)
try:
with open(download_path, "w", encoding="utf8") as f:
f.write(self.map_es_results(es_results=result))
except Exception as e:
logger.error(
f"failed to download from index {index_name} "
f"and id {record_id} to {download_path}: {e}",
exc_info=True,
)
raise SourceConnectionNetworkError(f"failed to download file {file_data.identifier}")
return DownloadResponse(
file_data=FileData(
identifier=filename_id,
connector_type=CONNECTOR_TYPE,
metadata=DataSourceMetadata(
version=str(result["_version"]) if "_version" in result else None,
date_processed=str(time()),
record_locator={
"hosts": self.connection_config.hosts,
"index_name": index_name,
"document_id": record_id,
},
),
),
path=download_path,
)
def run(self, file_data: FileData, **kwargs: Any) -> download_responses:
raise NotImplementedError()
@requires_dependencies(["elasticsearch"], extras="elasticsearch")
async def run_async(self, file_data: FileData, **kwargs: Any) -> download_responses:
from elasticsearch import AsyncElasticsearch as AsyncElasticsearchClient
from elasticsearch.helpers import async_scan
index_name: str = file_data.additional_metadata["index_name"]
ids: list[str] = file_data.additional_metadata["ids"]
scan_query = {
"_source": self.download_config.fields,
"version": True,
"query": {"ids": {"values": ids}},
}
download_responses = []
async with AsyncElasticsearchClient(**self.connection_config.get_client_kwargs()) as client:
async for result in async_scan(
client,
query=scan_query,
scroll="1m",
index=index_name,
):
download_responses.append(
self.generate_download_response(
result=result, index_name=index_name, file_data=file_data
)
)
return download_responses
@dataclass
class ElasticsearchUploadStagerConfig(UploadStagerConfig):
index_name: str
@dataclass
class ElasticsearchUploadStager(UploadStager):
upload_stager_config: ElasticsearchUploadStagerConfig
def conform_dict(self, data: dict) -> dict:
resp = {
"_index": self.upload_stager_config.index_name,
"_id": str(uuid.uuid4()),
"_source": {
"element_id": data.pop("element_id", None),
"embeddings": data.pop("embeddings", None),
"text": data.pop("text", None),
"type": data.pop("type", None),
},
}
if "metadata" in data and isinstance(data["metadata"], dict):
resp["_source"]["metadata"] = flatten_dict(data["metadata"], separator="-")
return resp
def run(
self,
elements_filepath: Path,
file_data: FileData,
output_dir: Path,
output_filename: str,
**kwargs: Any,
) -> Path:
with open(elements_filepath) as elements_file:
elements_contents = json.load(elements_file)
conformed_elements = [self.conform_dict(data=element) for element in elements_contents]
output_path = Path(output_dir) / Path(f"{output_filename}.json")
with open(output_path, "w") as output_file:
json.dump(conformed_elements, output_file)
return output_path
@dataclass
class ElasticsearchUploaderConfig(UploaderConfig):
index_name: str
batch_size_bytes: int = 15_000_000
thread_count: int = 4
@dataclass
class ElasticsearchUploader(Uploader):
upload_config: ElasticsearchUploaderConfig
connection_config: ElasticsearchConnectionConfig
def run(self, contents: list[UploadContent], **kwargs: Any) -> None:
elements_dict = []
for content in contents:
with open(content.path) as elements_file:
elements = json.load(elements_file)
elements_dict.extend(elements)
logger.info(
f"writing document batches to destination"
f" index named {self.upload_config.index_name}"
f" at {self.connection_config.hosts}"
f" with batch size (in bytes) {self.upload_config.batch_size_bytes}"
f" with {self.upload_config.thread_count} (number of) threads"
)
from elasticsearch.helpers import parallel_bulk
for batch in generator_batching_wbytes(
elements_dict, batch_size_limit_bytes=self.upload_config.batch_size_bytes
):
for success, info in parallel_bulk(
self.connection_config.get_client(),
batch,
thread_count=self.upload_config.thread_count,
):
if not success:
logger.error(
"upload failed for a batch in elasticsearch destination connector:", info
)
add_source_entry(
source_type=CONNECTOR_TYPE,
entry=SourceRegistryEntry(
connection_config=ElasticsearchConnectionConfig,
indexer=ElasticsearchIndexer,
indexer_config=ElasticsearchIndexerConfig,
downloader=ElasticsearchDownloader,
downloader_config=ElasticsearchDownloaderConfig,
),
)
add_destination_entry(
destination_type=CONNECTOR_TYPE,
entry=DestinationRegistryEntry(
connection_config=ElasticsearchConnectionConfig,
upload_stager_config=ElasticsearchUploadStagerConfig,
upload_stager=ElasticsearchUploadStager,
uploader_config=ElasticsearchUploaderConfig,
uploader=ElasticsearchUploader,
),
)