feat/migrate astra db (#3294)

### Description
Move astradb destination connector over to the new v2 ingest framework
This commit is contained in:
Roman Isecke 2024-06-25 14:00:47 -04:00 committed by GitHub
parent 3f581e6b7d
commit a7a53f6fcb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 244 additions and 6 deletions

View File

@ -1,4 +1,4 @@
## 0.14.9-dev1
## 0.14.9-dev2
### Enhancements

View File

@ -1 +1 @@
__version__ = "0.14.9-dev1" # pragma: no cover
__version__ = "0.14.9-dev2" # pragma: no cover

View File

@ -2,6 +2,7 @@ from collections import Counter
import click
from .astra import astra_dest_cmd
from .chroma import chroma_dest_cmd
from .elasticsearch import elasticsearch_dest_cmd, elasticsearch_src_cmd
from .fsspec.azure import azure_dest_cmd, azure_src_cmd
@ -36,6 +37,7 @@ if duplicate_src_names:
)
dest_cmds = [
astra_dest_cmd,
azure_dest_cmd,
box_dest_cmd,
chroma_dest_cmd,

View File

@ -0,0 +1,85 @@
from dataclasses import dataclass
import click
from unstructured.ingest.v2.cli.base import DestCmd
from unstructured.ingest.v2.cli.interfaces import CliConfig
from unstructured.ingest.v2.cli.utils import Dict
from unstructured.ingest.v2.processes.connectors.astra import CONNECTOR_TYPE
@dataclass
class AstraCliConnectionConfig(CliConfig):
@staticmethod
def get_cli_options() -> list[click.Option]:
options = [
click.Option(
["--token"],
required=True,
type=str,
help="Astra DB Token with access to the database.",
envvar="ASTRA_DB_TOKEN",
show_envvar=True,
),
click.Option(
["--api-endpoint"],
required=True,
type=str,
help="The API endpoint for the Astra DB.",
envvar="ASTRA_DB_ENDPOINT",
show_envvar=True,
),
]
return options
@dataclass
class AstraCliUploaderConfig(CliConfig):
@staticmethod
def get_cli_options() -> list[click.Option]:
options = [
click.Option(
["--collection-name"],
required=False,
type=str,
help="The name of the Astra DB collection to write into. "
"Note that the collection name must only include letters, "
"numbers, and underscores.",
),
click.Option(
["--embedding-dimension"],
required=True,
default=384,
type=int,
help="The dimensionality of the embeddings",
),
click.Option(
["--namespace"],
required=False,
default=None,
type=str,
help="The Astra DB namespace to write into.",
),
click.Option(
["--requested-indexing-policy"],
required=False,
default=None,
type=Dict(),
help="The indexing policy to use for the collection."
'example: \'{"deny": ["metadata"]}\' ',
),
click.Option(
["--batch-size"],
default=20,
type=int,
help="Number of records per batch",
),
]
return options
astra_dest_cmd = DestCmd(
cmd_name=CONNECTOR_TYPE,
connection_config=AstraCliConnectionConfig,
uploader_config=AstraCliUploaderConfig,
)

View File

@ -0,0 +1,154 @@
import json
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional
from unstructured import __name__ as integration_name
from unstructured.__version__ import __version__ as integration_version
from unstructured.ingest.enhanced_dataclass import enhanced_field
from unstructured.ingest.utils.data_prep import chunk_generator
from unstructured.ingest.v2.interfaces import (
AccessConfig,
ConnectionConfig,
FileData,
UploadContent,
Uploader,
UploaderConfig,
UploadStager,
UploadStagerConfig,
)
from unstructured.ingest.v2.logger import logger
from unstructured.ingest.v2.processes.connector_registry import (
DestinationRegistryEntry,
add_destination_entry,
)
from unstructured.utils import requires_dependencies
if TYPE_CHECKING:
from astrapy.db import AstraDBCollection
CONNECTOR_TYPE = "astra"
@dataclass
class AstraAccessConfig(AccessConfig):
token: str
api_endpoint: str
@dataclass
class AstraConnectionConfig(ConnectionConfig):
connection_type: str = CONNECTOR_TYPE
access_config: AstraAccessConfig = enhanced_field(sensitive=True)
@dataclass
class AstraUploadStagerConfig(UploadStagerConfig):
pass
@dataclass
class AstraUploadStager(UploadStager):
upload_stager_config: AstraUploadStagerConfig = field(
default_factory=lambda: AstraUploadStagerConfig()
)
def conform_dict(self, element_dict: dict) -> dict:
return {
"$vector": element_dict.pop("embeddings", None),
"content": element_dict.pop("text", None),
"metadata": element_dict,
}
def run(
self,
elements_filepath: Path,
file_data: FileData,
output_dir: Path,
output_filename: str,
**kwargs: Any,
) -> Path:
with open(elements_filepath) as elements_file:
elements_contents = json.load(elements_file)
conformed_elements = []
for element in elements_contents:
conformed_elements.append(self.conform_dict(element_dict=element))
output_path = Path(output_dir) / Path(f"{output_filename}.json")
with open(output_path, "w") as output_file:
json.dump(conformed_elements, output_file)
return output_path
@dataclass
class AstraUploaderConfig(UploaderConfig):
collection_name: str
embedding_dimension: int
namespace: Optional[str] = None
requested_indexing_policy: Optional[dict[str, Any]] = None
batch_size: int = 20
@dataclass
class AstraUploader(Uploader):
connection_config: AstraConnectionConfig
upload_config: AstraUploaderConfig
@requires_dependencies(["astrapy"], extras="astra")
def get_collection(self) -> "AstraDBCollection":
from astrapy.db import AstraDB
# Get the collection_name and embedding dimension
collection_name = self.upload_config.collection_name
embedding_dimension = self.upload_config.embedding_dimension
requested_indexing_policy = self.upload_config.requested_indexing_policy
# If the user has requested an indexing policy, pass it to the AstraDB
options = {"indexing": requested_indexing_policy} if requested_indexing_policy else None
# Build the Astra DB object.
# caller_name/version for AstraDB tracking
astra_db = AstraDB(
api_endpoint=self.connection_config.access_config.api_endpoint,
token=self.connection_config.access_config.token,
namespace=self.upload_config.namespace,
caller_name=integration_name,
caller_version=integration_version,
)
# Create and connect to the newly created collection
astra_db_collection = astra_db.create_collection(
collection_name=collection_name,
dimension=embedding_dimension,
options=options,
)
return astra_db_collection
def run(self, contents: list[UploadContent], **kwargs: Any) -> None:
elements_dict = []
for content in contents:
with open(content.path) as elements_file:
elements = json.load(elements_file)
elements_dict.extend(elements)
logger.info(
f"writing {len(elements_dict)} objects to destination "
f"collection {self.upload_config.collection_name}"
)
astra_batch_size = self.upload_config.batch_size
collection = self.get_collection()
for chunk in chunk_generator(elements_dict, astra_batch_size):
collection.insert_many(chunk)
add_destination_entry(
destination_type=CONNECTOR_TYPE,
entry=DestinationRegistryEntry(
connection_config=AstraConnectionConfig,
upload_stager_config=AstraUploadStagerConfig,
upload_stager=AstraUploadStager,
uploader_config=AstraUploaderConfig,
uploader=AstraUploader,
),
)

View File

@ -31,9 +31,6 @@ from unstructured.utils import requires_dependencies
if TYPE_CHECKING:
from chromadb import Client
import typing as t
CONNECTOR_TYPE = "chroma"
@ -165,7 +162,7 @@ class ChromaUploader(Uploader):
raise ValueError(f"chroma error: {e}") from e
@staticmethod
def prepare_chroma_list(chunk: t.Tuple[t.Dict[str, t.Any]]) -> t.Dict[str, t.List[t.Any]]:
def prepare_chroma_list(chunk: tuple[dict[str, Any]]) -> dict[str, list[Any]]:
"""Helper function to break a tuple of dicts into list of parallel lists for ChromaDb.
({'id':1}, {'id':2}, {'id':3}) -> {'ids':[1,2,3]}"""
chroma_dict = {}