mirror of
https://github.com/Cinnamon/kotaemon.git
synced 2025-06-26 23:19:56 +00:00
feat: add VoyageAI's rerank and embeddings models (#733) #none
* Introducing VoyageAI's rerank and embeddings models * fix: comfort CI * fix: update test case --------- Co-authored-by: fzowl <zoltan@voyageai.com>
This commit is contained in:
parent
c33bedca9e
commit
5132288386
@ -19,6 +19,9 @@ COHERE_API_KEY=<COHERE_API_KEY>
|
||||
# settings for Mistral
|
||||
# MISTRAL_API_KEY=placeholder
|
||||
|
||||
# settings for VoyageAI
|
||||
VOYAGE_API_KEY=<VOYAGE_API_KEY>
|
||||
|
||||
# settings for local models
|
||||
LOCAL_MODEL=qwen2.5:7b
|
||||
LOCAL_MODEL_EMBEDDINGS=nomic-embed-text
|
||||
|
@ -172,6 +172,25 @@ if OPENAI_API_KEY:
|
||||
"default": IS_OPENAI_DEFAULT,
|
||||
}
|
||||
|
||||
VOYAGE_API_KEY = config("VOYAGE_API_KEY", default="")
|
||||
if VOYAGE_API_KEY:
|
||||
KH_EMBEDDINGS["voyageai"] = {
|
||||
"spec": {
|
||||
"__type__": "kotaemon.embeddings.VoyageAIEmbeddings",
|
||||
"api_key": VOYAGE_API_KEY,
|
||||
"model": config("VOYAGE_EMBEDDINGS_MODEL", default="voyage-3-large"),
|
||||
},
|
||||
"default": False,
|
||||
}
|
||||
KH_RERANKINGS["voyageai"] = {
|
||||
"spec": {
|
||||
"__type__": "kotaemon.rerankings.VoyageAIReranking",
|
||||
"model_name": "rerank-2",
|
||||
"api_key": VOYAGE_API_KEY,
|
||||
},
|
||||
"default": False,
|
||||
}
|
||||
|
||||
if config("LOCAL_MODEL", default=""):
|
||||
KH_LLMS["ollama"] = {
|
||||
"spec": {
|
||||
|
@ -11,6 +11,7 @@ from .langchain_based import (
|
||||
)
|
||||
from .openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
|
||||
from .tei_endpoint_embed import TeiEndpointEmbeddings
|
||||
from .voyageai import VoyageAIEmbeddings
|
||||
|
||||
__all__ = [
|
||||
"BaseEmbeddings",
|
||||
@ -25,4 +26,5 @@ __all__ = [
|
||||
"OpenAIEmbeddings",
|
||||
"AzureOpenAIEmbeddings",
|
||||
"FastEmbedEmbeddings",
|
||||
"VoyageAIEmbeddings",
|
||||
]
|
||||
|
66
libs/kotaemon/kotaemon/embeddings/voyageai.py
Normal file
66
libs/kotaemon/kotaemon/embeddings/voyageai.py
Normal file
@ -0,0 +1,66 @@
|
||||
"""Implements embeddings from [Voyage AI](https://voyageai.com).
|
||||
"""
|
||||
|
||||
import importlib
|
||||
|
||||
from kotaemon.base import Document, DocumentWithEmbedding, Param
|
||||
|
||||
from .base import BaseEmbeddings
|
||||
|
||||
vo = None
|
||||
|
||||
|
||||
def _import_voyageai():
|
||||
global vo
|
||||
if not vo:
|
||||
vo = importlib.import_module("voyageai")
|
||||
return vo
|
||||
|
||||
|
||||
def _format_output(texts: list[str], embeddings: list[list]):
|
||||
"""Formats the output of all `.embed` calls.
|
||||
Args:
|
||||
texts: List of original documents
|
||||
embeddings: Embeddings corresponding to each document
|
||||
"""
|
||||
return [
|
||||
DocumentWithEmbedding(content=text, embedding=embedding)
|
||||
for text, embedding in zip(texts, embeddings)
|
||||
]
|
||||
|
||||
|
||||
class VoyageAIEmbeddings(BaseEmbeddings):
|
||||
"""Voyage AI provides best-in-class embedding models and rerankers."""
|
||||
|
||||
api_key: str = Param(None, help="Voyage API key", required=False)
|
||||
model: str = Param(
|
||||
"voyage-3",
|
||||
help=(
|
||||
"Model name to use. The Voyage "
|
||||
"[documentation](https://docs.voyageai.com/docs/embeddings) "
|
||||
"provides a list of all available embedding models."
|
||||
),
|
||||
required=True,
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if not self.api_key:
|
||||
raise ValueError("API key must be provided for VoyageAIEmbeddings.")
|
||||
|
||||
self._client = _import_voyageai().Client(api_key=self.api_key)
|
||||
self._aclient = _import_voyageai().AsyncClient(api_key=self.api_key)
|
||||
|
||||
def invoke(
|
||||
self, text: str | list[str] | Document | list[Document], *args, **kwargs
|
||||
) -> list[DocumentWithEmbedding]:
|
||||
texts = [t.content for t in self.prepare_input(text)]
|
||||
embeddings = self._client.embed(texts, model=self.model).embeddings
|
||||
return _format_output(texts, embeddings)
|
||||
|
||||
async def ainvoke(
|
||||
self, text: str | list[str] | Document | list[Document], *args, **kwargs
|
||||
) -> list[DocumentWithEmbedding]:
|
||||
texts = [t.content for t in self.prepare_input(text)]
|
||||
embeddings = await self._aclient.embed(texts, model=self.model).embeddings
|
||||
return _format_output(texts, embeddings)
|
@ -1,5 +1,6 @@
|
||||
from .base import BaseReranking
|
||||
from .cohere import CohereReranking
|
||||
from .tei_fast_rerank import TeiFastReranking
|
||||
from .voyageai import VoyageAIReranking
|
||||
|
||||
__all__ = ["BaseReranking", "TeiFastReranking", "CohereReranking"]
|
||||
__all__ = ["BaseReranking", "TeiFastReranking", "CohereReranking", "VoyageAIReranking"]
|
||||
|
63
libs/kotaemon/kotaemon/rerankings/voyageai.py
Normal file
63
libs/kotaemon/kotaemon/rerankings/voyageai.py
Normal file
@ -0,0 +1,63 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
|
||||
from decouple import config
|
||||
|
||||
from kotaemon.base import Document, Param
|
||||
|
||||
from .base import BaseReranking
|
||||
|
||||
vo = None
|
||||
|
||||
|
||||
def _import_voyageai():
|
||||
global vo
|
||||
if not vo:
|
||||
vo = importlib.import_module("voyageai")
|
||||
return vo
|
||||
|
||||
|
||||
class VoyageAIReranking(BaseReranking):
|
||||
"""VoyageAI Reranking model"""
|
||||
|
||||
model_name: str = Param(
|
||||
"rerank-2",
|
||||
help=(
|
||||
"ID of the model to use. You can go to [Supported Models]"
|
||||
"(https://docs.voyageai.com/docs/reranker) to see the supported models"
|
||||
),
|
||||
required=True,
|
||||
)
|
||||
api_key: str = Param(
|
||||
config("VOYAGE_API_KEY", ""),
|
||||
help="VoyageAI API key",
|
||||
required=True,
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if not self.api_key:
|
||||
raise ValueError("API key must be provided for VoyageAIEmbeddings.")
|
||||
|
||||
self._client = _import_voyageai().Client(api_key=self.api_key)
|
||||
self._aclient = _import_voyageai().AsyncClient(api_key=self.api_key)
|
||||
|
||||
def run(self, documents: list[Document], query: str) -> list[Document]:
|
||||
"""Use VoyageAI Reranker model to re-order documents
|
||||
with their relevance score"""
|
||||
compressed_docs: list[Document] = []
|
||||
|
||||
if not documents: # to avoid empty api call
|
||||
return compressed_docs
|
||||
|
||||
_docs = [d.content for d in documents]
|
||||
response = self._client.rerank(
|
||||
model=self.model_name, query=query, documents=_docs
|
||||
)
|
||||
for r in response.results:
|
||||
doc = documents[r.index]
|
||||
doc.metadata["reranking_score"] = r.relevance_score
|
||||
compressed_docs.append(doc)
|
||||
|
||||
return compressed_docs
|
@ -90,6 +90,7 @@ adv = [
|
||||
"tabulate",
|
||||
"unstructured>=0.15.8,<0.16",
|
||||
"wikipedia>=1.4.0,<1.5",
|
||||
"voyageai>=0.3.0",
|
||||
]
|
||||
dev = [
|
||||
"black",
|
||||
|
@ -70,6 +70,15 @@ def if_llama_cpp_not_installed():
|
||||
return False
|
||||
|
||||
|
||||
def if_voyageai_not_installed():
|
||||
try:
|
||||
import voyageai # noqa: F401
|
||||
except ImportError:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
skip_when_haystack_not_installed = pytest.mark.skipif(
|
||||
if_haystack_not_installed(), reason="Haystack is not installed"
|
||||
)
|
||||
@ -97,3 +106,7 @@ skip_openai_lc_wrapper_test = pytest.mark.skipif(
|
||||
skip_llama_cpp_not_installed = pytest.mark.skipif(
|
||||
if_llama_cpp_not_installed(), reason="llama_cpp is not installed"
|
||||
)
|
||||
|
||||
skip_when_voyageai_not_installed = pytest.mark.skipif(
|
||||
if_voyageai_not_installed(), reason="voyageai is not installed"
|
||||
)
|
||||
|
@ -1,22 +1,24 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from openai.types.create_embedding_response import CreateEmbeddingResponse
|
||||
|
||||
from kotaemon.base import Document
|
||||
from kotaemon.base import Document, DocumentWithEmbedding
|
||||
from kotaemon.embeddings import (
|
||||
AzureOpenAIEmbeddings,
|
||||
FastEmbedEmbeddings,
|
||||
LCCohereEmbeddings,
|
||||
LCHuggingFaceEmbeddings,
|
||||
OpenAIEmbeddings,
|
||||
VoyageAIEmbeddings,
|
||||
)
|
||||
|
||||
from .conftest import (
|
||||
skip_when_cohere_not_installed,
|
||||
skip_when_fastembed_not_installed,
|
||||
skip_when_sentence_bert_not_installed,
|
||||
skip_when_voyageai_not_installed,
|
||||
)
|
||||
|
||||
with open(Path(__file__).parent / "resources" / "embedding_openai_batch.json") as f:
|
||||
@ -155,3 +157,16 @@ def test_fastembed_embeddings():
|
||||
model = FastEmbedEmbeddings()
|
||||
output = model("Hello World")
|
||||
assert_embedding_result(output)
|
||||
|
||||
|
||||
voyage_output_mock = Mock()
|
||||
voyage_output_mock.embeddings = [[1.0, 2.1, 3.2]]
|
||||
|
||||
|
||||
@skip_when_voyageai_not_installed
|
||||
@patch("voyageai.Client.embed", return_value=voyage_output_mock)
|
||||
@patch("voyageai.AsyncClient.embed", return_value=voyage_output_mock)
|
||||
def test_voyageai_embeddings(sync_call, async_call):
|
||||
model = VoyageAIEmbeddings(api_key="test")
|
||||
output = model("Hello, world!")
|
||||
assert all(isinstance(doc, DocumentWithEmbedding) for doc in output)
|
||||
|
@ -62,6 +62,7 @@ class EmbeddingManager:
|
||||
LCMistralEmbeddings,
|
||||
OpenAIEmbeddings,
|
||||
TeiEndpointEmbeddings,
|
||||
VoyageAIEmbeddings,
|
||||
)
|
||||
|
||||
self._vendors = [
|
||||
@ -73,6 +74,7 @@ class EmbeddingManager:
|
||||
LCGoogleEmbeddings,
|
||||
LCMistralEmbeddings,
|
||||
TeiEndpointEmbeddings,
|
||||
VoyageAIEmbeddings,
|
||||
]
|
||||
|
||||
def __getitem__(self, key: str) -> BaseEmbeddings:
|
||||
|
@ -52,9 +52,13 @@ class RerankingManager:
|
||||
self._default = item.name
|
||||
|
||||
def load_vendors(self):
|
||||
from kotaemon.rerankings import CohereReranking, TeiFastReranking
|
||||
from kotaemon.rerankings import (
|
||||
CohereReranking,
|
||||
TeiFastReranking,
|
||||
VoyageAIReranking,
|
||||
)
|
||||
|
||||
self._vendors = [TeiFastReranking, CohereReranking]
|
||||
self._vendors = [TeiFastReranking, CohereReranking, VoyageAIReranking]
|
||||
|
||||
def __getitem__(self, key: str) -> BaseReranking:
|
||||
"""Get model by name"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user