feat: add VoyageAI's rerank and embeddings models (#733) #none

* Introducing VoyageAI's rerank and embeddings models

* fix: comfort CI

* fix: update test case

---------

Co-authored-by: fzowl <zoltan@voyageai.com>
This commit is contained in:
Tuan Anh Nguyen Dang (Tadashi_Cin) 2025-04-15 15:54:23 +07:00 committed by GitHub
parent c33bedca9e
commit 5132288386
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 194 additions and 5 deletions

View File

@ -19,6 +19,9 @@ COHERE_API_KEY=<COHERE_API_KEY>
# settings for Mistral
# MISTRAL_API_KEY=placeholder
# settings for VoyageAI
VOYAGE_API_KEY=<VOYAGE_API_KEY>
# settings for local models
LOCAL_MODEL=qwen2.5:7b
LOCAL_MODEL_EMBEDDINGS=nomic-embed-text

View File

@ -172,6 +172,25 @@ if OPENAI_API_KEY:
"default": IS_OPENAI_DEFAULT,
}
VOYAGE_API_KEY = config("VOYAGE_API_KEY", default="")
if VOYAGE_API_KEY:
KH_EMBEDDINGS["voyageai"] = {
"spec": {
"__type__": "kotaemon.embeddings.VoyageAIEmbeddings",
"api_key": VOYAGE_API_KEY,
"model": config("VOYAGE_EMBEDDINGS_MODEL", default="voyage-3-large"),
},
"default": False,
}
KH_RERANKINGS["voyageai"] = {
"spec": {
"__type__": "kotaemon.rerankings.VoyageAIReranking",
"model_name": "rerank-2",
"api_key": VOYAGE_API_KEY,
},
"default": False,
}
if config("LOCAL_MODEL", default=""):
KH_LLMS["ollama"] = {
"spec": {

View File

@ -11,6 +11,7 @@ from .langchain_based import (
)
from .openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
from .tei_endpoint_embed import TeiEndpointEmbeddings
from .voyageai import VoyageAIEmbeddings
__all__ = [
"BaseEmbeddings",
@ -25,4 +26,5 @@ __all__ = [
"OpenAIEmbeddings",
"AzureOpenAIEmbeddings",
"FastEmbedEmbeddings",
"VoyageAIEmbeddings",
]

View File

@ -0,0 +1,66 @@
"""Implements embeddings from [Voyage AI](https://voyageai.com).
"""
import importlib
from kotaemon.base import Document, DocumentWithEmbedding, Param
from .base import BaseEmbeddings
vo = None
def _import_voyageai():
global vo
if not vo:
vo = importlib.import_module("voyageai")
return vo
def _format_output(texts: list[str], embeddings: list[list]):
"""Formats the output of all `.embed` calls.
Args:
texts: List of original documents
embeddings: Embeddings corresponding to each document
"""
return [
DocumentWithEmbedding(content=text, embedding=embedding)
for text, embedding in zip(texts, embeddings)
]
class VoyageAIEmbeddings(BaseEmbeddings):
"""Voyage AI provides best-in-class embedding models and rerankers."""
api_key: str = Param(None, help="Voyage API key", required=False)
model: str = Param(
"voyage-3",
help=(
"Model name to use. The Voyage "
"[documentation](https://docs.voyageai.com/docs/embeddings) "
"provides a list of all available embedding models."
),
required=True,
)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not self.api_key:
raise ValueError("API key must be provided for VoyageAIEmbeddings.")
self._client = _import_voyageai().Client(api_key=self.api_key)
self._aclient = _import_voyageai().AsyncClient(api_key=self.api_key)
def invoke(
self, text: str | list[str] | Document | list[Document], *args, **kwargs
) -> list[DocumentWithEmbedding]:
texts = [t.content for t in self.prepare_input(text)]
embeddings = self._client.embed(texts, model=self.model).embeddings
return _format_output(texts, embeddings)
async def ainvoke(
self, text: str | list[str] | Document | list[Document], *args, **kwargs
) -> list[DocumentWithEmbedding]:
texts = [t.content for t in self.prepare_input(text)]
embeddings = await self._aclient.embed(texts, model=self.model).embeddings
return _format_output(texts, embeddings)

View File

@ -1,5 +1,6 @@
from .base import BaseReranking
from .cohere import CohereReranking
from .tei_fast_rerank import TeiFastReranking
from .voyageai import VoyageAIReranking
__all__ = ["BaseReranking", "TeiFastReranking", "CohereReranking"]
__all__ = ["BaseReranking", "TeiFastReranking", "CohereReranking", "VoyageAIReranking"]

View File

@ -0,0 +1,63 @@
from __future__ import annotations
import importlib
from decouple import config
from kotaemon.base import Document, Param
from .base import BaseReranking
vo = None
def _import_voyageai():
global vo
if not vo:
vo = importlib.import_module("voyageai")
return vo
class VoyageAIReranking(BaseReranking):
"""VoyageAI Reranking model"""
model_name: str = Param(
"rerank-2",
help=(
"ID of the model to use. You can go to [Supported Models]"
"(https://docs.voyageai.com/docs/reranker) to see the supported models"
),
required=True,
)
api_key: str = Param(
config("VOYAGE_API_KEY", ""),
help="VoyageAI API key",
required=True,
)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not self.api_key:
raise ValueError("API key must be provided for VoyageAIEmbeddings.")
self._client = _import_voyageai().Client(api_key=self.api_key)
self._aclient = _import_voyageai().AsyncClient(api_key=self.api_key)
def run(self, documents: list[Document], query: str) -> list[Document]:
"""Use VoyageAI Reranker model to re-order documents
with their relevance score"""
compressed_docs: list[Document] = []
if not documents: # to avoid empty api call
return compressed_docs
_docs = [d.content for d in documents]
response = self._client.rerank(
model=self.model_name, query=query, documents=_docs
)
for r in response.results:
doc = documents[r.index]
doc.metadata["reranking_score"] = r.relevance_score
compressed_docs.append(doc)
return compressed_docs

View File

@ -90,6 +90,7 @@ adv = [
"tabulate",
"unstructured>=0.15.8,<0.16",
"wikipedia>=1.4.0,<1.5",
"voyageai>=0.3.0",
]
dev = [
"black",

View File

@ -70,6 +70,15 @@ def if_llama_cpp_not_installed():
return False
def if_voyageai_not_installed():
try:
import voyageai # noqa: F401
except ImportError:
return True
else:
return False
skip_when_haystack_not_installed = pytest.mark.skipif(
if_haystack_not_installed(), reason="Haystack is not installed"
)
@ -97,3 +106,7 @@ skip_openai_lc_wrapper_test = pytest.mark.skipif(
skip_llama_cpp_not_installed = pytest.mark.skipif(
if_llama_cpp_not_installed(), reason="llama_cpp is not installed"
)
skip_when_voyageai_not_installed = pytest.mark.skipif(
if_voyageai_not_installed(), reason="voyageai is not installed"
)

View File

@ -1,22 +1,24 @@
import json
from pathlib import Path
from unittest.mock import patch
from unittest.mock import Mock, patch
from openai.types.create_embedding_response import CreateEmbeddingResponse
from kotaemon.base import Document
from kotaemon.base import Document, DocumentWithEmbedding
from kotaemon.embeddings import (
AzureOpenAIEmbeddings,
FastEmbedEmbeddings,
LCCohereEmbeddings,
LCHuggingFaceEmbeddings,
OpenAIEmbeddings,
VoyageAIEmbeddings,
)
from .conftest import (
skip_when_cohere_not_installed,
skip_when_fastembed_not_installed,
skip_when_sentence_bert_not_installed,
skip_when_voyageai_not_installed,
)
with open(Path(__file__).parent / "resources" / "embedding_openai_batch.json") as f:
@ -155,3 +157,16 @@ def test_fastembed_embeddings():
model = FastEmbedEmbeddings()
output = model("Hello World")
assert_embedding_result(output)
voyage_output_mock = Mock()
voyage_output_mock.embeddings = [[1.0, 2.1, 3.2]]
@skip_when_voyageai_not_installed
@patch("voyageai.Client.embed", return_value=voyage_output_mock)
@patch("voyageai.AsyncClient.embed", return_value=voyage_output_mock)
def test_voyageai_embeddings(sync_call, async_call):
model = VoyageAIEmbeddings(api_key="test")
output = model("Hello, world!")
assert all(isinstance(doc, DocumentWithEmbedding) for doc in output)

View File

@ -62,6 +62,7 @@ class EmbeddingManager:
LCMistralEmbeddings,
OpenAIEmbeddings,
TeiEndpointEmbeddings,
VoyageAIEmbeddings,
)
self._vendors = [
@ -73,6 +74,7 @@ class EmbeddingManager:
LCGoogleEmbeddings,
LCMistralEmbeddings,
TeiEndpointEmbeddings,
VoyageAIEmbeddings,
]
def __getitem__(self, key: str) -> BaseEmbeddings:

View File

@ -52,9 +52,13 @@ class RerankingManager:
self._default = item.name
def load_vendors(self):
from kotaemon.rerankings import CohereReranking, TeiFastReranking
from kotaemon.rerankings import (
CohereReranking,
TeiFastReranking,
VoyageAIReranking,
)
self._vendors = [TeiFastReranking, CohereReranking]
self._vendors = [TeiFastReranking, CohereReranking, VoyageAIReranking]
def __getitem__(self, key: str) -> BaseReranking:
"""Get model by name"""