diff --git a/CHANGELOG.md b/CHANGELOG.md index 36646407b..30584e712 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,4 @@ -## 0.14.9-dev1 +## 0.14.9-dev2 ### Enhancements diff --git a/unstructured/__version__.py b/unstructured/__version__.py index 6cb65b9d1..a68f7bc73 100644 --- a/unstructured/__version__.py +++ b/unstructured/__version__.py @@ -1 +1 @@ -__version__ = "0.14.9-dev1" # pragma: no cover +__version__ = "0.14.9-dev2" # pragma: no cover diff --git a/unstructured/ingest/v2/cli/cmds/__init__.py b/unstructured/ingest/v2/cli/cmds/__init__.py index dbbc53b55..7da7a9865 100644 --- a/unstructured/ingest/v2/cli/cmds/__init__.py +++ b/unstructured/ingest/v2/cli/cmds/__init__.py @@ -2,6 +2,7 @@ from collections import Counter import click +from .astra import astra_dest_cmd from .chroma import chroma_dest_cmd from .elasticsearch import elasticsearch_dest_cmd, elasticsearch_src_cmd from .fsspec.azure import azure_dest_cmd, azure_src_cmd @@ -36,6 +37,7 @@ if duplicate_src_names: ) dest_cmds = [ + astra_dest_cmd, azure_dest_cmd, box_dest_cmd, chroma_dest_cmd, diff --git a/unstructured/ingest/v2/cli/cmds/astra.py b/unstructured/ingest/v2/cli/cmds/astra.py new file mode 100644 index 000000000..1970eac7c --- /dev/null +++ b/unstructured/ingest/v2/cli/cmds/astra.py @@ -0,0 +1,85 @@ +from dataclasses import dataclass + +import click + +from unstructured.ingest.v2.cli.base import DestCmd +from unstructured.ingest.v2.cli.interfaces import CliConfig +from unstructured.ingest.v2.cli.utils import Dict +from unstructured.ingest.v2.processes.connectors.astra import CONNECTOR_TYPE + + +@dataclass +class AstraCliConnectionConfig(CliConfig): + @staticmethod + def get_cli_options() -> list[click.Option]: + options = [ + click.Option( + ["--token"], + required=True, + type=str, + help="Astra DB Token with access to the database.", + envvar="ASTRA_DB_TOKEN", + show_envvar=True, + ), + click.Option( + ["--api-endpoint"], + required=True, + type=str, + help="The API endpoint for the Astra DB.", + envvar="ASTRA_DB_ENDPOINT", + show_envvar=True, + ), + ] + return options + + +@dataclass +class AstraCliUploaderConfig(CliConfig): + @staticmethod + def get_cli_options() -> list[click.Option]: + options = [ + click.Option( + ["--collection-name"], + required=False, + type=str, + help="The name of the Astra DB collection to write into. " + "Note that the collection name must only include letters, " + "numbers, and underscores.", + ), + click.Option( + ["--embedding-dimension"], + required=True, + default=384, + type=int, + help="The dimensionality of the embeddings", + ), + click.Option( + ["--namespace"], + required=False, + default=None, + type=str, + help="The Astra DB namespace to write into.", + ), + click.Option( + ["--requested-indexing-policy"], + required=False, + default=None, + type=Dict(), + help="The indexing policy to use for the collection." + 'example: \'{"deny": ["metadata"]}\' ', + ), + click.Option( + ["--batch-size"], + default=20, + type=int, + help="Number of records per batch", + ), + ] + return options + + +astra_dest_cmd = DestCmd( + cmd_name=CONNECTOR_TYPE, + connection_config=AstraCliConnectionConfig, + uploader_config=AstraCliUploaderConfig, +) diff --git a/unstructured/ingest/v2/processes/connectors/astra.py b/unstructured/ingest/v2/processes/connectors/astra.py new file mode 100644 index 000000000..59c33d54b --- /dev/null +++ b/unstructured/ingest/v2/processes/connectors/astra.py @@ -0,0 +1,154 @@ +import json +from dataclasses import dataclass, field +from pathlib import Path +from typing import TYPE_CHECKING, Any, Optional + +from unstructured import __name__ as integration_name +from unstructured.__version__ import __version__ as integration_version +from unstructured.ingest.enhanced_dataclass import enhanced_field +from unstructured.ingest.utils.data_prep import chunk_generator +from unstructured.ingest.v2.interfaces import ( + AccessConfig, + ConnectionConfig, + FileData, + UploadContent, + Uploader, + UploaderConfig, + UploadStager, + UploadStagerConfig, +) +from unstructured.ingest.v2.logger import logger +from unstructured.ingest.v2.processes.connector_registry import ( + DestinationRegistryEntry, + add_destination_entry, +) +from unstructured.utils import requires_dependencies + +if TYPE_CHECKING: + from astrapy.db import AstraDBCollection + +CONNECTOR_TYPE = "astra" + + +@dataclass +class AstraAccessConfig(AccessConfig): + token: str + api_endpoint: str + + +@dataclass +class AstraConnectionConfig(ConnectionConfig): + connection_type: str = CONNECTOR_TYPE + access_config: AstraAccessConfig = enhanced_field(sensitive=True) + + +@dataclass +class AstraUploadStagerConfig(UploadStagerConfig): + pass + + +@dataclass +class AstraUploadStager(UploadStager): + upload_stager_config: AstraUploadStagerConfig = field( + default_factory=lambda: AstraUploadStagerConfig() + ) + + def conform_dict(self, element_dict: dict) -> dict: + return { + "$vector": element_dict.pop("embeddings", None), + "content": element_dict.pop("text", None), + "metadata": element_dict, + } + + def run( + self, + elements_filepath: Path, + file_data: FileData, + output_dir: Path, + output_filename: str, + **kwargs: Any, + ) -> Path: + with open(elements_filepath) as elements_file: + elements_contents = json.load(elements_file) + conformed_elements = [] + for element in elements_contents: + conformed_elements.append(self.conform_dict(element_dict=element)) + output_path = Path(output_dir) / Path(f"{output_filename}.json") + with open(output_path, "w") as output_file: + json.dump(conformed_elements, output_file) + return output_path + + +@dataclass +class AstraUploaderConfig(UploaderConfig): + collection_name: str + embedding_dimension: int + namespace: Optional[str] = None + requested_indexing_policy: Optional[dict[str, Any]] = None + batch_size: int = 20 + + +@dataclass +class AstraUploader(Uploader): + connection_config: AstraConnectionConfig + upload_config: AstraUploaderConfig + + @requires_dependencies(["astrapy"], extras="astra") + def get_collection(self) -> "AstraDBCollection": + from astrapy.db import AstraDB + + # Get the collection_name and embedding dimension + collection_name = self.upload_config.collection_name + embedding_dimension = self.upload_config.embedding_dimension + requested_indexing_policy = self.upload_config.requested_indexing_policy + + # If the user has requested an indexing policy, pass it to the AstraDB + options = {"indexing": requested_indexing_policy} if requested_indexing_policy else None + + # Build the Astra DB object. + # caller_name/version for AstraDB tracking + astra_db = AstraDB( + api_endpoint=self.connection_config.access_config.api_endpoint, + token=self.connection_config.access_config.token, + namespace=self.upload_config.namespace, + caller_name=integration_name, + caller_version=integration_version, + ) + + # Create and connect to the newly created collection + astra_db_collection = astra_db.create_collection( + collection_name=collection_name, + dimension=embedding_dimension, + options=options, + ) + return astra_db_collection + + def run(self, contents: list[UploadContent], **kwargs: Any) -> None: + elements_dict = [] + for content in contents: + with open(content.path) as elements_file: + elements = json.load(elements_file) + elements_dict.extend(elements) + + logger.info( + f"writing {len(elements_dict)} objects to destination " + f"collection {self.upload_config.collection_name}" + ) + + astra_batch_size = self.upload_config.batch_size + collection = self.get_collection() + + for chunk in chunk_generator(elements_dict, astra_batch_size): + collection.insert_many(chunk) + + +add_destination_entry( + destination_type=CONNECTOR_TYPE, + entry=DestinationRegistryEntry( + connection_config=AstraConnectionConfig, + upload_stager_config=AstraUploadStagerConfig, + upload_stager=AstraUploadStager, + uploader_config=AstraUploaderConfig, + uploader=AstraUploader, + ), +) diff --git a/unstructured/ingest/v2/processes/connectors/chroma.py b/unstructured/ingest/v2/processes/connectors/chroma.py index 295db82be..f8da8afdf 100644 --- a/unstructured/ingest/v2/processes/connectors/chroma.py +++ b/unstructured/ingest/v2/processes/connectors/chroma.py @@ -31,9 +31,6 @@ from unstructured.utils import requires_dependencies if TYPE_CHECKING: from chromadb import Client - -import typing as t - CONNECTOR_TYPE = "chroma" @@ -165,7 +162,7 @@ class ChromaUploader(Uploader): raise ValueError(f"chroma error: {e}") from e @staticmethod - def prepare_chroma_list(chunk: t.Tuple[t.Dict[str, t.Any]]) -> t.Dict[str, t.List[t.Any]]: + def prepare_chroma_list(chunk: tuple[dict[str, Any]]) -> dict[str, list[Any]]: """Helper function to break a tuple of dicts into list of parallel lists for ChromaDb. ({'id':1}, {'id':2}, {'id':3}) -> {'ids':[1,2,3]}""" chroma_dict = {}