mirror of
https://github.com/Unstructured-IO/unstructured.git
synced 2025-11-02 02:53:31 +00:00
feat/migrate astra db (#3294)
### Description Move astradb destination connector over to the new v2 ingest framework
This commit is contained in:
parent
3f581e6b7d
commit
a7a53f6fcb
@ -1,4 +1,4 @@
|
||||
## 0.14.9-dev1
|
||||
## 0.14.9-dev2
|
||||
|
||||
### Enhancements
|
||||
|
||||
|
||||
@ -1 +1 @@
|
||||
__version__ = "0.14.9-dev1" # pragma: no cover
|
||||
__version__ = "0.14.9-dev2" # pragma: no cover
|
||||
|
||||
@ -2,6 +2,7 @@ from collections import Counter
|
||||
|
||||
import click
|
||||
|
||||
from .astra import astra_dest_cmd
|
||||
from .chroma import chroma_dest_cmd
|
||||
from .elasticsearch import elasticsearch_dest_cmd, elasticsearch_src_cmd
|
||||
from .fsspec.azure import azure_dest_cmd, azure_src_cmd
|
||||
@ -36,6 +37,7 @@ if duplicate_src_names:
|
||||
)
|
||||
|
||||
dest_cmds = [
|
||||
astra_dest_cmd,
|
||||
azure_dest_cmd,
|
||||
box_dest_cmd,
|
||||
chroma_dest_cmd,
|
||||
|
||||
85
unstructured/ingest/v2/cli/cmds/astra.py
Normal file
85
unstructured/ingest/v2/cli/cmds/astra.py
Normal file
@ -0,0 +1,85 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import click
|
||||
|
||||
from unstructured.ingest.v2.cli.base import DestCmd
|
||||
from unstructured.ingest.v2.cli.interfaces import CliConfig
|
||||
from unstructured.ingest.v2.cli.utils import Dict
|
||||
from unstructured.ingest.v2.processes.connectors.astra import CONNECTOR_TYPE
|
||||
|
||||
|
||||
@dataclass
|
||||
class AstraCliConnectionConfig(CliConfig):
|
||||
@staticmethod
|
||||
def get_cli_options() -> list[click.Option]:
|
||||
options = [
|
||||
click.Option(
|
||||
["--token"],
|
||||
required=True,
|
||||
type=str,
|
||||
help="Astra DB Token with access to the database.",
|
||||
envvar="ASTRA_DB_TOKEN",
|
||||
show_envvar=True,
|
||||
),
|
||||
click.Option(
|
||||
["--api-endpoint"],
|
||||
required=True,
|
||||
type=str,
|
||||
help="The API endpoint for the Astra DB.",
|
||||
envvar="ASTRA_DB_ENDPOINT",
|
||||
show_envvar=True,
|
||||
),
|
||||
]
|
||||
return options
|
||||
|
||||
|
||||
@dataclass
|
||||
class AstraCliUploaderConfig(CliConfig):
|
||||
@staticmethod
|
||||
def get_cli_options() -> list[click.Option]:
|
||||
options = [
|
||||
click.Option(
|
||||
["--collection-name"],
|
||||
required=False,
|
||||
type=str,
|
||||
help="The name of the Astra DB collection to write into. "
|
||||
"Note that the collection name must only include letters, "
|
||||
"numbers, and underscores.",
|
||||
),
|
||||
click.Option(
|
||||
["--embedding-dimension"],
|
||||
required=True,
|
||||
default=384,
|
||||
type=int,
|
||||
help="The dimensionality of the embeddings",
|
||||
),
|
||||
click.Option(
|
||||
["--namespace"],
|
||||
required=False,
|
||||
default=None,
|
||||
type=str,
|
||||
help="The Astra DB namespace to write into.",
|
||||
),
|
||||
click.Option(
|
||||
["--requested-indexing-policy"],
|
||||
required=False,
|
||||
default=None,
|
||||
type=Dict(),
|
||||
help="The indexing policy to use for the collection."
|
||||
'example: \'{"deny": ["metadata"]}\' ',
|
||||
),
|
||||
click.Option(
|
||||
["--batch-size"],
|
||||
default=20,
|
||||
type=int,
|
||||
help="Number of records per batch",
|
||||
),
|
||||
]
|
||||
return options
|
||||
|
||||
|
||||
astra_dest_cmd = DestCmd(
|
||||
cmd_name=CONNECTOR_TYPE,
|
||||
connection_config=AstraCliConnectionConfig,
|
||||
uploader_config=AstraCliUploaderConfig,
|
||||
)
|
||||
154
unstructured/ingest/v2/processes/connectors/astra.py
Normal file
154
unstructured/ingest/v2/processes/connectors/astra.py
Normal file
@ -0,0 +1,154 @@
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from unstructured import __name__ as integration_name
|
||||
from unstructured.__version__ import __version__ as integration_version
|
||||
from unstructured.ingest.enhanced_dataclass import enhanced_field
|
||||
from unstructured.ingest.utils.data_prep import chunk_generator
|
||||
from unstructured.ingest.v2.interfaces import (
|
||||
AccessConfig,
|
||||
ConnectionConfig,
|
||||
FileData,
|
||||
UploadContent,
|
||||
Uploader,
|
||||
UploaderConfig,
|
||||
UploadStager,
|
||||
UploadStagerConfig,
|
||||
)
|
||||
from unstructured.ingest.v2.logger import logger
|
||||
from unstructured.ingest.v2.processes.connector_registry import (
|
||||
DestinationRegistryEntry,
|
||||
add_destination_entry,
|
||||
)
|
||||
from unstructured.utils import requires_dependencies
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrapy.db import AstraDBCollection
|
||||
|
||||
CONNECTOR_TYPE = "astra"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AstraAccessConfig(AccessConfig):
|
||||
token: str
|
||||
api_endpoint: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class AstraConnectionConfig(ConnectionConfig):
|
||||
connection_type: str = CONNECTOR_TYPE
|
||||
access_config: AstraAccessConfig = enhanced_field(sensitive=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AstraUploadStagerConfig(UploadStagerConfig):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class AstraUploadStager(UploadStager):
|
||||
upload_stager_config: AstraUploadStagerConfig = field(
|
||||
default_factory=lambda: AstraUploadStagerConfig()
|
||||
)
|
||||
|
||||
def conform_dict(self, element_dict: dict) -> dict:
|
||||
return {
|
||||
"$vector": element_dict.pop("embeddings", None),
|
||||
"content": element_dict.pop("text", None),
|
||||
"metadata": element_dict,
|
||||
}
|
||||
|
||||
def run(
|
||||
self,
|
||||
elements_filepath: Path,
|
||||
file_data: FileData,
|
||||
output_dir: Path,
|
||||
output_filename: str,
|
||||
**kwargs: Any,
|
||||
) -> Path:
|
||||
with open(elements_filepath) as elements_file:
|
||||
elements_contents = json.load(elements_file)
|
||||
conformed_elements = []
|
||||
for element in elements_contents:
|
||||
conformed_elements.append(self.conform_dict(element_dict=element))
|
||||
output_path = Path(output_dir) / Path(f"{output_filename}.json")
|
||||
with open(output_path, "w") as output_file:
|
||||
json.dump(conformed_elements, output_file)
|
||||
return output_path
|
||||
|
||||
|
||||
@dataclass
|
||||
class AstraUploaderConfig(UploaderConfig):
|
||||
collection_name: str
|
||||
embedding_dimension: int
|
||||
namespace: Optional[str] = None
|
||||
requested_indexing_policy: Optional[dict[str, Any]] = None
|
||||
batch_size: int = 20
|
||||
|
||||
|
||||
@dataclass
|
||||
class AstraUploader(Uploader):
|
||||
connection_config: AstraConnectionConfig
|
||||
upload_config: AstraUploaderConfig
|
||||
|
||||
@requires_dependencies(["astrapy"], extras="astra")
|
||||
def get_collection(self) -> "AstraDBCollection":
|
||||
from astrapy.db import AstraDB
|
||||
|
||||
# Get the collection_name and embedding dimension
|
||||
collection_name = self.upload_config.collection_name
|
||||
embedding_dimension = self.upload_config.embedding_dimension
|
||||
requested_indexing_policy = self.upload_config.requested_indexing_policy
|
||||
|
||||
# If the user has requested an indexing policy, pass it to the AstraDB
|
||||
options = {"indexing": requested_indexing_policy} if requested_indexing_policy else None
|
||||
|
||||
# Build the Astra DB object.
|
||||
# caller_name/version for AstraDB tracking
|
||||
astra_db = AstraDB(
|
||||
api_endpoint=self.connection_config.access_config.api_endpoint,
|
||||
token=self.connection_config.access_config.token,
|
||||
namespace=self.upload_config.namespace,
|
||||
caller_name=integration_name,
|
||||
caller_version=integration_version,
|
||||
)
|
||||
|
||||
# Create and connect to the newly created collection
|
||||
astra_db_collection = astra_db.create_collection(
|
||||
collection_name=collection_name,
|
||||
dimension=embedding_dimension,
|
||||
options=options,
|
||||
)
|
||||
return astra_db_collection
|
||||
|
||||
def run(self, contents: list[UploadContent], **kwargs: Any) -> None:
|
||||
elements_dict = []
|
||||
for content in contents:
|
||||
with open(content.path) as elements_file:
|
||||
elements = json.load(elements_file)
|
||||
elements_dict.extend(elements)
|
||||
|
||||
logger.info(
|
||||
f"writing {len(elements_dict)} objects to destination "
|
||||
f"collection {self.upload_config.collection_name}"
|
||||
)
|
||||
|
||||
astra_batch_size = self.upload_config.batch_size
|
||||
collection = self.get_collection()
|
||||
|
||||
for chunk in chunk_generator(elements_dict, astra_batch_size):
|
||||
collection.insert_many(chunk)
|
||||
|
||||
|
||||
add_destination_entry(
|
||||
destination_type=CONNECTOR_TYPE,
|
||||
entry=DestinationRegistryEntry(
|
||||
connection_config=AstraConnectionConfig,
|
||||
upload_stager_config=AstraUploadStagerConfig,
|
||||
upload_stager=AstraUploadStager,
|
||||
uploader_config=AstraUploaderConfig,
|
||||
uploader=AstraUploader,
|
||||
),
|
||||
)
|
||||
@ -31,9 +31,6 @@ from unstructured.utils import requires_dependencies
|
||||
if TYPE_CHECKING:
|
||||
from chromadb import Client
|
||||
|
||||
|
||||
import typing as t
|
||||
|
||||
CONNECTOR_TYPE = "chroma"
|
||||
|
||||
|
||||
@ -165,7 +162,7 @@ class ChromaUploader(Uploader):
|
||||
raise ValueError(f"chroma error: {e}") from e
|
||||
|
||||
@staticmethod
|
||||
def prepare_chroma_list(chunk: t.Tuple[t.Dict[str, t.Any]]) -> t.Dict[str, t.List[t.Any]]:
|
||||
def prepare_chroma_list(chunk: tuple[dict[str, Any]]) -> dict[str, list[Any]]:
|
||||
"""Helper function to break a tuple of dicts into list of parallel lists for ChromaDb.
|
||||
({'id':1}, {'id':2}, {'id':3}) -> {'ids':[1,2,3]}"""
|
||||
chroma_dict = {}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user