Refactoring VoyageAI integration (#3878)

Using VoyageAI's python package directly, allowing more features than
through langchain
This commit is contained in:
fzowl 2025-01-28 22:45:40 +01:00 committed by GitHub
parent 238f985dda
commit 0fbdd4ea36
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 71 additions and 17 deletions

View File

@ -1,6 +1,7 @@
## 0.16.17-dev1
### Enhancements
- **Refactoring the VoyageAI integration** to use voyageai package directly, allowing extra features.
### Features

View File

@ -1,17 +1,21 @@
from unittest.mock import Mock
from unstructured.documents.elements import Text
from unstructured.embed.voyageai import VoyageAIEmbeddingConfig, VoyageAIEmbeddingEncoder
def test_embed_documents_does_not_break_element_to_dict(mocker):
# Mocked client with the desired behavior for embed_documents
embed_response = Mock()
embed_response.embeddings = [[1], [2]]
mock_client = mocker.MagicMock()
mock_client.embed_documents.return_value = [1, 2]
mock_client.embed.return_value = embed_response
# Mock get_client to return our mock_client
mocker.patch.object(VoyageAIEmbeddingConfig, "get_client", return_value=mock_client)
encoder = VoyageAIEmbeddingEncoder(
config=VoyageAIEmbeddingConfig(api_key="api_key", model_name="voyage-law-2")
config=VoyageAIEmbeddingConfig(api_key="api_key", model_name="voyage-3-large")
)
elements = encoder.embed_documents(
elements=[Text("This is sentence 1"), Text("This is sentence 2")],

View File

@ -37,7 +37,7 @@ PYTHONPATH=${PYTHONPATH:-.} "$RUN_SCRIPT" \
--work-dir "$WORK_DIR" \
--embedding-provider "voyageai" \
--embedding-api-key "$VOYAGE_API_KEY" \
--embedding-model-name "voyage-large-2"
--embedding-model-name "voyage-3-large"
set +e

View File

@ -13,7 +13,7 @@ EMBEDDING_PROVIDER_TO_CLASS_MAP = {
"langchain-huggingface": HuggingFaceEmbeddingEncoder,
"langchain-aws-bedrock": BedrockEmbeddingEncoder,
"langchain-vertexai": VertexAIEmbeddingEncoder,
"langchain-voyageai": VoyageAIEmbeddingEncoder,
"voyageai": VoyageAIEmbeddingEncoder,
"mixedbread-ai": MixedbreadAIEmbeddingEncoder,
"octoai": OctoAIEmbeddingEncoder,
}

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Iterable, List, Optional, cast
import numpy as np
from pydantic import Field, SecretStr
@ -9,30 +9,46 @@ from unstructured.embed.interfaces import BaseEmbeddingEncoder, EmbeddingConfig
from unstructured.utils import requires_dependencies
if TYPE_CHECKING:
from langchain_voyageai import VoyageAIEmbeddings
from voyageai import Client
DEFAULT_VOYAGE_2_BATCH_SIZE = 72
DEFAULT_VOYAGE_3_LITE_BATCH_SIZE = 30
DEFAULT_VOYAGE_3_BATCH_SIZE = 10
DEFAULT_BATCH_SIZE = 7
class VoyageAIEmbeddingConfig(EmbeddingConfig):
api_key: SecretStr
model_name: str
show_progress_bar: bool = False
batch_size: Optional[int] = Field(default=None)
truncation: Optional[bool] = Field(default=None)
output_dimension: Optional[int] = Field(default=None)
@requires_dependencies(
["langchain", "langchain_voyageai"],
["voyageai"],
extras="embed-voyageai",
)
def get_client(self) -> "VoyageAIEmbeddings":
"""Creates a Langchain VoyageAI python client to embed elements."""
from langchain_voyageai import VoyageAIEmbeddings
def get_client(self) -> "Client":
"""Creates a VoyageAI python client to embed elements."""
from voyageai import Client
return VoyageAIEmbeddings(
voyage_api_key=self.api_key,
model=self.model_name,
batch_size=self.batch_size,
truncation=self.truncation,
return Client(
api_key=self.api_key.get_secret_value(),
)
def get_batch_size(self):
if self.batch_size is None:
if self.model_name in ["voyage-2", "voyage-02"]:
self.batch_size = DEFAULT_VOYAGE_2_BATCH_SIZE
elif self.model_name == "voyage-3-lite":
self.batch_size = DEFAULT_VOYAGE_3_LITE_BATCH_SIZE
elif self.model_name == "voyage-3":
self.batch_size = DEFAULT_VOYAGE_3_BATCH_SIZE
else:
self.batch_size = DEFAULT_BATCH_SIZE
return self.batch_size
@dataclass
class VoyageAIEmbeddingEncoder(BaseEmbeddingEncoder):
@ -56,12 +72,29 @@ class VoyageAIEmbeddingEncoder(BaseEmbeddingEncoder):
def embed_documents(self, elements: List[Element]) -> List[Element]:
client = self.config.get_client()
embeddings = client.embed_documents([str(e) for e in elements])
embeddings: List[List[float]] = []
_iter = self._get_batch_iterator(elements)
for i in _iter:
r = client.embed(
texts=[str(e) for e in elements[i : i + self.config.get_batch_size()]],
model=self.config.model_name,
input_type="document",
truncation=self.config.truncation,
output_dimension=self.config.output_dimension,
).embeddings
embeddings.extend(cast(Iterable[List[float]], r))
return self._add_embeddings_to_elements(elements, embeddings)
def embed_query(self, query: str) -> List[float]:
client = self.config.get_client()
return client.embed_query(query)
return client.embed(
texts=[query],
model=self.config.model_name,
input_type="query",
truncation=self.config.truncation,
output_dimension=self.config.output_dimension,
).embeddings[0]
@staticmethod
def _add_embeddings_to_elements(elements, embeddings) -> List[Element]:
@ -71,3 +104,19 @@ class VoyageAIEmbeddingEncoder(BaseEmbeddingEncoder):
element.embeddings = embeddings[i]
elements_w_embedding.append(element)
return elements
def _get_batch_iterator(self, elements: List[Element]) -> Iterable:
if self.config.show_progress_bar:
try:
from tqdm.auto import tqdm # type: ignore
except ImportError as e:
raise ImportError(
"Must have tqdm installed if `show_progress_bar` is set to True. "
"Please install with `pip install tqdm`."
) from e
_iter = tqdm(range(0, len(elements), self.config.get_batch_size()))
else:
_iter = range(0, len(elements), self.config.get_batch_size()) # type: ignore
return _iter