diff --git a/tests/embedder/embedder_fixtures.py b/tests/embedder/embedder_fixtures.py new file mode 100644 index 00000000..73490b6f --- /dev/null +++ b/tests/embedder/embedder_fixtures.py @@ -0,0 +1,20 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + + +def create_embedding_values(multiplier: float = 0.1, dimension: int = 1536) -> list[float]: + """Create embedding values with the specified multiplier and dimension.""" + return [multiplier] * dimension diff --git a/tests/embedder/test_gemini.py b/tests/embedder/test_gemini.py new file mode 100644 index 00000000..649c1a57 --- /dev/null +++ b/tests/embedder/test_gemini.py @@ -0,0 +1,127 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from collections.abc import Generator +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from graphiti_core.embedder.gemini import ( + DEFAULT_EMBEDDING_MODEL, + GeminiEmbedder, + GeminiEmbedderConfig, +) +from tests.embedder.embedder_fixtures import create_embedding_values + + +def create_gemini_embedding(multiplier: float = 0.1) -> MagicMock: + """Create a mock Gemini embedding with specified value multiplier.""" + mock_embedding = MagicMock() + mock_embedding.values = create_embedding_values(multiplier) + return mock_embedding + + +@pytest.fixture +def mock_gemini_response() -> MagicMock: + """Create a mock Gemini embeddings response.""" + mock_result = MagicMock() + mock_result.embeddings = [create_gemini_embedding()] + return mock_result + + +@pytest.fixture +def mock_gemini_batch_response() -> MagicMock: + """Create a mock Gemini batch embeddings response.""" + mock_result = MagicMock() + mock_result.embeddings = [ + create_gemini_embedding(0.1), + create_gemini_embedding(0.2), + create_gemini_embedding(0.3), + ] + return mock_result + + +@pytest.fixture +def mock_gemini_client() -> Generator[Any, Any, None]: + """Create a mocked Gemini client.""" + with patch('google.genai.Client') as mock_client: + mock_instance = mock_client.return_value + mock_instance.aio = MagicMock() + mock_instance.aio.models = MagicMock() + mock_instance.aio.models.embed_content = AsyncMock() + yield mock_instance + + +@pytest.fixture +def gemini_embedder(mock_gemini_client: Any) -> GeminiEmbedder: + """Create a GeminiEmbedder with a mocked client.""" + config = GeminiEmbedderConfig(api_key='test_api_key') + client = GeminiEmbedder(config=config) + client.client = mock_gemini_client + return client + + +@pytest.mark.asyncio +async def test_create_calls_api_correctly( + gemini_embedder: GeminiEmbedder, mock_gemini_client: Any, mock_gemini_response: MagicMock +) -> None: + """Test that create method correctly calls the API and processes the response.""" + # Setup + mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response + + # Call method + result = await gemini_embedder.create('Test input') + + # Verify API is called with correct parameters + mock_gemini_client.aio.models.embed_content.assert_called_once() + _, kwargs = mock_gemini_client.aio.models.embed_content.call_args + assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL + assert kwargs['contents'] == ['Test input'] + + # Verify result is processed correctly + assert result == mock_gemini_response.embeddings[0].values + + +@pytest.mark.asyncio +async def test_create_batch_processes_multiple_inputs( + gemini_embedder: GeminiEmbedder, mock_gemini_client: Any, mock_gemini_batch_response: MagicMock +) -> None: + """Test that create_batch method correctly processes multiple inputs.""" + # Setup + mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_batch_response + input_batch = ['Input 1', 'Input 2', 'Input 3'] + + # Call method + result = await gemini_embedder.create_batch(input_batch) + + # Verify API is called with correct parameters + mock_gemini_client.aio.models.embed_content.assert_called_once() + _, kwargs = mock_gemini_client.aio.models.embed_content.call_args + assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL + assert kwargs['contents'] == input_batch + + # Verify all results are processed correctly + assert len(result) == 3 + assert result == [ + mock_gemini_batch_response.embeddings[0].values, + mock_gemini_batch_response.embeddings[1].values, + mock_gemini_batch_response.embeddings[2].values, + ] + + +if __name__ == '__main__': + pytest.main(['-xvs', __file__]) diff --git a/tests/embedder/test_openai.py b/tests/embedder/test_openai.py new file mode 100644 index 00000000..35783ce6 --- /dev/null +++ b/tests/embedder/test_openai.py @@ -0,0 +1,126 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from collections.abc import Generator +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from graphiti_core.embedder.openai import ( + DEFAULT_EMBEDDING_MODEL, + OpenAIEmbedder, + OpenAIEmbedderConfig, +) +from tests.embedder.embedder_fixtures import create_embedding_values + + +def create_openai_embedding(multiplier: float = 0.1) -> MagicMock: + """Create a mock OpenAI embedding with specified value multiplier.""" + mock_embedding = MagicMock() + mock_embedding.embedding = create_embedding_values(multiplier) + return mock_embedding + + +@pytest.fixture +def mock_openai_response() -> MagicMock: + """Create a mock OpenAI embeddings response.""" + mock_result = MagicMock() + mock_result.data = [create_openai_embedding()] + return mock_result + + +@pytest.fixture +def mock_openai_batch_response() -> MagicMock: + """Create a mock OpenAI batch embeddings response.""" + mock_result = MagicMock() + mock_result.data = [ + create_openai_embedding(0.1), + create_openai_embedding(0.2), + create_openai_embedding(0.3), + ] + return mock_result + + +@pytest.fixture +def mock_openai_client() -> Generator[Any, Any, None]: + """Create a mocked OpenAI client.""" + with patch('openai.AsyncOpenAI') as mock_client: + mock_instance = mock_client.return_value + mock_instance.embeddings = MagicMock() + mock_instance.embeddings.create = AsyncMock() + yield mock_instance + + +@pytest.fixture +def openai_embedder(mock_openai_client: Any) -> OpenAIEmbedder: + """Create an OpenAIEmbedder with a mocked client.""" + config = OpenAIEmbedderConfig(api_key='test_api_key') + client = OpenAIEmbedder(config=config) + client.client = mock_openai_client + return client + + +@pytest.mark.asyncio +async def test_create_calls_api_correctly( + openai_embedder: OpenAIEmbedder, mock_openai_client: Any, mock_openai_response: MagicMock +) -> None: + """Test that create method correctly calls the API and processes the response.""" + # Setup + mock_openai_client.embeddings.create.return_value = mock_openai_response + + # Call method + result = await openai_embedder.create('Test input') + + # Verify API is called with correct parameters + mock_openai_client.embeddings.create.assert_called_once() + _, kwargs = mock_openai_client.embeddings.create.call_args + assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL + assert kwargs['input'] == 'Test input' + + # Verify result is processed correctly + assert result == mock_openai_response.data[0].embedding[: openai_embedder.config.embedding_dim] + + +@pytest.mark.asyncio +async def test_create_batch_processes_multiple_inputs( + openai_embedder: OpenAIEmbedder, mock_openai_client: Any, mock_openai_batch_response: MagicMock +) -> None: + """Test that create_batch method correctly processes multiple inputs.""" + # Setup + mock_openai_client.embeddings.create.return_value = mock_openai_batch_response + input_batch = ['Input 1', 'Input 2', 'Input 3'] + + # Call method + result = await openai_embedder.create_batch(input_batch) + + # Verify API is called with correct parameters + mock_openai_client.embeddings.create.assert_called_once() + _, kwargs = mock_openai_client.embeddings.create.call_args + assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL + assert kwargs['input'] == input_batch + + # Verify all results are processed correctly + assert len(result) == 3 + assert result == [ + mock_openai_batch_response.data[0].embedding[: openai_embedder.config.embedding_dim], + mock_openai_batch_response.data[1].embedding[: openai_embedder.config.embedding_dim], + mock_openai_batch_response.data[2].embedding[: openai_embedder.config.embedding_dim], + ] + + +if __name__ == '__main__': + pytest.main(['-xvs', __file__]) diff --git a/tests/embedder/test_voyage.py b/tests/embedder/test_voyage.py new file mode 100644 index 00000000..b9e4c3d0 --- /dev/null +++ b/tests/embedder/test_voyage.py @@ -0,0 +1,142 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from collections.abc import Generator +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from graphiti_core.embedder.voyage import ( + DEFAULT_EMBEDDING_MODEL, + VoyageAIEmbedder, + VoyageAIEmbedderConfig, +) +from tests.embedder.embedder_fixtures import create_embedding_values + + +@pytest.fixture +def mock_voyageai_response() -> MagicMock: + """Create a mock VoyageAI embeddings response.""" + mock_result = MagicMock() + mock_result.embeddings = [create_embedding_values()] + return mock_result + + +@pytest.fixture +def mock_voyageai_batch_response() -> MagicMock: + """Create a mock VoyageAI batch embeddings response.""" + mock_result = MagicMock() + mock_result.embeddings = [ + create_embedding_values(0.1), + create_embedding_values(0.2), + create_embedding_values(0.3), + ] + return mock_result + + +@pytest.fixture +def mock_voyageai_client() -> Generator[Any, Any, None]: + """Create a mocked VoyageAI client.""" + with patch('voyageai.AsyncClient') as mock_client: + mock_instance = mock_client.return_value + mock_instance.embed = AsyncMock() + yield mock_instance + + +@pytest.fixture +def voyageai_embedder(mock_voyageai_client: Any) -> VoyageAIEmbedder: + """Create a VoyageAIEmbedder with a mocked client.""" + config = VoyageAIEmbedderConfig(api_key='test_api_key') + client = VoyageAIEmbedder(config=config) + client.client = mock_voyageai_client + return client + + +@pytest.mark.asyncio +async def test_create_calls_api_correctly( + voyageai_embedder: VoyageAIEmbedder, + mock_voyageai_client: Any, + mock_voyageai_response: MagicMock, +) -> None: + """Test that create method correctly calls the API and processes the response.""" + # Setup + mock_voyageai_client.embed.return_value = mock_voyageai_response + + # Call method + result = await voyageai_embedder.create('Test input') + + # Verify API is called with correct parameters + mock_voyageai_client.embed.assert_called_once() + args, kwargs = mock_voyageai_client.embed.call_args + assert args[0] == ['Test input'] + assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL + + # Verify result is processed correctly + expected_result = [ + float(x) + for x in mock_voyageai_response.embeddings[0][: voyageai_embedder.config.embedding_dim] + ] + assert result == expected_result + + +@pytest.mark.asyncio +async def test_create_batch_processes_multiple_inputs( + voyageai_embedder: VoyageAIEmbedder, + mock_voyageai_client: Any, + mock_voyageai_batch_response: MagicMock, +) -> None: + """Test that create_batch method correctly processes multiple inputs.""" + # Setup + mock_voyageai_client.embed.return_value = mock_voyageai_batch_response + input_batch = ['Input 1', 'Input 2', 'Input 3'] + + # Call method + result = await voyageai_embedder.create_batch(input_batch) + + # Verify API is called with correct parameters + mock_voyageai_client.embed.assert_called_once() + args, kwargs = mock_voyageai_client.embed.call_args + assert args[0] == input_batch + assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL + + # Verify all results are processed correctly + assert len(result) == 3 + expected_results = [ + [ + float(x) + for x in mock_voyageai_batch_response.embeddings[0][ + : voyageai_embedder.config.embedding_dim + ] + ], + [ + float(x) + for x in mock_voyageai_batch_response.embeddings[1][ + : voyageai_embedder.config.embedding_dim + ] + ], + [ + float(x) + for x in mock_voyageai_batch_response.embeddings[2][ + : voyageai_embedder.config.embedding_dim + ] + ], + ] + assert result == expected_results + + +if __name__ == '__main__': + pytest.main(['-xvs', __file__])