mirror of
https://github.com/getzep/graphiti.git
synced 2025-06-27 02:00:02 +00:00
add embedder tests (#430)
This commit is contained in:
parent
6b85e92105
commit
f2e95a5685
20
tests/embedder/embedder_fixtures.py
Normal file
20
tests/embedder/embedder_fixtures.py
Normal file
@ -0,0 +1,20 @@
|
||||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
|
||||
def create_embedding_values(multiplier: float = 0.1, dimension: int = 1536) -> list[float]:
|
||||
"""Create embedding values with the specified multiplier and dimension."""
|
||||
return [multiplier] * dimension
|
127
tests/embedder/test_gemini.py
Normal file
127
tests/embedder/test_gemini.py
Normal file
@ -0,0 +1,127 @@
|
||||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from graphiti_core.embedder.gemini import (
|
||||
DEFAULT_EMBEDDING_MODEL,
|
||||
GeminiEmbedder,
|
||||
GeminiEmbedderConfig,
|
||||
)
|
||||
from tests.embedder.embedder_fixtures import create_embedding_values
|
||||
|
||||
|
||||
def create_gemini_embedding(multiplier: float = 0.1) -> MagicMock:
|
||||
"""Create a mock Gemini embedding with specified value multiplier."""
|
||||
mock_embedding = MagicMock()
|
||||
mock_embedding.values = create_embedding_values(multiplier)
|
||||
return mock_embedding
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_gemini_response() -> MagicMock:
|
||||
"""Create a mock Gemini embeddings response."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.embeddings = [create_gemini_embedding()]
|
||||
return mock_result
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_gemini_batch_response() -> MagicMock:
|
||||
"""Create a mock Gemini batch embeddings response."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.embeddings = [
|
||||
create_gemini_embedding(0.1),
|
||||
create_gemini_embedding(0.2),
|
||||
create_gemini_embedding(0.3),
|
||||
]
|
||||
return mock_result
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_gemini_client() -> Generator[Any, Any, None]:
|
||||
"""Create a mocked Gemini client."""
|
||||
with patch('google.genai.Client') as mock_client:
|
||||
mock_instance = mock_client.return_value
|
||||
mock_instance.aio = MagicMock()
|
||||
mock_instance.aio.models = MagicMock()
|
||||
mock_instance.aio.models.embed_content = AsyncMock()
|
||||
yield mock_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gemini_embedder(mock_gemini_client: Any) -> GeminiEmbedder:
|
||||
"""Create a GeminiEmbedder with a mocked client."""
|
||||
config = GeminiEmbedderConfig(api_key='test_api_key')
|
||||
client = GeminiEmbedder(config=config)
|
||||
client.client = mock_gemini_client
|
||||
return client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_calls_api_correctly(
|
||||
gemini_embedder: GeminiEmbedder, mock_gemini_client: Any, mock_gemini_response: MagicMock
|
||||
) -> None:
|
||||
"""Test that create method correctly calls the API and processes the response."""
|
||||
# Setup
|
||||
mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
|
||||
|
||||
# Call method
|
||||
result = await gemini_embedder.create('Test input')
|
||||
|
||||
# Verify API is called with correct parameters
|
||||
mock_gemini_client.aio.models.embed_content.assert_called_once()
|
||||
_, kwargs = mock_gemini_client.aio.models.embed_content.call_args
|
||||
assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL
|
||||
assert kwargs['contents'] == ['Test input']
|
||||
|
||||
# Verify result is processed correctly
|
||||
assert result == mock_gemini_response.embeddings[0].values
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_batch_processes_multiple_inputs(
|
||||
gemini_embedder: GeminiEmbedder, mock_gemini_client: Any, mock_gemini_batch_response: MagicMock
|
||||
) -> None:
|
||||
"""Test that create_batch method correctly processes multiple inputs."""
|
||||
# Setup
|
||||
mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_batch_response
|
||||
input_batch = ['Input 1', 'Input 2', 'Input 3']
|
||||
|
||||
# Call method
|
||||
result = await gemini_embedder.create_batch(input_batch)
|
||||
|
||||
# Verify API is called with correct parameters
|
||||
mock_gemini_client.aio.models.embed_content.assert_called_once()
|
||||
_, kwargs = mock_gemini_client.aio.models.embed_content.call_args
|
||||
assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL
|
||||
assert kwargs['contents'] == input_batch
|
||||
|
||||
# Verify all results are processed correctly
|
||||
assert len(result) == 3
|
||||
assert result == [
|
||||
mock_gemini_batch_response.embeddings[0].values,
|
||||
mock_gemini_batch_response.embeddings[1].values,
|
||||
mock_gemini_batch_response.embeddings[2].values,
|
||||
]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main(['-xvs', __file__])
|
126
tests/embedder/test_openai.py
Normal file
126
tests/embedder/test_openai.py
Normal file
@ -0,0 +1,126 @@
|
||||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from graphiti_core.embedder.openai import (
|
||||
DEFAULT_EMBEDDING_MODEL,
|
||||
OpenAIEmbedder,
|
||||
OpenAIEmbedderConfig,
|
||||
)
|
||||
from tests.embedder.embedder_fixtures import create_embedding_values
|
||||
|
||||
|
||||
def create_openai_embedding(multiplier: float = 0.1) -> MagicMock:
|
||||
"""Create a mock OpenAI embedding with specified value multiplier."""
|
||||
mock_embedding = MagicMock()
|
||||
mock_embedding.embedding = create_embedding_values(multiplier)
|
||||
return mock_embedding
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_response() -> MagicMock:
|
||||
"""Create a mock OpenAI embeddings response."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.data = [create_openai_embedding()]
|
||||
return mock_result
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_batch_response() -> MagicMock:
|
||||
"""Create a mock OpenAI batch embeddings response."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.data = [
|
||||
create_openai_embedding(0.1),
|
||||
create_openai_embedding(0.2),
|
||||
create_openai_embedding(0.3),
|
||||
]
|
||||
return mock_result
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_client() -> Generator[Any, Any, None]:
|
||||
"""Create a mocked OpenAI client."""
|
||||
with patch('openai.AsyncOpenAI') as mock_client:
|
||||
mock_instance = mock_client.return_value
|
||||
mock_instance.embeddings = MagicMock()
|
||||
mock_instance.embeddings.create = AsyncMock()
|
||||
yield mock_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openai_embedder(mock_openai_client: Any) -> OpenAIEmbedder:
|
||||
"""Create an OpenAIEmbedder with a mocked client."""
|
||||
config = OpenAIEmbedderConfig(api_key='test_api_key')
|
||||
client = OpenAIEmbedder(config=config)
|
||||
client.client = mock_openai_client
|
||||
return client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_calls_api_correctly(
|
||||
openai_embedder: OpenAIEmbedder, mock_openai_client: Any, mock_openai_response: MagicMock
|
||||
) -> None:
|
||||
"""Test that create method correctly calls the API and processes the response."""
|
||||
# Setup
|
||||
mock_openai_client.embeddings.create.return_value = mock_openai_response
|
||||
|
||||
# Call method
|
||||
result = await openai_embedder.create('Test input')
|
||||
|
||||
# Verify API is called with correct parameters
|
||||
mock_openai_client.embeddings.create.assert_called_once()
|
||||
_, kwargs = mock_openai_client.embeddings.create.call_args
|
||||
assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL
|
||||
assert kwargs['input'] == 'Test input'
|
||||
|
||||
# Verify result is processed correctly
|
||||
assert result == mock_openai_response.data[0].embedding[: openai_embedder.config.embedding_dim]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_batch_processes_multiple_inputs(
|
||||
openai_embedder: OpenAIEmbedder, mock_openai_client: Any, mock_openai_batch_response: MagicMock
|
||||
) -> None:
|
||||
"""Test that create_batch method correctly processes multiple inputs."""
|
||||
# Setup
|
||||
mock_openai_client.embeddings.create.return_value = mock_openai_batch_response
|
||||
input_batch = ['Input 1', 'Input 2', 'Input 3']
|
||||
|
||||
# Call method
|
||||
result = await openai_embedder.create_batch(input_batch)
|
||||
|
||||
# Verify API is called with correct parameters
|
||||
mock_openai_client.embeddings.create.assert_called_once()
|
||||
_, kwargs = mock_openai_client.embeddings.create.call_args
|
||||
assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL
|
||||
assert kwargs['input'] == input_batch
|
||||
|
||||
# Verify all results are processed correctly
|
||||
assert len(result) == 3
|
||||
assert result == [
|
||||
mock_openai_batch_response.data[0].embedding[: openai_embedder.config.embedding_dim],
|
||||
mock_openai_batch_response.data[1].embedding[: openai_embedder.config.embedding_dim],
|
||||
mock_openai_batch_response.data[2].embedding[: openai_embedder.config.embedding_dim],
|
||||
]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main(['-xvs', __file__])
|
142
tests/embedder/test_voyage.py
Normal file
142
tests/embedder/test_voyage.py
Normal file
@ -0,0 +1,142 @@
|
||||
"""
|
||||
Copyright 2024, Zep Software, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from graphiti_core.embedder.voyage import (
|
||||
DEFAULT_EMBEDDING_MODEL,
|
||||
VoyageAIEmbedder,
|
||||
VoyageAIEmbedderConfig,
|
||||
)
|
||||
from tests.embedder.embedder_fixtures import create_embedding_values
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_voyageai_response() -> MagicMock:
|
||||
"""Create a mock VoyageAI embeddings response."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.embeddings = [create_embedding_values()]
|
||||
return mock_result
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_voyageai_batch_response() -> MagicMock:
|
||||
"""Create a mock VoyageAI batch embeddings response."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.embeddings = [
|
||||
create_embedding_values(0.1),
|
||||
create_embedding_values(0.2),
|
||||
create_embedding_values(0.3),
|
||||
]
|
||||
return mock_result
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_voyageai_client() -> Generator[Any, Any, None]:
|
||||
"""Create a mocked VoyageAI client."""
|
||||
with patch('voyageai.AsyncClient') as mock_client:
|
||||
mock_instance = mock_client.return_value
|
||||
mock_instance.embed = AsyncMock()
|
||||
yield mock_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def voyageai_embedder(mock_voyageai_client: Any) -> VoyageAIEmbedder:
|
||||
"""Create a VoyageAIEmbedder with a mocked client."""
|
||||
config = VoyageAIEmbedderConfig(api_key='test_api_key')
|
||||
client = VoyageAIEmbedder(config=config)
|
||||
client.client = mock_voyageai_client
|
||||
return client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_calls_api_correctly(
|
||||
voyageai_embedder: VoyageAIEmbedder,
|
||||
mock_voyageai_client: Any,
|
||||
mock_voyageai_response: MagicMock,
|
||||
) -> None:
|
||||
"""Test that create method correctly calls the API and processes the response."""
|
||||
# Setup
|
||||
mock_voyageai_client.embed.return_value = mock_voyageai_response
|
||||
|
||||
# Call method
|
||||
result = await voyageai_embedder.create('Test input')
|
||||
|
||||
# Verify API is called with correct parameters
|
||||
mock_voyageai_client.embed.assert_called_once()
|
||||
args, kwargs = mock_voyageai_client.embed.call_args
|
||||
assert args[0] == ['Test input']
|
||||
assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL
|
||||
|
||||
# Verify result is processed correctly
|
||||
expected_result = [
|
||||
float(x)
|
||||
for x in mock_voyageai_response.embeddings[0][: voyageai_embedder.config.embedding_dim]
|
||||
]
|
||||
assert result == expected_result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_batch_processes_multiple_inputs(
|
||||
voyageai_embedder: VoyageAIEmbedder,
|
||||
mock_voyageai_client: Any,
|
||||
mock_voyageai_batch_response: MagicMock,
|
||||
) -> None:
|
||||
"""Test that create_batch method correctly processes multiple inputs."""
|
||||
# Setup
|
||||
mock_voyageai_client.embed.return_value = mock_voyageai_batch_response
|
||||
input_batch = ['Input 1', 'Input 2', 'Input 3']
|
||||
|
||||
# Call method
|
||||
result = await voyageai_embedder.create_batch(input_batch)
|
||||
|
||||
# Verify API is called with correct parameters
|
||||
mock_voyageai_client.embed.assert_called_once()
|
||||
args, kwargs = mock_voyageai_client.embed.call_args
|
||||
assert args[0] == input_batch
|
||||
assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL
|
||||
|
||||
# Verify all results are processed correctly
|
||||
assert len(result) == 3
|
||||
expected_results = [
|
||||
[
|
||||
float(x)
|
||||
for x in mock_voyageai_batch_response.embeddings[0][
|
||||
: voyageai_embedder.config.embedding_dim
|
||||
]
|
||||
],
|
||||
[
|
||||
float(x)
|
||||
for x in mock_voyageai_batch_response.embeddings[1][
|
||||
: voyageai_embedder.config.embedding_dim
|
||||
]
|
||||
],
|
||||
[
|
||||
float(x)
|
||||
for x in mock_voyageai_batch_response.embeddings[2][
|
||||
: voyageai_embedder.config.embedding_dim
|
||||
]
|
||||
],
|
||||
]
|
||||
assert result == expected_results
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main(['-xvs', __file__])
|
Loading…
x
Reference in New Issue
Block a user