diff --git a/CHANGELOG.md b/CHANGELOG.md index 3cd508828..8869fbe9f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ ## 0.16.17-dev1 ### Enhancements +- **Refactoring the VoyageAI integration** to use voyageai package directly, allowing extra features. ### Features diff --git a/test_unstructured/embed/test_voyageai.py b/test_unstructured/embed/test_voyageai.py index b759e6153..f0a24bde7 100644 --- a/test_unstructured/embed/test_voyageai.py +++ b/test_unstructured/embed/test_voyageai.py @@ -1,17 +1,21 @@ +from unittest.mock import Mock + from unstructured.documents.elements import Text from unstructured.embed.voyageai import VoyageAIEmbeddingConfig, VoyageAIEmbeddingEncoder def test_embed_documents_does_not_break_element_to_dict(mocker): # Mocked client with the desired behavior for embed_documents + embed_response = Mock() + embed_response.embeddings = [[1], [2]] mock_client = mocker.MagicMock() - mock_client.embed_documents.return_value = [1, 2] + mock_client.embed.return_value = embed_response # Mock get_client to return our mock_client mocker.patch.object(VoyageAIEmbeddingConfig, "get_client", return_value=mock_client) encoder = VoyageAIEmbeddingEncoder( - config=VoyageAIEmbeddingConfig(api_key="api_key", model_name="voyage-law-2") + config=VoyageAIEmbeddingConfig(api_key="api_key", model_name="voyage-3-large") ) elements = encoder.embed_documents( elements=[Text("This is sentence 1"), Text("This is sentence 2")], diff --git a/test_unstructured_ingest/src/local-embed-voyageai.sh b/test_unstructured_ingest/src/local-embed-voyageai.sh index c5f3be1fe..83fe3586a 100755 --- a/test_unstructured_ingest/src/local-embed-voyageai.sh +++ b/test_unstructured_ingest/src/local-embed-voyageai.sh @@ -37,7 +37,7 @@ PYTHONPATH=${PYTHONPATH:-.} "$RUN_SCRIPT" \ --work-dir "$WORK_DIR" \ --embedding-provider "voyageai" \ --embedding-api-key "$VOYAGE_API_KEY" \ - --embedding-model-name "voyage-large-2" + --embedding-model-name "voyage-3-large" set +e diff --git a/unstructured/embed/__init__.py b/unstructured/embed/__init__.py index 7b5a49e98..2eda6a326 100644 --- a/unstructured/embed/__init__.py +++ b/unstructured/embed/__init__.py @@ -13,7 +13,7 @@ EMBEDDING_PROVIDER_TO_CLASS_MAP = { "langchain-huggingface": HuggingFaceEmbeddingEncoder, "langchain-aws-bedrock": BedrockEmbeddingEncoder, "langchain-vertexai": VertexAIEmbeddingEncoder, - "langchain-voyageai": VoyageAIEmbeddingEncoder, + "voyageai": VoyageAIEmbeddingEncoder, "mixedbread-ai": MixedbreadAIEmbeddingEncoder, "octoai": OctoAIEmbeddingEncoder, } diff --git a/unstructured/embed/voyageai.py b/unstructured/embed/voyageai.py index c5dd5b61c..35d231884 100644 --- a/unstructured/embed/voyageai.py +++ b/unstructured/embed/voyageai.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Iterable, List, Optional, cast import numpy as np from pydantic import Field, SecretStr @@ -9,30 +9,46 @@ from unstructured.embed.interfaces import BaseEmbeddingEncoder, EmbeddingConfig from unstructured.utils import requires_dependencies if TYPE_CHECKING: - from langchain_voyageai import VoyageAIEmbeddings + from voyageai import Client + +DEFAULT_VOYAGE_2_BATCH_SIZE = 72 +DEFAULT_VOYAGE_3_LITE_BATCH_SIZE = 30 +DEFAULT_VOYAGE_3_BATCH_SIZE = 10 +DEFAULT_BATCH_SIZE = 7 class VoyageAIEmbeddingConfig(EmbeddingConfig): api_key: SecretStr model_name: str + show_progress_bar: bool = False batch_size: Optional[int] = Field(default=None) truncation: Optional[bool] = Field(default=None) + output_dimension: Optional[int] = Field(default=None) @requires_dependencies( - ["langchain", "langchain_voyageai"], + ["voyageai"], extras="embed-voyageai", ) - def get_client(self) -> "VoyageAIEmbeddings": - """Creates a Langchain VoyageAI python client to embed elements.""" - from langchain_voyageai import VoyageAIEmbeddings + def get_client(self) -> "Client": + """Creates a VoyageAI python client to embed elements.""" + from voyageai import Client - return VoyageAIEmbeddings( - voyage_api_key=self.api_key, - model=self.model_name, - batch_size=self.batch_size, - truncation=self.truncation, + return Client( + api_key=self.api_key.get_secret_value(), ) + def get_batch_size(self): + if self.batch_size is None: + if self.model_name in ["voyage-2", "voyage-02"]: + self.batch_size = DEFAULT_VOYAGE_2_BATCH_SIZE + elif self.model_name == "voyage-3-lite": + self.batch_size = DEFAULT_VOYAGE_3_LITE_BATCH_SIZE + elif self.model_name == "voyage-3": + self.batch_size = DEFAULT_VOYAGE_3_BATCH_SIZE + else: + self.batch_size = DEFAULT_BATCH_SIZE + return self.batch_size + @dataclass class VoyageAIEmbeddingEncoder(BaseEmbeddingEncoder): @@ -56,12 +72,29 @@ class VoyageAIEmbeddingEncoder(BaseEmbeddingEncoder): def embed_documents(self, elements: List[Element]) -> List[Element]: client = self.config.get_client() - embeddings = client.embed_documents([str(e) for e in elements]) + embeddings: List[List[float]] = [] + + _iter = self._get_batch_iterator(elements) + for i in _iter: + r = client.embed( + texts=[str(e) for e in elements[i : i + self.config.get_batch_size()]], + model=self.config.model_name, + input_type="document", + truncation=self.config.truncation, + output_dimension=self.config.output_dimension, + ).embeddings + embeddings.extend(cast(Iterable[List[float]], r)) return self._add_embeddings_to_elements(elements, embeddings) def embed_query(self, query: str) -> List[float]: client = self.config.get_client() - return client.embed_query(query) + return client.embed( + texts=[query], + model=self.config.model_name, + input_type="query", + truncation=self.config.truncation, + output_dimension=self.config.output_dimension, + ).embeddings[0] @staticmethod def _add_embeddings_to_elements(elements, embeddings) -> List[Element]: @@ -71,3 +104,19 @@ class VoyageAIEmbeddingEncoder(BaseEmbeddingEncoder): element.embeddings = embeddings[i] elements_w_embedding.append(element) return elements + + def _get_batch_iterator(self, elements: List[Element]) -> Iterable: + if self.config.show_progress_bar: + try: + from tqdm.auto import tqdm # type: ignore + except ImportError as e: + raise ImportError( + "Must have tqdm installed if `show_progress_bar` is set to True. " + "Please install with `pip install tqdm`." + ) from e + + _iter = tqdm(range(0, len(elements), self.config.get_batch_size())) + else: + _iter = range(0, len(elements), self.config.get_batch_size()) # type: ignore + + return _iter