From 0fe0f15f30844eb9525d4548e006bbefa1036727 Mon Sep 17 00:00:00 2001 From: Roman Isecke <136338424+rbiseck3@users.noreply.github.com> Date: Thu, 6 Jun 2024 19:18:55 -0400 Subject: [PATCH] feat: migrate weaviate connector to new framework (#3160) ### Description Add weaviate output connector to those supported in the new v2 ingest framework. Some fixes were needed to the upoad stager step as this was the first connector moved over that leverages this part of the pipeline. --- CHANGELOG.md | 2 +- unstructured/__version__.py | 2 +- unstructured/ingest/v2/cli/base/cmd.py | 4 +- unstructured/ingest/v2/cli/cmds/__init__.py | 2 + unstructured/ingest/v2/cli/cmds/weaviate.py | 100 ++++++++ unstructured/ingest/v2/cli/interfaces.py | 1 - .../ingest/v2/interfaces/connector.py | 8 +- .../ingest/v2/interfaces/downloader.py | 4 +- .../ingest/v2/interfaces/upload_stager.py | 26 +- unstructured/ingest/v2/interfaces/uploader.py | 4 +- .../ingest/v2/pipeline/steps/stage.py | 28 ++- .../ingest/v2/processes/connectors/local.py | 25 +- .../v2/processes/connectors/weaviate.py | 236 ++++++++++++++++++ 13 files changed, 420 insertions(+), 22 deletions(-) create mode 100644 unstructured/ingest/v2/cli/cmds/weaviate.py create mode 100644 unstructured/ingest/v2/processes/connectors/weaviate.py diff --git a/CHANGELOG.md b/CHANGELOG.md index ffbd87343..ceec72e0e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,4 @@ -## 0.14.5-dev6 +## 0.14.5-dev7 ### Enhancements diff --git a/unstructured/__version__.py b/unstructured/__version__.py index 40be1c9bb..119f994e3 100644 --- a/unstructured/__version__.py +++ b/unstructured/__version__.py @@ -1 +1 @@ -__version__ = "0.14.5-dev6" # pragma: no cover +__version__ = "0.14.5-dev7" # pragma: no cover diff --git a/unstructured/ingest/v2/cli/base/cmd.py b/unstructured/ingest/v2/cli/base/cmd.py index 76badac6c..d2212e475 100644 --- a/unstructured/ingest/v2/cli/base/cmd.py +++ b/unstructured/ingest/v2/cli/base/cmd.py @@ -74,7 +74,7 @@ class BaseCmd(ABC): f"setting destination on pipeline {dest} with options: {destination_options}" ) if uploader_stager := self.get_upload_stager(dest=dest, options=destination_options): - pipeline_kwargs["upload_stager"] = uploader_stager + pipeline_kwargs["stager"] = uploader_stager pipeline_kwargs["uploader"] = self.get_uploader(dest=dest, options=destination_options) else: # Default to local uploader @@ -148,7 +148,7 @@ class BaseCmd(ABC): dest_entry = destination_registry[dest] upload_stager_kwargs: dict[str, Any] = {} if upload_stager_config_cls := dest_entry.upload_stager_config: - upload_stager_kwargs["config"] = extract_config( + upload_stager_kwargs["upload_stager_config"] = extract_config( flat_data=options, config=upload_stager_config_cls ) if upload_stager_cls := dest_entry.upload_stager: diff --git a/unstructured/ingest/v2/cli/cmds/__init__.py b/unstructured/ingest/v2/cli/cmds/__init__.py index 6ce3ece14..93711190b 100644 --- a/unstructured/ingest/v2/cli/cmds/__init__.py +++ b/unstructured/ingest/v2/cli/cmds/__init__.py @@ -9,6 +9,7 @@ from .fsspec.gcs import gcs_dest_cmd, gcs_src_cmd from .fsspec.s3 import s3_dest_cmd, s3_src_cmd from .fsspec.sftp import sftp_dest_cmd, sftp_src_cmd from .local import local_dest_cmd, local_src_cmd +from .weaviate import weaviate_dest_cmd src_cmds = [ azure_src_cmd, @@ -37,6 +38,7 @@ dest_cmds = [ local_dest_cmd, s3_dest_cmd, sftp_dest_cmd, + weaviate_dest_cmd, ] duplicate_dest_names = [ diff --git a/unstructured/ingest/v2/cli/cmds/weaviate.py b/unstructured/ingest/v2/cli/cmds/weaviate.py new file mode 100644 index 000000000..aaa051d05 --- /dev/null +++ b/unstructured/ingest/v2/cli/cmds/weaviate.py @@ -0,0 +1,100 @@ +from dataclasses import dataclass + +import click + +from unstructured.ingest.v2.cli.base import DestCmd +from unstructured.ingest.v2.cli.interfaces import CliConfig +from unstructured.ingest.v2.cli.utils import DelimitedString +from unstructured.ingest.v2.processes.connectors.weaviate import CONNECTOR_TYPE + + +@dataclass +class WeaviateCliConnectionConfig(CliConfig): + @staticmethod + def get_cli_options() -> list[click.Option]: + options = [ + click.Option( + ["--host-url"], + required=True, + help="Weaviate instance url", + ), + click.Option( + ["--class-name"], + default=None, + type=str, + help="Name of the class to push the records into, e.g: Pdf-elements", + ), + click.Option( + ["--access-token"], default=None, type=str, help="Used to create the bearer token." + ), + click.Option( + ["--refresh-token"], + default=None, + type=str, + help="Will tie this value to the bearer token. If not provided, " + "the authentication will expire once the lifetime of the access token is up.", + ), + click.Option( + ["--api-key"], + default=None, + type=str, + ), + click.Option( + ["--client-secret"], + default=None, + type=str, + ), + click.Option( + ["--scope"], + default=None, + type=DelimitedString(), + ), + click.Option( + ["--username"], + default=None, + type=str, + ), + click.Option( + ["--password"], + default=None, + type=str, + ), + click.Option( + ["--anonymous"], + is_flag=True, + default=False, + type=bool, + help="if set, all auth values will be ignored", + ), + ] + return options + + +@dataclass +class WeaviateCliUploaderConfig(CliConfig): + @staticmethod + def get_cli_options() -> list[click.Option]: + options = [ + click.Option( + ["--batch-size"], + default=100, + type=int, + help="Number of records per batch", + ) + ] + return options + + +@dataclass +class WeaviateCliUploadStagerConfig(CliConfig): + @staticmethod + def get_cli_options() -> list[click.Option]: + return [] + + +weaviate_dest_cmd = DestCmd( + cmd_name=CONNECTOR_TYPE, + connection_config=WeaviateCliConnectionConfig, + uploader_config=WeaviateCliUploaderConfig, + upload_stager_config=WeaviateCliUploadStagerConfig, +) diff --git a/unstructured/ingest/v2/cli/interfaces.py b/unstructured/ingest/v2/cli/interfaces.py index 559590e11..2a8a0e18b 100644 --- a/unstructured/ingest/v2/cli/interfaces.py +++ b/unstructured/ingest/v2/cli/interfaces.py @@ -19,7 +19,6 @@ class CliConfig(ABC): existing_opts = [] for param in cmd.params: existing_opts.extend(param.opts) - for param in params: for opt in param.opts: if opt in existing_opts: diff --git a/unstructured/ingest/v2/interfaces/connector.py b/unstructured/ingest/v2/interfaces/connector.py index f71f0ca2a..dc700fc94 100644 --- a/unstructured/ingest/v2/interfaces/connector.py +++ b/unstructured/ingest/v2/interfaces/connector.py @@ -1,8 +1,8 @@ from abc import ABC from dataclasses import dataclass -from typing import Any, Optional, TypeVar +from typing import Any, TypeVar -from unstructured.ingest.enhanced_dataclass import EnhancedDataClassJsonMixin, enhanced_field +from unstructured.ingest.enhanced_dataclass import EnhancedDataClassJsonMixin @dataclass @@ -16,7 +16,7 @@ AccessConfigT = TypeVar("AccessConfigT", bound=AccessConfig) @dataclass class ConnectionConfig(EnhancedDataClassJsonMixin): - access_config: Optional[AccessConfigT] = enhanced_field(sensitive=True, default=None) + access_config: AccessConfigT def get_access_config(self) -> dict[str, Any]: if not self.access_config: @@ -29,4 +29,4 @@ ConnectionConfigT = TypeVar("ConnectionConfigT", bound=ConnectionConfig) @dataclass class BaseConnector(ABC): - connection_config: Optional[ConnectionConfigT] = None + connection_config: ConnectionConfigT diff --git a/unstructured/ingest/v2/interfaces/downloader.py b/unstructured/ingest/v2/interfaces/downloader.py index aee4bc47e..a2c1ce805 100644 --- a/unstructured/ingest/v2/interfaces/downloader.py +++ b/unstructured/ingest/v2/interfaces/downloader.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from dataclasses import dataclass, field +from dataclasses import dataclass from pathlib import Path from typing import Any, Optional, TypeVar @@ -19,7 +19,7 @@ DownloaderConfigT = TypeVar("DownloaderConfigT", bound=DownloaderConfig) class Downloader(BaseProcess, BaseConnector, ABC): connector_type: str - download_config: Optional[DownloaderConfigT] = field(default_factory=DownloaderConfig) + download_config: DownloaderConfigT @property def download_dir(self) -> Path: diff --git a/unstructured/ingest/v2/interfaces/upload_stager.py b/unstructured/ingest/v2/interfaces/upload_stager.py index e89ba331d..39e28355a 100644 --- a/unstructured/ingest/v2/interfaces/upload_stager.py +++ b/unstructured/ingest/v2/interfaces/upload_stager.py @@ -21,8 +21,28 @@ class UploadStager(BaseProcess, ABC): upload_stager_config: Optional[UploadStagerConfigT] = None @abstractmethod - def run(self, elements_filepath: Path, file_data: FileData, **kwargs: Any) -> Path: + def run( + self, + elements_filepath: Path, + file_data: FileData, + output_dir: Path, + output_filename: str, + **kwargs: Any + ) -> Path: pass - async def run_async(self, elements_filepath: Path, file_data: FileData, **kwargs: Any) -> Path: - return self.run(elements_filepath=elements_filepath, file_data=file_data, **kwargs) + async def run_async( + self, + elements_filepath: Path, + file_data: FileData, + output_dir: Path, + output_filename: str, + **kwargs: Any + ) -> Path: + return self.run( + elements_filepath=elements_filepath, + output_dir=output_dir, + output_filename=output_filename, + file_data=file_data, + **kwargs + ) diff --git a/unstructured/ingest/v2/interfaces/uploader.py b/unstructured/ingest/v2/interfaces/uploader.py index 03763e299..520628e5a 100644 --- a/unstructured/ingest/v2/interfaces/uploader.py +++ b/unstructured/ingest/v2/interfaces/uploader.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from dataclasses import dataclass, field +from dataclasses import dataclass from pathlib import Path from typing import Any, TypeVar @@ -25,7 +25,7 @@ class UploadContent: @dataclass class Uploader(BaseProcess, BaseConnector, ABC): - upload_config: UploaderConfigT = field(default_factory=UploaderConfig) + upload_config: UploaderConfigT def is_async(self) -> bool: return False diff --git a/unstructured/ingest/v2/pipeline/steps/stage.py b/unstructured/ingest/v2/pipeline/steps/stage.py index e7a3644de..59bbe90c1 100644 --- a/unstructured/ingest/v2/pipeline/steps/stage.py +++ b/unstructured/ingest/v2/pipeline/steps/stage.py @@ -1,6 +1,8 @@ +import hashlib +import json from dataclasses import dataclass from pathlib import Path -from typing import TypedDict +from typing import Optional, TypedDict from unstructured.ingest.v2.interfaces.file_data import FileData from unstructured.ingest.v2.interfaces.upload_stager import UploadStager @@ -30,12 +32,16 @@ class UploadStageStep(PipelineStep): if self.process.upload_stager_config else None ) + self.cache_dir.mkdir(parents=True, exist_ok=True) logger.info(f"Created {self.identifier} with configs: {config}") def _run(self, path: str, file_data_path: str) -> UploadStageStepResponse: path = Path(path) staged_output_path = self.process.run( - elements_filepath=path, file_data=FileData.from_file(path=file_data_path) + elements_filepath=path, + file_data=FileData.from_file(path=file_data_path), + output_dir=self.cache_dir, + output_filename=self.get_hash(extras=[path.name]), ) return UploadStageStepResponse(file_data_path=file_data_path, path=str(staged_output_path)) @@ -44,10 +50,24 @@ class UploadStageStep(PipelineStep): if semaphore := self.context.semaphore: async with semaphore: staged_output_path = await self.process.run_async( - elements_filepath=path, file_data=FileData.from_file(path=file_data_path) + elements_filepath=path, + file_data=FileData.from_file(path=file_data_path), + output_dir=self.cache_dir, + output_filename=self.get_hash(extras=[path.name]), ) else: staged_output_path = await self.process.run_async( - elements_filepath=path, file_data=FileData.from_file(path=file_data_path) + elements_filepath=path, + file_data=FileData.from_file(path=file_data_path), + output_dir=self.cache_dir, + output_filename=self.get_hash(extras=[path.name]), ) return UploadStageStepResponse(file_data_path=file_data_path, path=str(staged_output_path)) + + def get_hash(self, extras: Optional[list[str]]) -> str: + hashable_string = json.dumps( + self.process.upload_stager_config.to_dict(), sort_keys=True, ensure_ascii=True + ) + if extras: + hashable_string += "".join(extras) + return hashlib.sha256(hashable_string.encode()).hexdigest()[:12] diff --git a/unstructured/ingest/v2/processes/connectors/local.py b/unstructured/ingest/v2/processes/connectors/local.py index 00e7a4ab8..5cfeae7ef 100644 --- a/unstructured/ingest/v2/processes/connectors/local.py +++ b/unstructured/ingest/v2/processes/connectors/local.py @@ -8,6 +8,8 @@ from typing import Any, Generator, Optional from unstructured.documents.elements import DataSourceMetadata from unstructured.ingest.v2.interfaces import ( + AccessConfig, + ConnectionConfig, Downloader, DownloaderConfig, FileData, @@ -29,6 +31,16 @@ from unstructured.ingest.v2.processes.connector_registry import ( CONNECTOR_TYPE = "local" +@dataclass +class LocalAccessConfig(AccessConfig): + pass + + +@dataclass +class LocalConnectionConfig(ConnectionConfig): + access_config: LocalAccessConfig = field(default_factory=lambda: LocalAccessConfig()) + + @dataclass class LocalIndexerConfig(IndexerConfig): input_path: str @@ -43,6 +55,9 @@ class LocalIndexerConfig(IndexerConfig): @dataclass class LocalIndexer(Indexer): index_config: LocalIndexerConfig + connection_config: LocalConnectionConfig = field( + default_factory=lambda: LocalConnectionConfig() + ) connector_type: str = CONNECTOR_TYPE def list_files(self) -> list[Path]: @@ -115,7 +130,10 @@ class LocalDownloaderConfig(DownloaderConfig): @dataclass class LocalDownloader(Downloader): connector_type: str = CONNECTOR_TYPE - download_config: Optional[LocalDownloaderConfig] = None + connection_config: LocalConnectionConfig = field( + default_factory=lambda: LocalConnectionConfig() + ) + download_config: LocalDownloaderConfig = field(default_factory=lambda: LocalDownloaderConfig()) def get_download_path(self, file_data: FileData) -> Path: return Path(file_data.source_identifiers.fullpath) @@ -139,7 +157,10 @@ class LocalUploaderConfig(UploaderConfig): @dataclass class LocalUploader(Uploader): - upload_config: LocalUploaderConfig = field(default_factory=LocalUploaderConfig) + upload_config: LocalUploaderConfig = field(default_factory=lambda: LocalUploaderConfig()) + connection_config: LocalConnectionConfig = field( + default_factory=lambda: LocalConnectionConfig() + ) def is_async(self) -> bool: return False diff --git a/unstructured/ingest/v2/processes/connectors/weaviate.py b/unstructured/ingest/v2/processes/connectors/weaviate.py new file mode 100644 index 000000000..c273df4ef --- /dev/null +++ b/unstructured/ingest/v2/processes/connectors/weaviate.py @@ -0,0 +1,236 @@ +import json +from dataclasses import dataclass, field +from datetime import date, datetime +from pathlib import Path +from typing import TYPE_CHECKING, Any, Optional + +from dateutil import parser + +from unstructured.ingest.enhanced_dataclass import enhanced_field +from unstructured.ingest.v2.interfaces import ( + AccessConfig, + ConnectionConfig, + FileData, + UploadContent, + Uploader, + UploaderConfig, + UploadStager, + UploadStagerConfig, +) +from unstructured.ingest.v2.logger import logger +from unstructured.ingest.v2.processes.connector_registry import ( + DestinationRegistryEntry, + add_destination_entry, +) + +if TYPE_CHECKING: + from weaviate import Client + +CONNECTOR_TYPE = "weaviate" + + +@dataclass +class WeaviateAccessConfig(AccessConfig): + access_token: Optional[str] + api_key: Optional[str] + client_secret: Optional[str] + password: Optional[str] + + +@dataclass +class WeaviateConnectionConfig(ConnectionConfig): + host_url: str + class_name: str + access_config: WeaviateAccessConfig = enhanced_field(sensitive=True) + username: Optional[str] = None + anonymous: bool = False + scope: Optional[list[str]] = None + refresh_token: Optional[str] = None + connector_type: str = CONNECTOR_TYPE + + +@dataclass +class WeaviateUploadStagerConfig(UploadStagerConfig): + pass + + +@dataclass +class WeaviateUploadStager(UploadStager): + upload_stager_config: WeaviateUploadStagerConfig = field( + default_factory=lambda: WeaviateUploadStagerConfig() + ) + + @staticmethod + def parse_date_string(date_string: str) -> date: + try: + timestamp = float(date_string) + return datetime.fromtimestamp(timestamp) + except Exception as e: + logger.debug(f"date {date_string} string not a timestamp: {e}") + return parser.parse(date_string) + + @classmethod + def conform_dict(cls, data: dict) -> None: + """ + Updates the element dictionary to conform to the Weaviate schema + """ + + # Dict as string formatting + if record_locator := data.get("metadata", {}).get("data_source", {}).get("record_locator"): + # Explicit casting otherwise fails schema type checking + data["metadata"]["data_source"]["record_locator"] = str(json.dumps(record_locator)) + + # Array of items as string formatting + if points := data.get("metadata", {}).get("coordinates", {}).get("points"): + data["metadata"]["coordinates"]["points"] = str(json.dumps(points)) + + if links := data.get("metadata", {}).get("links", {}): + data["metadata"]["links"] = str(json.dumps(links)) + + if permissions_data := ( + data.get("metadata", {}).get("data_source", {}).get("permissions_data") + ): + data["metadata"]["data_source"]["permissions_data"] = json.dumps(permissions_data) + + # Datetime formatting + if date_created := data.get("metadata", {}).get("data_source", {}).get("date_created"): + data["metadata"]["data_source"]["date_created"] = cls.parse_date_string( + date_created + ).strftime( + "%Y-%m-%dT%H:%M:%S.%fZ", + ) + + if date_modified := data.get("metadata", {}).get("data_source", {}).get("date_modified"): + data["metadata"]["data_source"]["date_modified"] = cls.parse_date_string( + date_modified + ).strftime( + "%Y-%m-%dT%H:%M:%S.%fZ", + ) + + if date_processed := data.get("metadata", {}).get("data_source", {}).get("date_processed"): + data["metadata"]["data_source"]["date_processed"] = cls.parse_date_string( + date_processed + ).strftime( + "%Y-%m-%dT%H:%M:%S.%fZ", + ) + + if last_modified := data.get("metadata", {}).get("last_modified"): + data["metadata"]["last_modified"] = cls.parse_date_string(last_modified).strftime( + "%Y-%m-%dT%H:%M:%S.%fZ", + ) + + # String casting + if version := data.get("metadata", {}).get("data_source", {}).get("version"): + data["metadata"]["data_source"]["version"] = str(version) + + if page_number := data.get("metadata", {}).get("page_number"): + data["metadata"]["page_number"] = str(page_number) + + if regex_metadata := data.get("metadata", {}).get("regex_metadata"): + data["metadata"]["regex_metadata"] = str(json.dumps(regex_metadata)) + + def run( + self, + elements_filepath: Path, + file_data: FileData, + output_dir: Path, + output_filename: str, + **kwargs: Any, + ) -> Path: + with open(elements_filepath) as elements_file: + elements_contents = json.load(elements_file) + for element in elements_contents: + self.conform_dict(data=element) + output_path = Path(output_dir) / Path(f"{output_filename}.json") + with open(output_path, "w") as output_file: + json.dump(elements_contents, output_file) + return output_path + + +@dataclass +class WeaviateUploaderConfig(UploaderConfig): + batch_size: int = 100 + + +@dataclass +class WeaviateUploader(Uploader): + upload_config: WeaviateUploaderConfig + connection_config: WeaviateConnectionConfig + client: Optional["Client"] = field(init=False) + + def __post_init__(self): + from weaviate import Client + + auth = self._resolve_auth_method() + self.client = Client(url=self.connection_config.host_url, auth_client_secret=auth) + + def is_async(self) -> bool: + return True + + def _resolve_auth_method(self): + access_configs = self.connection_config.access_config + connection_config = self.connection_config + if connection_config.anonymous: + return None + + if access_configs.access_token: + from weaviate.auth import AuthBearerToken + + return AuthBearerToken( + access_token=access_configs.access_token, + refresh_token=connection_config.refresh_token, + ) + elif access_configs.api_key: + from weaviate.auth import AuthApiKey + + return AuthApiKey(api_key=access_configs.api_key) + elif access_configs.client_secret: + from weaviate.auth import AuthClientCredentials + + return AuthClientCredentials( + client_secret=access_configs.client_secret, scope=connection_config.scope + ) + elif connection_config.username and access_configs.password: + from weaviate.auth import AuthClientPassword + + return AuthClientPassword( + username=connection_config.username, + password=access_configs.password, + scope=connection_config.scope, + ) + return None + + def run(self, contents: list[UploadContent], **kwargs: Any) -> None: + raise NotImplementedError + + async def run_async(self, path: Path, file_data: FileData, **kwargs: Any) -> None: + with open(path) as elements_file: + elements_dict = json.load(elements_file) + + logger.info( + f"writing {len(elements_dict)} objects to destination " + f"class {self.connection_config.class_name} " + f"at {self.connection_config.host_url}", + ) + + self.client.batch.configure(batch_size=self.upload_config.batch_size) + with self.client.batch as b: + for e in elements_dict: + vector = e.pop("embeddings", None) + b.add_data_object( + e, + self.connection_config.class_name, + vector=vector, + ) + + +add_destination_entry( + destination_type=CONNECTOR_TYPE, + entry=DestinationRegistryEntry( + connection_config=WeaviateConnectionConfig, + uploader=WeaviateUploader, + uploader_config=WeaviateUploaderConfig, + upload_stager=WeaviateUploadStager, + upload_stager_config=WeaviateUploadStagerConfig, + ), +)