2023-09-22 21:54:11 +02:00
|
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
|
|
import openai
|
|
|
|
from openai.util import convert_to_openai_object
|
|
|
|
import numpy as np
|
|
|
|
|
2023-11-24 14:48:43 +01:00
|
|
|
from haystack.components.embedders.openai_text_embedder import OpenAITextEmbedder
|
2023-09-22 21:54:11 +02:00
|
|
|
|
|
|
|
|
|
|
|
def mock_openai_response(model: str = "text-embedding-ada-002", **kwargs) -> openai.openai_object.OpenAIObject:
|
|
|
|
dict_response = {
|
|
|
|
"object": "list",
|
|
|
|
"data": [{"object": "embedding", "index": 0, "embedding": np.random.rand(1536).tolist()}],
|
|
|
|
"model": model,
|
|
|
|
"usage": {"prompt_tokens": 4, "total_tokens": 4},
|
|
|
|
}
|
|
|
|
|
|
|
|
return convert_to_openai_object(dict_response)
|
|
|
|
|
|
|
|
|
|
|
|
class TestOpenAITextEmbedder:
|
|
|
|
def test_init_default(self, monkeypatch):
|
2023-10-23 12:53:52 +02:00
|
|
|
openai.api_key = None
|
2023-09-22 21:54:11 +02:00
|
|
|
monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
|
|
|
|
embedder = OpenAITextEmbedder()
|
|
|
|
|
|
|
|
assert openai.api_key == "fake-api-key"
|
|
|
|
assert embedder.model_name == "text-embedding-ada-002"
|
|
|
|
assert embedder.organization is None
|
|
|
|
assert embedder.prefix == ""
|
|
|
|
assert embedder.suffix == ""
|
|
|
|
|
|
|
|
def test_init_with_parameters(self):
|
|
|
|
embedder = OpenAITextEmbedder(
|
|
|
|
api_key="fake-api-key",
|
|
|
|
model_name="model",
|
|
|
|
organization="fake-organization",
|
|
|
|
prefix="prefix",
|
|
|
|
suffix="suffix",
|
|
|
|
)
|
|
|
|
assert openai.api_key == "fake-api-key"
|
|
|
|
assert embedder.model_name == "model"
|
|
|
|
assert embedder.organization == "fake-organization"
|
|
|
|
assert openai.organization == "fake-organization"
|
|
|
|
assert embedder.prefix == "prefix"
|
|
|
|
assert embedder.suffix == "suffix"
|
|
|
|
|
|
|
|
def test_init_fail_wo_api_key(self, monkeypatch):
|
2023-10-23 12:53:52 +02:00
|
|
|
openai.api_key = None
|
2023-09-22 21:54:11 +02:00
|
|
|
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
|
|
|
with pytest.raises(ValueError, match="OpenAITextEmbedder expects an OpenAI API key"):
|
|
|
|
OpenAITextEmbedder()
|
|
|
|
|
|
|
|
def test_to_dict(self):
|
|
|
|
component = OpenAITextEmbedder(api_key="fake-api-key")
|
|
|
|
data = component.to_dict()
|
|
|
|
assert data == {
|
2023-11-24 14:48:43 +01:00
|
|
|
"type": "haystack.components.embedders.openai_text_embedder.OpenAITextEmbedder",
|
2023-09-22 21:54:11 +02:00
|
|
|
"init_parameters": {
|
|
|
|
"model_name": "text-embedding-ada-002",
|
|
|
|
"organization": None,
|
|
|
|
"prefix": "",
|
|
|
|
"suffix": "",
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
def test_to_dict_with_custom_init_parameters(self):
|
|
|
|
component = OpenAITextEmbedder(
|
|
|
|
api_key="fake-api-key",
|
|
|
|
model_name="model",
|
|
|
|
organization="fake-organization",
|
|
|
|
prefix="prefix",
|
|
|
|
suffix="suffix",
|
|
|
|
)
|
|
|
|
data = component.to_dict()
|
|
|
|
assert data == {
|
2023-11-24 14:48:43 +01:00
|
|
|
"type": "haystack.components.embedders.openai_text_embedder.OpenAITextEmbedder",
|
2023-09-22 21:54:11 +02:00
|
|
|
"init_parameters": {
|
|
|
|
"model_name": "model",
|
|
|
|
"organization": "fake-organization",
|
|
|
|
"prefix": "prefix",
|
|
|
|
"suffix": "suffix",
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
def test_run(self):
|
|
|
|
model = "text-similarity-ada-001"
|
|
|
|
|
2023-11-24 14:48:43 +01:00
|
|
|
with patch("haystack.components.embedders.openai_text_embedder.openai.Embedding") as openai_embedding_patch:
|
2023-09-22 21:54:11 +02:00
|
|
|
openai_embedding_patch.create.side_effect = mock_openai_response
|
|
|
|
|
|
|
|
embedder = OpenAITextEmbedder(api_key="fake-api-key", model_name=model, prefix="prefix ", suffix=" suffix")
|
|
|
|
result = embedder.run(text="The food was delicious")
|
|
|
|
|
|
|
|
openai_embedding_patch.create.assert_called_once_with(
|
|
|
|
model=model, input="prefix The food was delicious suffix"
|
|
|
|
)
|
|
|
|
|
|
|
|
assert len(result["embedding"]) == 1536
|
2023-10-19 11:17:02 +02:00
|
|
|
assert all(isinstance(x, float) for x in result["embedding"])
|
2023-09-22 21:54:11 +02:00
|
|
|
assert result["metadata"] == {"model": model, "usage": {"prompt_tokens": 4, "total_tokens": 4}}
|
|
|
|
|
|
|
|
def test_run_wrong_input_format(self):
|
|
|
|
embedder = OpenAITextEmbedder(api_key="fake-api-key")
|
|
|
|
|
|
|
|
list_integers_input = [1, 2, 3]
|
|
|
|
|
|
|
|
with pytest.raises(TypeError, match="OpenAITextEmbedder expects a string as an input"):
|
|
|
|
embedder.run(text=list_integers_input)
|