mirror of
https://github.com/Unstructured-IO/unstructured.git
synced 2025-06-27 02:30:08 +00:00
Refactoring VoyageAI integration (#3878)
Using VoyageAI's python package directly, allowing more features than through langchain
This commit is contained in:
parent
238f985dda
commit
0fbdd4ea36
@ -1,6 +1,7 @@
|
||||
## 0.16.17-dev1
|
||||
|
||||
### Enhancements
|
||||
- **Refactoring the VoyageAI integration** to use voyageai package directly, allowing extra features.
|
||||
|
||||
### Features
|
||||
|
||||
|
@ -1,17 +1,21 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
from unstructured.documents.elements import Text
|
||||
from unstructured.embed.voyageai import VoyageAIEmbeddingConfig, VoyageAIEmbeddingEncoder
|
||||
|
||||
|
||||
def test_embed_documents_does_not_break_element_to_dict(mocker):
|
||||
# Mocked client with the desired behavior for embed_documents
|
||||
embed_response = Mock()
|
||||
embed_response.embeddings = [[1], [2]]
|
||||
mock_client = mocker.MagicMock()
|
||||
mock_client.embed_documents.return_value = [1, 2]
|
||||
mock_client.embed.return_value = embed_response
|
||||
|
||||
# Mock get_client to return our mock_client
|
||||
mocker.patch.object(VoyageAIEmbeddingConfig, "get_client", return_value=mock_client)
|
||||
|
||||
encoder = VoyageAIEmbeddingEncoder(
|
||||
config=VoyageAIEmbeddingConfig(api_key="api_key", model_name="voyage-law-2")
|
||||
config=VoyageAIEmbeddingConfig(api_key="api_key", model_name="voyage-3-large")
|
||||
)
|
||||
elements = encoder.embed_documents(
|
||||
elements=[Text("This is sentence 1"), Text("This is sentence 2")],
|
||||
|
@ -37,7 +37,7 @@ PYTHONPATH=${PYTHONPATH:-.} "$RUN_SCRIPT" \
|
||||
--work-dir "$WORK_DIR" \
|
||||
--embedding-provider "voyageai" \
|
||||
--embedding-api-key "$VOYAGE_API_KEY" \
|
||||
--embedding-model-name "voyage-large-2"
|
||||
--embedding-model-name "voyage-3-large"
|
||||
|
||||
set +e
|
||||
|
||||
|
@ -13,7 +13,7 @@ EMBEDDING_PROVIDER_TO_CLASS_MAP = {
|
||||
"langchain-huggingface": HuggingFaceEmbeddingEncoder,
|
||||
"langchain-aws-bedrock": BedrockEmbeddingEncoder,
|
||||
"langchain-vertexai": VertexAIEmbeddingEncoder,
|
||||
"langchain-voyageai": VoyageAIEmbeddingEncoder,
|
||||
"voyageai": VoyageAIEmbeddingEncoder,
|
||||
"mixedbread-ai": MixedbreadAIEmbeddingEncoder,
|
||||
"octoai": OctoAIEmbeddingEncoder,
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Iterable, List, Optional, cast
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field, SecretStr
|
||||
@ -9,30 +9,46 @@ from unstructured.embed.interfaces import BaseEmbeddingEncoder, EmbeddingConfig
|
||||
from unstructured.utils import requires_dependencies
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_voyageai import VoyageAIEmbeddings
|
||||
from voyageai import Client
|
||||
|
||||
DEFAULT_VOYAGE_2_BATCH_SIZE = 72
|
||||
DEFAULT_VOYAGE_3_LITE_BATCH_SIZE = 30
|
||||
DEFAULT_VOYAGE_3_BATCH_SIZE = 10
|
||||
DEFAULT_BATCH_SIZE = 7
|
||||
|
||||
|
||||
class VoyageAIEmbeddingConfig(EmbeddingConfig):
|
||||
api_key: SecretStr
|
||||
model_name: str
|
||||
show_progress_bar: bool = False
|
||||
batch_size: Optional[int] = Field(default=None)
|
||||
truncation: Optional[bool] = Field(default=None)
|
||||
output_dimension: Optional[int] = Field(default=None)
|
||||
|
||||
@requires_dependencies(
|
||||
["langchain", "langchain_voyageai"],
|
||||
["voyageai"],
|
||||
extras="embed-voyageai",
|
||||
)
|
||||
def get_client(self) -> "VoyageAIEmbeddings":
|
||||
"""Creates a Langchain VoyageAI python client to embed elements."""
|
||||
from langchain_voyageai import VoyageAIEmbeddings
|
||||
def get_client(self) -> "Client":
|
||||
"""Creates a VoyageAI python client to embed elements."""
|
||||
from voyageai import Client
|
||||
|
||||
return VoyageAIEmbeddings(
|
||||
voyage_api_key=self.api_key,
|
||||
model=self.model_name,
|
||||
batch_size=self.batch_size,
|
||||
truncation=self.truncation,
|
||||
return Client(
|
||||
api_key=self.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
def get_batch_size(self):
|
||||
if self.batch_size is None:
|
||||
if self.model_name in ["voyage-2", "voyage-02"]:
|
||||
self.batch_size = DEFAULT_VOYAGE_2_BATCH_SIZE
|
||||
elif self.model_name == "voyage-3-lite":
|
||||
self.batch_size = DEFAULT_VOYAGE_3_LITE_BATCH_SIZE
|
||||
elif self.model_name == "voyage-3":
|
||||
self.batch_size = DEFAULT_VOYAGE_3_BATCH_SIZE
|
||||
else:
|
||||
self.batch_size = DEFAULT_BATCH_SIZE
|
||||
return self.batch_size
|
||||
|
||||
|
||||
@dataclass
|
||||
class VoyageAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
||||
@ -56,12 +72,29 @@ class VoyageAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
||||
|
||||
def embed_documents(self, elements: List[Element]) -> List[Element]:
|
||||
client = self.config.get_client()
|
||||
embeddings = client.embed_documents([str(e) for e in elements])
|
||||
embeddings: List[List[float]] = []
|
||||
|
||||
_iter = self._get_batch_iterator(elements)
|
||||
for i in _iter:
|
||||
r = client.embed(
|
||||
texts=[str(e) for e in elements[i : i + self.config.get_batch_size()]],
|
||||
model=self.config.model_name,
|
||||
input_type="document",
|
||||
truncation=self.config.truncation,
|
||||
output_dimension=self.config.output_dimension,
|
||||
).embeddings
|
||||
embeddings.extend(cast(Iterable[List[float]], r))
|
||||
return self._add_embeddings_to_elements(elements, embeddings)
|
||||
|
||||
def embed_query(self, query: str) -> List[float]:
|
||||
client = self.config.get_client()
|
||||
return client.embed_query(query)
|
||||
return client.embed(
|
||||
texts=[query],
|
||||
model=self.config.model_name,
|
||||
input_type="query",
|
||||
truncation=self.config.truncation,
|
||||
output_dimension=self.config.output_dimension,
|
||||
).embeddings[0]
|
||||
|
||||
@staticmethod
|
||||
def _add_embeddings_to_elements(elements, embeddings) -> List[Element]:
|
||||
@ -71,3 +104,19 @@ class VoyageAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
||||
element.embeddings = embeddings[i]
|
||||
elements_w_embedding.append(element)
|
||||
return elements
|
||||
|
||||
def _get_batch_iterator(self, elements: List[Element]) -> Iterable:
|
||||
if self.config.show_progress_bar:
|
||||
try:
|
||||
from tqdm.auto import tqdm # type: ignore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Must have tqdm installed if `show_progress_bar` is set to True. "
|
||||
"Please install with `pip install tqdm`."
|
||||
) from e
|
||||
|
||||
_iter = tqdm(range(0, len(elements), self.config.get_batch_size()))
|
||||
else:
|
||||
_iter = range(0, len(elements), self.config.get_batch_size()) # type: ignore
|
||||
|
||||
return _iter
|
||||
|
Loading…
x
Reference in New Issue
Block a user