feat: migrate weaviate connector to new framework (#3160)

### Description
Add weaviate output connector to those supported in the new v2 ingest
framework. Some fixes were needed to the upoad stager step as this was
the first connector moved over that leverages this part of the pipeline.
This commit is contained in:
Roman Isecke 2024-06-06 19:18:55 -04:00 committed by GitHub
parent a883fc9df2
commit 0fe0f15f30
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 420 additions and 22 deletions

View File

@ -1,4 +1,4 @@
## 0.14.5-dev6
## 0.14.5-dev7
### Enhancements

View File

@ -1 +1 @@
__version__ = "0.14.5-dev6" # pragma: no cover
__version__ = "0.14.5-dev7" # pragma: no cover

View File

@ -74,7 +74,7 @@ class BaseCmd(ABC):
f"setting destination on pipeline {dest} with options: {destination_options}"
)
if uploader_stager := self.get_upload_stager(dest=dest, options=destination_options):
pipeline_kwargs["upload_stager"] = uploader_stager
pipeline_kwargs["stager"] = uploader_stager
pipeline_kwargs["uploader"] = self.get_uploader(dest=dest, options=destination_options)
else:
# Default to local uploader
@ -148,7 +148,7 @@ class BaseCmd(ABC):
dest_entry = destination_registry[dest]
upload_stager_kwargs: dict[str, Any] = {}
if upload_stager_config_cls := dest_entry.upload_stager_config:
upload_stager_kwargs["config"] = extract_config(
upload_stager_kwargs["upload_stager_config"] = extract_config(
flat_data=options, config=upload_stager_config_cls
)
if upload_stager_cls := dest_entry.upload_stager:

View File

@ -9,6 +9,7 @@ from .fsspec.gcs import gcs_dest_cmd, gcs_src_cmd
from .fsspec.s3 import s3_dest_cmd, s3_src_cmd
from .fsspec.sftp import sftp_dest_cmd, sftp_src_cmd
from .local import local_dest_cmd, local_src_cmd
from .weaviate import weaviate_dest_cmd
src_cmds = [
azure_src_cmd,
@ -37,6 +38,7 @@ dest_cmds = [
local_dest_cmd,
s3_dest_cmd,
sftp_dest_cmd,
weaviate_dest_cmd,
]
duplicate_dest_names = [

View File

@ -0,0 +1,100 @@
from dataclasses import dataclass
import click
from unstructured.ingest.v2.cli.base import DestCmd
from unstructured.ingest.v2.cli.interfaces import CliConfig
from unstructured.ingest.v2.cli.utils import DelimitedString
from unstructured.ingest.v2.processes.connectors.weaviate import CONNECTOR_TYPE
@dataclass
class WeaviateCliConnectionConfig(CliConfig):
@staticmethod
def get_cli_options() -> list[click.Option]:
options = [
click.Option(
["--host-url"],
required=True,
help="Weaviate instance url",
),
click.Option(
["--class-name"],
default=None,
type=str,
help="Name of the class to push the records into, e.g: Pdf-elements",
),
click.Option(
["--access-token"], default=None, type=str, help="Used to create the bearer token."
),
click.Option(
["--refresh-token"],
default=None,
type=str,
help="Will tie this value to the bearer token. If not provided, "
"the authentication will expire once the lifetime of the access token is up.",
),
click.Option(
["--api-key"],
default=None,
type=str,
),
click.Option(
["--client-secret"],
default=None,
type=str,
),
click.Option(
["--scope"],
default=None,
type=DelimitedString(),
),
click.Option(
["--username"],
default=None,
type=str,
),
click.Option(
["--password"],
default=None,
type=str,
),
click.Option(
["--anonymous"],
is_flag=True,
default=False,
type=bool,
help="if set, all auth values will be ignored",
),
]
return options
@dataclass
class WeaviateCliUploaderConfig(CliConfig):
@staticmethod
def get_cli_options() -> list[click.Option]:
options = [
click.Option(
["--batch-size"],
default=100,
type=int,
help="Number of records per batch",
)
]
return options
@dataclass
class WeaviateCliUploadStagerConfig(CliConfig):
@staticmethod
def get_cli_options() -> list[click.Option]:
return []
weaviate_dest_cmd = DestCmd(
cmd_name=CONNECTOR_TYPE,
connection_config=WeaviateCliConnectionConfig,
uploader_config=WeaviateCliUploaderConfig,
upload_stager_config=WeaviateCliUploadStagerConfig,
)

View File

@ -19,7 +19,6 @@ class CliConfig(ABC):
existing_opts = []
for param in cmd.params:
existing_opts.extend(param.opts)
for param in params:
for opt in param.opts:
if opt in existing_opts:

View File

@ -1,8 +1,8 @@
from abc import ABC
from dataclasses import dataclass
from typing import Any, Optional, TypeVar
from typing import Any, TypeVar
from unstructured.ingest.enhanced_dataclass import EnhancedDataClassJsonMixin, enhanced_field
from unstructured.ingest.enhanced_dataclass import EnhancedDataClassJsonMixin
@dataclass
@ -16,7 +16,7 @@ AccessConfigT = TypeVar("AccessConfigT", bound=AccessConfig)
@dataclass
class ConnectionConfig(EnhancedDataClassJsonMixin):
access_config: Optional[AccessConfigT] = enhanced_field(sensitive=True, default=None)
access_config: AccessConfigT
def get_access_config(self) -> dict[str, Any]:
if not self.access_config:
@ -29,4 +29,4 @@ ConnectionConfigT = TypeVar("ConnectionConfigT", bound=ConnectionConfig)
@dataclass
class BaseConnector(ABC):
connection_config: Optional[ConnectionConfigT] = None
connection_config: ConnectionConfigT

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional, TypeVar
@ -19,7 +19,7 @@ DownloaderConfigT = TypeVar("DownloaderConfigT", bound=DownloaderConfig)
class Downloader(BaseProcess, BaseConnector, ABC):
connector_type: str
download_config: Optional[DownloaderConfigT] = field(default_factory=DownloaderConfig)
download_config: DownloaderConfigT
@property
def download_dir(self) -> Path:

View File

@ -21,8 +21,28 @@ class UploadStager(BaseProcess, ABC):
upload_stager_config: Optional[UploadStagerConfigT] = None
@abstractmethod
def run(self, elements_filepath: Path, file_data: FileData, **kwargs: Any) -> Path:
def run(
self,
elements_filepath: Path,
file_data: FileData,
output_dir: Path,
output_filename: str,
**kwargs: Any
) -> Path:
pass
async def run_async(self, elements_filepath: Path, file_data: FileData, **kwargs: Any) -> Path:
return self.run(elements_filepath=elements_filepath, file_data=file_data, **kwargs)
async def run_async(
self,
elements_filepath: Path,
file_data: FileData,
output_dir: Path,
output_filename: str,
**kwargs: Any
) -> Path:
return self.run(
elements_filepath=elements_filepath,
output_dir=output_dir,
output_filename=output_filename,
file_data=file_data,
**kwargs
)

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from dataclasses import dataclass
from pathlib import Path
from typing import Any, TypeVar
@ -25,7 +25,7 @@ class UploadContent:
@dataclass
class Uploader(BaseProcess, BaseConnector, ABC):
upload_config: UploaderConfigT = field(default_factory=UploaderConfig)
upload_config: UploaderConfigT
def is_async(self) -> bool:
return False

View File

@ -1,6 +1,8 @@
import hashlib
import json
from dataclasses import dataclass
from pathlib import Path
from typing import TypedDict
from typing import Optional, TypedDict
from unstructured.ingest.v2.interfaces.file_data import FileData
from unstructured.ingest.v2.interfaces.upload_stager import UploadStager
@ -30,12 +32,16 @@ class UploadStageStep(PipelineStep):
if self.process.upload_stager_config
else None
)
self.cache_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Created {self.identifier} with configs: {config}")
def _run(self, path: str, file_data_path: str) -> UploadStageStepResponse:
path = Path(path)
staged_output_path = self.process.run(
elements_filepath=path, file_data=FileData.from_file(path=file_data_path)
elements_filepath=path,
file_data=FileData.from_file(path=file_data_path),
output_dir=self.cache_dir,
output_filename=self.get_hash(extras=[path.name]),
)
return UploadStageStepResponse(file_data_path=file_data_path, path=str(staged_output_path))
@ -44,10 +50,24 @@ class UploadStageStep(PipelineStep):
if semaphore := self.context.semaphore:
async with semaphore:
staged_output_path = await self.process.run_async(
elements_filepath=path, file_data=FileData.from_file(path=file_data_path)
elements_filepath=path,
file_data=FileData.from_file(path=file_data_path),
output_dir=self.cache_dir,
output_filename=self.get_hash(extras=[path.name]),
)
else:
staged_output_path = await self.process.run_async(
elements_filepath=path, file_data=FileData.from_file(path=file_data_path)
elements_filepath=path,
file_data=FileData.from_file(path=file_data_path),
output_dir=self.cache_dir,
output_filename=self.get_hash(extras=[path.name]),
)
return UploadStageStepResponse(file_data_path=file_data_path, path=str(staged_output_path))
def get_hash(self, extras: Optional[list[str]]) -> str:
hashable_string = json.dumps(
self.process.upload_stager_config.to_dict(), sort_keys=True, ensure_ascii=True
)
if extras:
hashable_string += "".join(extras)
return hashlib.sha256(hashable_string.encode()).hexdigest()[:12]

View File

@ -8,6 +8,8 @@ from typing import Any, Generator, Optional
from unstructured.documents.elements import DataSourceMetadata
from unstructured.ingest.v2.interfaces import (
AccessConfig,
ConnectionConfig,
Downloader,
DownloaderConfig,
FileData,
@ -29,6 +31,16 @@ from unstructured.ingest.v2.processes.connector_registry import (
CONNECTOR_TYPE = "local"
@dataclass
class LocalAccessConfig(AccessConfig):
pass
@dataclass
class LocalConnectionConfig(ConnectionConfig):
access_config: LocalAccessConfig = field(default_factory=lambda: LocalAccessConfig())
@dataclass
class LocalIndexerConfig(IndexerConfig):
input_path: str
@ -43,6 +55,9 @@ class LocalIndexerConfig(IndexerConfig):
@dataclass
class LocalIndexer(Indexer):
index_config: LocalIndexerConfig
connection_config: LocalConnectionConfig = field(
default_factory=lambda: LocalConnectionConfig()
)
connector_type: str = CONNECTOR_TYPE
def list_files(self) -> list[Path]:
@ -115,7 +130,10 @@ class LocalDownloaderConfig(DownloaderConfig):
@dataclass
class LocalDownloader(Downloader):
connector_type: str = CONNECTOR_TYPE
download_config: Optional[LocalDownloaderConfig] = None
connection_config: LocalConnectionConfig = field(
default_factory=lambda: LocalConnectionConfig()
)
download_config: LocalDownloaderConfig = field(default_factory=lambda: LocalDownloaderConfig())
def get_download_path(self, file_data: FileData) -> Path:
return Path(file_data.source_identifiers.fullpath)
@ -139,7 +157,10 @@ class LocalUploaderConfig(UploaderConfig):
@dataclass
class LocalUploader(Uploader):
upload_config: LocalUploaderConfig = field(default_factory=LocalUploaderConfig)
upload_config: LocalUploaderConfig = field(default_factory=lambda: LocalUploaderConfig())
connection_config: LocalConnectionConfig = field(
default_factory=lambda: LocalConnectionConfig()
)
def is_async(self) -> bool:
return False

View File

@ -0,0 +1,236 @@
import json
from dataclasses import dataclass, field
from datetime import date, datetime
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional
from dateutil import parser
from unstructured.ingest.enhanced_dataclass import enhanced_field
from unstructured.ingest.v2.interfaces import (
AccessConfig,
ConnectionConfig,
FileData,
UploadContent,
Uploader,
UploaderConfig,
UploadStager,
UploadStagerConfig,
)
from unstructured.ingest.v2.logger import logger
from unstructured.ingest.v2.processes.connector_registry import (
DestinationRegistryEntry,
add_destination_entry,
)
if TYPE_CHECKING:
from weaviate import Client
CONNECTOR_TYPE = "weaviate"
@dataclass
class WeaviateAccessConfig(AccessConfig):
access_token: Optional[str]
api_key: Optional[str]
client_secret: Optional[str]
password: Optional[str]
@dataclass
class WeaviateConnectionConfig(ConnectionConfig):
host_url: str
class_name: str
access_config: WeaviateAccessConfig = enhanced_field(sensitive=True)
username: Optional[str] = None
anonymous: bool = False
scope: Optional[list[str]] = None
refresh_token: Optional[str] = None
connector_type: str = CONNECTOR_TYPE
@dataclass
class WeaviateUploadStagerConfig(UploadStagerConfig):
pass
@dataclass
class WeaviateUploadStager(UploadStager):
upload_stager_config: WeaviateUploadStagerConfig = field(
default_factory=lambda: WeaviateUploadStagerConfig()
)
@staticmethod
def parse_date_string(date_string: str) -> date:
try:
timestamp = float(date_string)
return datetime.fromtimestamp(timestamp)
except Exception as e:
logger.debug(f"date {date_string} string not a timestamp: {e}")
return parser.parse(date_string)
@classmethod
def conform_dict(cls, data: dict) -> None:
"""
Updates the element dictionary to conform to the Weaviate schema
"""
# Dict as string formatting
if record_locator := data.get("metadata", {}).get("data_source", {}).get("record_locator"):
# Explicit casting otherwise fails schema type checking
data["metadata"]["data_source"]["record_locator"] = str(json.dumps(record_locator))
# Array of items as string formatting
if points := data.get("metadata", {}).get("coordinates", {}).get("points"):
data["metadata"]["coordinates"]["points"] = str(json.dumps(points))
if links := data.get("metadata", {}).get("links", {}):
data["metadata"]["links"] = str(json.dumps(links))
if permissions_data := (
data.get("metadata", {}).get("data_source", {}).get("permissions_data")
):
data["metadata"]["data_source"]["permissions_data"] = json.dumps(permissions_data)
# Datetime formatting
if date_created := data.get("metadata", {}).get("data_source", {}).get("date_created"):
data["metadata"]["data_source"]["date_created"] = cls.parse_date_string(
date_created
).strftime(
"%Y-%m-%dT%H:%M:%S.%fZ",
)
if date_modified := data.get("metadata", {}).get("data_source", {}).get("date_modified"):
data["metadata"]["data_source"]["date_modified"] = cls.parse_date_string(
date_modified
).strftime(
"%Y-%m-%dT%H:%M:%S.%fZ",
)
if date_processed := data.get("metadata", {}).get("data_source", {}).get("date_processed"):
data["metadata"]["data_source"]["date_processed"] = cls.parse_date_string(
date_processed
).strftime(
"%Y-%m-%dT%H:%M:%S.%fZ",
)
if last_modified := data.get("metadata", {}).get("last_modified"):
data["metadata"]["last_modified"] = cls.parse_date_string(last_modified).strftime(
"%Y-%m-%dT%H:%M:%S.%fZ",
)
# String casting
if version := data.get("metadata", {}).get("data_source", {}).get("version"):
data["metadata"]["data_source"]["version"] = str(version)
if page_number := data.get("metadata", {}).get("page_number"):
data["metadata"]["page_number"] = str(page_number)
if regex_metadata := data.get("metadata", {}).get("regex_metadata"):
data["metadata"]["regex_metadata"] = str(json.dumps(regex_metadata))
def run(
self,
elements_filepath: Path,
file_data: FileData,
output_dir: Path,
output_filename: str,
**kwargs: Any,
) -> Path:
with open(elements_filepath) as elements_file:
elements_contents = json.load(elements_file)
for element in elements_contents:
self.conform_dict(data=element)
output_path = Path(output_dir) / Path(f"{output_filename}.json")
with open(output_path, "w") as output_file:
json.dump(elements_contents, output_file)
return output_path
@dataclass
class WeaviateUploaderConfig(UploaderConfig):
batch_size: int = 100
@dataclass
class WeaviateUploader(Uploader):
upload_config: WeaviateUploaderConfig
connection_config: WeaviateConnectionConfig
client: Optional["Client"] = field(init=False)
def __post_init__(self):
from weaviate import Client
auth = self._resolve_auth_method()
self.client = Client(url=self.connection_config.host_url, auth_client_secret=auth)
def is_async(self) -> bool:
return True
def _resolve_auth_method(self):
access_configs = self.connection_config.access_config
connection_config = self.connection_config
if connection_config.anonymous:
return None
if access_configs.access_token:
from weaviate.auth import AuthBearerToken
return AuthBearerToken(
access_token=access_configs.access_token,
refresh_token=connection_config.refresh_token,
)
elif access_configs.api_key:
from weaviate.auth import AuthApiKey
return AuthApiKey(api_key=access_configs.api_key)
elif access_configs.client_secret:
from weaviate.auth import AuthClientCredentials
return AuthClientCredentials(
client_secret=access_configs.client_secret, scope=connection_config.scope
)
elif connection_config.username and access_configs.password:
from weaviate.auth import AuthClientPassword
return AuthClientPassword(
username=connection_config.username,
password=access_configs.password,
scope=connection_config.scope,
)
return None
def run(self, contents: list[UploadContent], **kwargs: Any) -> None:
raise NotImplementedError
async def run_async(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
with open(path) as elements_file:
elements_dict = json.load(elements_file)
logger.info(
f"writing {len(elements_dict)} objects to destination "
f"class {self.connection_config.class_name} "
f"at {self.connection_config.host_url}",
)
self.client.batch.configure(batch_size=self.upload_config.batch_size)
with self.client.batch as b:
for e in elements_dict:
vector = e.pop("embeddings", None)
b.add_data_object(
e,
self.connection_config.class_name,
vector=vector,
)
add_destination_entry(
destination_type=CONNECTOR_TYPE,
entry=DestinationRegistryEntry(
connection_config=WeaviateConnectionConfig,
uploader=WeaviateUploader,
uploader_config=WeaviateUploaderConfig,
upload_stager=WeaviateUploadStager,
upload_stager_config=WeaviateUploadStagerConfig,
),
)