mirror of
https://github.com/microsoft/autogen.git
synced 2025-06-26 22:30:10 +00:00
Simplify Azure Ai Search Tool (#6511)
## Why are these changes needed? Simplified the azure ai search tool and fixed bugs in the code ## Related issue number "Closes #6430 " ## Checks - [X] I've included any doc changes needed for <https://microsoft.github.io/autogen/>. See <https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to build and test documentation locally. - [X] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [X] I've made sure all auto checks have passed. Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
parent
978cbd2e89
commit
87cf4f07dd
@ -4,6 +4,7 @@ from ._ai_search import (
|
||||
SearchQuery,
|
||||
SearchResult,
|
||||
SearchResults,
|
||||
VectorizableTextQuery,
|
||||
)
|
||||
from ._config import AzureAISearchConfig
|
||||
|
||||
@ -14,4 +15,5 @@ __all__ = [
|
||||
"SearchResult",
|
||||
"SearchResults",
|
||||
"AzureAISearchConfig",
|
||||
"VectorizableTextQuery",
|
||||
]
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -6,173 +6,180 @@ settings for authentication, search behavior, retry policies, and caching.
|
||||
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from azure.core.credentials import AzureKeyCredential, TokenCredential
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
# Add explicit ignore for the specific model validator error
|
||||
# pyright: reportArgumentType=false
|
||||
# pyright: reportUnknownArgumentType=false
|
||||
# pyright: reportUnknownVariableType=false
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
from azure.core.credentials_async import AsyncTokenCredential
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
T = TypeVar("T", bound="AzureAISearchConfig")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
QueryTypeLiteral = Literal["simple", "full", "semantic", "vector"]
|
||||
DEFAULT_API_VERSION = "2023-10-01-preview"
|
||||
|
||||
|
||||
class AzureAISearchConfig(BaseModel):
|
||||
"""Configuration for Azure AI Search tool.
|
||||
"""Configuration for Azure AI Search with validation.
|
||||
|
||||
This class defines the configuration parameters for :class:`AzureAISearchTool`.
|
||||
It provides options for customizing search behavior including query types,
|
||||
field selection, authentication, retry policies, and caching strategies.
|
||||
This class defines the configuration parameters for Azure AI Search tools, including
|
||||
authentication, search behavior, caching, and embedding settings.
|
||||
|
||||
.. note::
|
||||
|
||||
This class requires the :code:`azure` extra for the :code:`autogen-ext` package.
|
||||
This class requires the ``azure`` extra for the ``autogen-ext`` package.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U "autogen-ext[azure]"
|
||||
|
||||
Example:
|
||||
.. note::
|
||||
**Prerequisites:**
|
||||
|
||||
1. An Azure AI Search service must be created in your Azure subscription.
|
||||
2. The search index must be properly configured for your use case:
|
||||
|
||||
- For vector search: Index must have vector fields
|
||||
- For semantic search: Index must have semantic configuration
|
||||
- For hybrid search: Both vector fields and text fields must be configured
|
||||
3. Required packages:
|
||||
|
||||
- Base functionality: ``azure-search-documents>=11.4.0``
|
||||
- For Azure OpenAI embeddings: ``openai azure-identity``
|
||||
- For OpenAI embeddings: ``openai``
|
||||
|
||||
Example Usage:
|
||||
.. code-block:: python
|
||||
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
from autogen_ext.tools.azure import AzureAISearchConfig
|
||||
|
||||
# Basic configuration for full-text search
|
||||
config = AzureAISearchConfig(
|
||||
name="doc_search",
|
||||
endpoint="https://my-search.search.windows.net",
|
||||
index_name="my-index",
|
||||
credential=AzureKeyCredential("<your-key>"),
|
||||
query_type="vector",
|
||||
vector_fields=["embedding"],
|
||||
name="doc-search",
|
||||
endpoint="https://your-search.search.windows.net", # Your Azure AI Search endpoint
|
||||
index_name="<your-index>", # Name of your search index
|
||||
credential=AzureKeyCredential("<your-key>"), # Your Azure AI Search admin key
|
||||
query_type="simple",
|
||||
search_fields=["content", "title"], # Update with your searchable fields
|
||||
top=5,
|
||||
)
|
||||
|
||||
For more details, see:
|
||||
* `Azure AI Search Overview <https://learn.microsoft.com/azure/search/search-what-is-azure-search>`_
|
||||
* `Vector Search <https://learn.microsoft.com/azure/search/vector-search-overview>`_
|
||||
# Configuration for vector search with Azure OpenAI embeddings
|
||||
vector_config = AzureAISearchConfig(
|
||||
name="vector-search",
|
||||
endpoint="https://your-search.search.windows.net",
|
||||
index_name="<your-index>",
|
||||
credential=AzureKeyCredential("<your-key>"),
|
||||
query_type="vector",
|
||||
vector_fields=["embedding"], # Update with your vector field name
|
||||
embedding_provider="azure_openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
openai_endpoint="https://your-openai.openai.azure.com", # Your Azure OpenAI endpoint
|
||||
openai_api_key="<your-openai-key>", # Your Azure OpenAI key
|
||||
top=5,
|
||||
)
|
||||
|
||||
Args:
|
||||
name (str): Name for the tool instance, used to identify it in the agent's toolkit.
|
||||
description (Optional[str]): Human-readable description of what this tool does and how to use it.
|
||||
endpoint (str): The full URL of your Azure AI Search service, in the format
|
||||
'https://<service-name>.search.windows.net'.
|
||||
index_name (str): Name of the target search index in your Azure AI Search service.
|
||||
The index must be pre-created and properly configured.
|
||||
api_version (str): Azure AI Search REST API version to use. Defaults to '2023-11-01'.
|
||||
Only change if you need specific features from a different API version.
|
||||
credential (Union[AzureKeyCredential, TokenCredential]): Azure authentication credential:
|
||||
- AzureKeyCredential: For API key authentication (admin/query key)
|
||||
- TokenCredential: For Azure AD authentication (e.g., DefaultAzureCredential)
|
||||
query_type (Literal["keyword", "fulltext", "vector", "semantic"]): The search query mode to use:
|
||||
- 'keyword': Basic keyword search (default)
|
||||
- 'fulltext': Full Lucene query syntax
|
||||
- 'vector': Vector similarity search
|
||||
- 'semantic': Semantic search using semantic configuration
|
||||
search_fields (Optional[List[str]]): List of index fields to search within. If not specified,
|
||||
searches all searchable fields. Example: ['title', 'content'].
|
||||
select_fields (Optional[List[str]]): Fields to return in search results. If not specified,
|
||||
returns all fields. Use to optimize response size.
|
||||
vector_fields (Optional[List[str]]): Vector field names for vector search. Must be configured
|
||||
in your search index as vector fields. Required for vector search.
|
||||
top (Optional[int]): Maximum number of documents to return in search results.
|
||||
Helps control response size and processing time.
|
||||
retry_enabled (bool): Whether to enable retry policy for transient errors. Defaults to True.
|
||||
retry_max_attempts (Optional[int]): Maximum number of retry attempts for failed requests. Defaults to 3.
|
||||
retry_mode (Literal["fixed", "exponential"]): Retry backoff strategy: fixed or exponential. Defaults to "exponential".
|
||||
enable_caching (bool): Whether to enable client-side caching of search results. Defaults to False.
|
||||
cache_ttl_seconds (int): Time-to-live for cached search results in seconds. Defaults to 300 (5 minutes).
|
||||
filter (Optional[str]): OData filter expression to refine search results.
|
||||
# Configuration for hybrid search with semantic ranking
|
||||
hybrid_config = AzureAISearchConfig(
|
||||
name="hybrid-search",
|
||||
endpoint="https://your-search.search.windows.net",
|
||||
index_name="<your-index>",
|
||||
credential=AzureKeyCredential("<your-key>"),
|
||||
query_type="semantic",
|
||||
semantic_config_name="<your-semantic-config>", # Name of your semantic configuration
|
||||
search_fields=["content", "title"], # Update with your search fields
|
||||
vector_fields=["embedding"], # Update with your vector field name
|
||||
embedding_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
openai_api_key="<your-openai-key>", # Your OpenAI API key
|
||||
top=5,
|
||||
)
|
||||
"""
|
||||
|
||||
name: str = Field(description="The name of the tool")
|
||||
description: Optional[str] = Field(default=None, description="A description of the tool")
|
||||
endpoint: str = Field(description="The endpoint URL for your Azure AI Search service")
|
||||
index_name: str = Field(description="The name of the search index to query")
|
||||
api_version: str = Field(default="2023-11-01", description="API version to use")
|
||||
credential: Union[AzureKeyCredential, TokenCredential] = Field(
|
||||
description="The credential to use for authentication"
|
||||
name: str = Field(description="The name of this tool instance")
|
||||
description: Optional[str] = Field(default=None, description="Description explaining the tool's purpose")
|
||||
endpoint: str = Field(description="The full URL of your Azure AI Search service")
|
||||
index_name: str = Field(description="Name of the search index to query")
|
||||
credential: Union[AzureKeyCredential, AsyncTokenCredential] = Field(
|
||||
description="Azure credential for authentication (API key or token)"
|
||||
)
|
||||
query_type: Literal["keyword", "fulltext", "vector", "semantic"] = Field(
|
||||
default="keyword",
|
||||
description="Type of query to perform (keyword for classic, fulltext for Lucene, vector for embedding, semantic for semantic/AI search)",
|
||||
api_version: str = Field(
|
||||
default=DEFAULT_API_VERSION,
|
||||
description=f"Azure AI Search API version to use. Defaults to {DEFAULT_API_VERSION}.",
|
||||
)
|
||||
search_fields: Optional[List[str]] = Field(default=None, description="Optional list of fields to search in")
|
||||
select_fields: Optional[List[str]] = Field(default=None, description="Optional list of fields to return in results")
|
||||
vector_fields: Optional[List[str]] = Field(
|
||||
default=None, description="Optional list of vector fields for vector search"
|
||||
query_type: QueryTypeLiteral = Field(
|
||||
default="simple", description="Type of search to perform: simple, full, semantic, or vector"
|
||||
)
|
||||
top: Optional[int] = Field(default=None, description="Optional number of results to return")
|
||||
filter: Optional[str] = Field(default=None, description="Optional OData filter expression to refine search results")
|
||||
|
||||
retry_enabled: bool = Field(default=True, description="Whether to enable retry policy for transient errors")
|
||||
retry_max_attempts: Optional[int] = Field(
|
||||
default=3, description="Maximum number of retry attempts for failed requests"
|
||||
search_fields: Optional[List[str]] = Field(default=None, description="Fields to search within documents")
|
||||
select_fields: Optional[List[str]] = Field(default=None, description="Fields to return in search results")
|
||||
vector_fields: Optional[List[str]] = Field(default=None, description="Fields to use for vector search")
|
||||
top: Optional[int] = Field(
|
||||
default=None, description="Maximum number of results to return. For vector searches, acts as k in k-NN."
|
||||
)
|
||||
retry_mode: Literal["fixed", "exponential"] = Field(
|
||||
default="exponential",
|
||||
description="Retry backoff strategy: fixed or exponential",
|
||||
filter: Optional[str] = Field(default=None, description="OData filter expression to refine search results")
|
||||
semantic_config_name: Optional[str] = Field(
|
||||
default=None, description="Semantic configuration name for enhanced results"
|
||||
)
|
||||
|
||||
enable_caching: bool = Field(
|
||||
default=False,
|
||||
description="Whether to enable client-side caching of search results",
|
||||
)
|
||||
cache_ttl_seconds: int = Field(
|
||||
default=300, # 5 minutes
|
||||
description="Time-to-live for cached search results in seconds",
|
||||
)
|
||||
enable_caching: bool = Field(default=False, description="Whether to cache search results")
|
||||
cache_ttl_seconds: int = Field(default=300, description="How long to cache results in seconds")
|
||||
|
||||
embedding_provider: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Name of embedding provider to use (e.g., 'azure_openai', 'openai')",
|
||||
)
|
||||
embedding_model: Optional[str] = Field(default=None, description="Model name to use for generating embeddings")
|
||||
embedding_dimension: Optional[int] = Field(
|
||||
default=None, description="Dimension of embedding vectors produced by the model"
|
||||
default=None, description="Name of embedding provider for client-side embeddings"
|
||||
)
|
||||
embedding_model: Optional[str] = Field(default=None, description="Model name for client-side embeddings")
|
||||
openai_api_key: Optional[str] = Field(default=None, description="API key for OpenAI/Azure OpenAI embeddings")
|
||||
openai_api_version: Optional[str] = Field(default=None, description="API version for Azure OpenAI embeddings")
|
||||
openai_endpoint: Optional[str] = Field(default=None, description="Endpoint URL for Azure OpenAI embeddings")
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
@classmethod
|
||||
@model_validator(mode="before")
|
||||
def validate_credentials(cls: Type[T], data: Any) -> Any:
|
||||
"""Validate and convert credential data."""
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
@field_validator("endpoint")
|
||||
def validate_endpoint(cls, v: str) -> str:
|
||||
"""Validate that the endpoint is a valid URL."""
|
||||
if not v.startswith(("http://", "https://")):
|
||||
raise ValueError("endpoint must be a valid URL starting with http:// or https://")
|
||||
return v
|
||||
|
||||
result = {}
|
||||
@field_validator("query_type")
|
||||
def normalize_query_type(cls, v: QueryTypeLiteral) -> QueryTypeLiteral:
|
||||
"""Normalize query type to standard values."""
|
||||
if not v:
|
||||
return "simple"
|
||||
|
||||
for key, value in data.items():
|
||||
result[str(key)] = value
|
||||
if isinstance(v, str) and v.lower() == "fulltext":
|
||||
return "full"
|
||||
|
||||
if "credential" in result:
|
||||
credential = result["credential"]
|
||||
return v
|
||||
|
||||
if isinstance(credential, dict) and "api_key" in credential:
|
||||
api_key = str(credential["api_key"])
|
||||
result["credential"] = AzureKeyCredential(api_key)
|
||||
@field_validator("top")
|
||||
def validate_top(cls, v: Optional[int]) -> Optional[int]:
|
||||
"""Ensure top is a positive integer if provided."""
|
||||
if v is not None and v <= 0:
|
||||
raise ValueError("top must be a positive integer")
|
||||
return v
|
||||
|
||||
return result
|
||||
@model_validator(mode="after")
|
||||
def validate_interdependent_fields(self) -> "AzureAISearchConfig":
|
||||
"""Validate interdependent fields after all fields have been parsed."""
|
||||
if self.query_type == "semantic" and not self.semantic_config_name:
|
||||
raise ValueError("semantic_config_name must be provided when query_type is 'semantic'")
|
||||
|
||||
def model_dump(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
"""Custom model_dump to handle credentials."""
|
||||
result: Dict[str, Any] = super().model_dump(**kwargs)
|
||||
if self.query_type == "vector" and not self.vector_fields:
|
||||
raise ValueError("vector_fields must be provided for vector search")
|
||||
|
||||
if isinstance(self.credential, AzureKeyCredential):
|
||||
result["credential"] = {"type": "AzureKeyCredential"}
|
||||
elif isinstance(self.credential, TokenCredential):
|
||||
result["credential"] = {"type": "TokenCredential"}
|
||||
if (
|
||||
self.embedding_provider
|
||||
and self.embedding_provider.lower() == "azure_openai"
|
||||
and self.embedding_model
|
||||
and not self.openai_endpoint
|
||||
):
|
||||
raise ValueError("openai_endpoint must be provided for azure_openai embedding provider")
|
||||
|
||||
return result
|
||||
return self
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Test fixtures for Azure AI Search tool tests."""
|
||||
|
||||
import warnings
|
||||
from typing import Any, Dict, Generator, List, Protocol, Type, TypeVar, Union
|
||||
from typing import Any, Dict, Iterator, List, Protocol, TypeVar, Union
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -9,6 +9,17 @@ from autogen_core import ComponentModel
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
try:
|
||||
from azure.core.credentials import AzureKeyCredential, TokenCredential
|
||||
|
||||
azure_sdk_available = True
|
||||
except ImportError:
|
||||
azure_sdk_available = False
|
||||
|
||||
skip_if_no_azure_sdk = pytest.mark.skipif(
|
||||
not azure_sdk_available, reason="Azure SDK components (azure-search-documents, azure-identity) not available"
|
||||
)
|
||||
|
||||
|
||||
class AccessTokenProtocol(Protocol):
|
||||
"""Protocol matching Azure AccessToken."""
|
||||
@ -47,18 +58,13 @@ class MockTokenCredential:
|
||||
return MockAccessToken("mock-token", 12345)
|
||||
|
||||
|
||||
try:
|
||||
from azure.core.credentials import AccessToken, AzureKeyCredential, TokenCredential
|
||||
|
||||
_access_token_type: Type[AccessToken] = AccessToken
|
||||
azure_sdk_available = True
|
||||
except ImportError:
|
||||
AzureKeyCredential = MockAzureKeyCredential # type: ignore
|
||||
TokenCredential = MockTokenCredential # type: ignore
|
||||
_access_token_type = MockAccessToken # type: ignore
|
||||
azure_sdk_available = False
|
||||
|
||||
CredentialType = Union[AzureKeyCredential, TokenCredential, MockAzureKeyCredential, MockTokenCredential, Any]
|
||||
CredentialType = Union[
|
||||
AzureKeyCredential, # pyright: ignore [reportPossiblyUnboundVariable]
|
||||
TokenCredential, # pyright: ignore [reportPossiblyUnboundVariable]
|
||||
MockAzureKeyCredential,
|
||||
MockTokenCredential,
|
||||
Any,
|
||||
]
|
||||
|
||||
needs_azure_sdk = pytest.mark.skipif(not azure_sdk_available, reason="Azure SDK not available")
|
||||
|
||||
@ -70,10 +76,14 @@ warnings.filterwarnings(
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vectorized_query() -> Generator[MagicMock, None, None]:
|
||||
def mock_vectorized_query() -> MagicMock:
|
||||
"""Create a mock VectorizedQuery for testing."""
|
||||
with patch("azure.search.documents.models.VectorizedQuery") as mock:
|
||||
yield mock
|
||||
if azure_sdk_available:
|
||||
from azure.search.documents.models import VectorizedQuery
|
||||
|
||||
return MagicMock(spec=VectorizedQuery)
|
||||
else:
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -87,7 +97,7 @@ def test_config() -> ComponentModel:
|
||||
"endpoint": "https://test-search-service.search.windows.net",
|
||||
"index_name": "test-index",
|
||||
"api_version": "2023-10-01-Preview",
|
||||
"credential": AzureKeyCredential("test-key") if azure_sdk_available else {"api_key": "test-key"},
|
||||
"credential": AzureKeyCredential("test-key") if azure_sdk_available else {"api_key": "test-key"}, # pyright: ignore [reportPossiblyUnboundVariable]
|
||||
"query_type": "keyword",
|
||||
"search_fields": ["content", "title"],
|
||||
"select_fields": ["id", "content", "title", "source"],
|
||||
@ -106,7 +116,7 @@ def keyword_config() -> ComponentModel:
|
||||
"description": "Keyword search tool",
|
||||
"endpoint": "https://test-search-service.search.windows.net",
|
||||
"index_name": "test-index",
|
||||
"credential": AzureKeyCredential("test-key") if azure_sdk_available else {"api_key": "test-key"},
|
||||
"credential": AzureKeyCredential("test-key") if azure_sdk_available else {"api_key": "test-key"}, # pyright: ignore [reportPossiblyUnboundVariable]
|
||||
"query_type": "keyword",
|
||||
"search_fields": ["content", "title"],
|
||||
"select_fields": ["id", "content", "title", "source"],
|
||||
@ -125,7 +135,7 @@ def vector_config() -> ComponentModel:
|
||||
"endpoint": "https://test-search-service.search.windows.net",
|
||||
"index_name": "test-index",
|
||||
"api_version": "2023-10-01-Preview",
|
||||
"credential": AzureKeyCredential("test-key") if azure_sdk_available else {"api_key": "test-key"},
|
||||
"credential": AzureKeyCredential("test-key") if azure_sdk_available else {"api_key": "test-key"}, # pyright: ignore [reportPossiblyUnboundVariable]
|
||||
"query_type": "vector",
|
||||
"vector_fields": ["embedding"],
|
||||
"select_fields": ["id", "content", "title", "source"],
|
||||
@ -145,7 +155,7 @@ def hybrid_config() -> ComponentModel:
|
||||
"endpoint": "https://test-search-service.search.windows.net",
|
||||
"index_name": "test-index",
|
||||
"api_version": "2023-10-01-Preview",
|
||||
"credential": AzureKeyCredential("test-key") if azure_sdk_available else {"api_key": "test-key"},
|
||||
"credential": AzureKeyCredential("test-key") if azure_sdk_available else {"api_key": "test-key"}, # pyright: ignore [reportPossiblyUnboundVariable]
|
||||
"query_type": "keyword",
|
||||
"search_fields": ["content", "title"],
|
||||
"vector_fields": ["embedding"],
|
||||
@ -196,108 +206,18 @@ class AsyncIterator:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_search_client(mock_search_response: List[Dict[str, Any]]) -> tuple[MagicMock, Any]:
|
||||
"""Create a mock search client for testing."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=None)
|
||||
def mock_search_client(mock_search_response: List[Dict[str, Any]]) -> Iterator[MagicMock]:
|
||||
"""Create a mock search client for testing, with the patch active."""
|
||||
mock_client_instance = MagicMock()
|
||||
mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance)
|
||||
mock_client_instance.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
search_results = AsyncIterator(mock_search_response)
|
||||
mock_client.search = MagicMock(return_value=search_results)
|
||||
search_results_iterator = AsyncIterator(mock_search_response)
|
||||
mock_client_instance.search = MagicMock(return_value=search_results_iterator)
|
||||
|
||||
patcher = patch("azure.search.documents.aio.SearchClient", return_value=mock_client)
|
||||
patch_target = "autogen_ext.tools.azure._ai_search.SearchClient"
|
||||
patcher = patch(patch_target, return_value=mock_client_instance)
|
||||
|
||||
return mock_client, patcher
|
||||
|
||||
|
||||
def test_validate_credentials_scenarios() -> None:
|
||||
"""Test all validate_credentials scenarios to ensure full code coverage."""
|
||||
import sys
|
||||
|
||||
from autogen_ext.tools.azure._config import AzureAISearchConfig
|
||||
|
||||
module_path = sys.modules[AzureAISearchConfig.__module__].__file__
|
||||
if module_path is not None:
|
||||
assert "autogen-ext" in module_path
|
||||
|
||||
data: Any = "not a dict"
|
||||
result: Any = AzureAISearchConfig.validate_credentials(data) # type: ignore
|
||||
assert result == data
|
||||
|
||||
data_empty: Dict[str, Any] = {}
|
||||
result_empty: Dict[str, Any] = AzureAISearchConfig.validate_credentials(data_empty) # type: ignore
|
||||
assert isinstance(result_empty, dict)
|
||||
|
||||
data_items: Dict[str, Any] = {"key1": "value1", "key2": "value2"}
|
||||
result_items: Dict[str, Any] = AzureAISearchConfig.validate_credentials(data_items) # type: ignore
|
||||
assert result_items["key1"] == "value1"
|
||||
assert result_items["key2"] == "value2"
|
||||
|
||||
data_with_api_key: Dict[str, Any] = {
|
||||
"name": "test",
|
||||
"endpoint": "https://test.search.windows.net",
|
||||
"index_name": "test-index",
|
||||
"credential": {"api_key": "test-key"},
|
||||
}
|
||||
result_with_api_key: Dict[str, Any] = AzureAISearchConfig.validate_credentials(data_with_api_key) # type: ignore
|
||||
|
||||
cred = result_with_api_key["credential"] # type: ignore
|
||||
assert isinstance(cred, (AzureKeyCredential, MockAzureKeyCredential))
|
||||
assert hasattr(cred, "key")
|
||||
assert cred.key == "test-key" # type: ignore
|
||||
|
||||
credential: Any = AzureKeyCredential("test-key")
|
||||
data_with_credential: Dict[str, Any] = {
|
||||
"name": "test",
|
||||
"endpoint": "https://test.search.windows.net",
|
||||
"index_name": "test-index",
|
||||
"credential": credential,
|
||||
}
|
||||
result_with_credential: Dict[str, Any] = AzureAISearchConfig.validate_credentials(data_with_credential) # type: ignore
|
||||
assert result_with_credential["credential"] is credential
|
||||
|
||||
data_without_api_key: Dict[str, Any] = {
|
||||
"name": "test",
|
||||
"endpoint": "https://test.search.windows.net",
|
||||
"index_name": "test-index",
|
||||
"credential": {"username": "test-user", "password": "test-pass"},
|
||||
}
|
||||
result_without_api_key: Dict[str, Any] = AzureAISearchConfig.validate_credentials(data_without_api_key) # type: ignore
|
||||
assert result_without_api_key["credential"] == {"username": "test-user", "password": "test-pass"}
|
||||
|
||||
|
||||
def test_model_dump_scenarios() -> None:
|
||||
"""Test all model_dump scenarios to ensure full code coverage."""
|
||||
import sys
|
||||
|
||||
from autogen_ext.tools.azure._config import AzureAISearchConfig
|
||||
|
||||
module_path = sys.modules[AzureAISearchConfig.__module__].__file__
|
||||
if module_path is not None:
|
||||
assert "autogen-ext" in module_path
|
||||
|
||||
config = AzureAISearchConfig(
|
||||
name="test",
|
||||
endpoint="https://endpoint",
|
||||
index_name="index",
|
||||
credential=AzureKeyCredential("key"), # type: ignore
|
||||
)
|
||||
result = config.model_dump()
|
||||
assert result["credential"] == {"type": "AzureKeyCredential"}
|
||||
|
||||
if azure_sdk_available:
|
||||
from azure.core.credentials import AccessToken
|
||||
from azure.core.credentials import TokenCredential as RealTokenCredential
|
||||
|
||||
class TestTokenCredential(RealTokenCredential):
|
||||
def get_token(self, *args: Any, **kwargs: Any) -> AccessToken:
|
||||
"""Override of get_token method that returns proper type."""
|
||||
return AccessToken("test-token", 12345)
|
||||
|
||||
config = AzureAISearchConfig(
|
||||
name="test", endpoint="https://endpoint", index_name="index", credential=TestTokenCredential()
|
||||
)
|
||||
result = config.model_dump()
|
||||
assert result["credential"] == {"type": "TokenCredential"}
|
||||
else:
|
||||
pytest.skip("Skipping TokenCredential test - Azure SDK not available")
|
||||
patcher.start()
|
||||
yield mock_client_instance
|
||||
patcher.stop()
|
||||
|
@ -0,0 +1,297 @@
|
||||
from typing import Any, Dict, cast
|
||||
|
||||
import pytest
|
||||
from autogen_ext.tools.azure._config import AzureAISearchConfig, QueryTypeLiteral
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
from pydantic import ValidationError
|
||||
|
||||
from tests.tools.azure.conftest import azure_sdk_available
|
||||
|
||||
skip_if_no_azure_sdk = pytest.mark.skipif(
|
||||
not azure_sdk_available, reason="Azure SDK components (azure-search-documents, azure-identity) not available"
|
||||
)
|
||||
|
||||
# =====================================
|
||||
# Basic Configuration Tests
|
||||
# =====================================
|
||||
|
||||
|
||||
def test_basic_config_creation() -> None:
|
||||
"""Test that a basic valid configuration can be created."""
|
||||
config = AzureAISearchConfig(
|
||||
name="test_tool",
|
||||
endpoint="https://test-search.search.windows.net",
|
||||
index_name="test-index",
|
||||
credential=AzureKeyCredential("test-key"),
|
||||
)
|
||||
|
||||
assert config.name == "test_tool"
|
||||
assert config.endpoint == "https://test-search.search.windows.net"
|
||||
assert config.index_name == "test-index"
|
||||
assert isinstance(config.credential, AzureKeyCredential)
|
||||
assert config.query_type == "simple" # default value
|
||||
|
||||
|
||||
def test_endpoint_validation() -> None:
|
||||
"""Test that endpoint validation works correctly."""
|
||||
valid_endpoints = ["https://test.search.windows.net", "http://localhost:8080"]
|
||||
|
||||
for endpoint in valid_endpoints:
|
||||
config = AzureAISearchConfig(
|
||||
name="test_tool",
|
||||
endpoint=endpoint,
|
||||
index_name="test-index",
|
||||
credential=AzureKeyCredential("test-key"),
|
||||
)
|
||||
assert config.endpoint == endpoint
|
||||
|
||||
invalid_endpoints = [
|
||||
"test.search.windows.net",
|
||||
"ftp://test.search.windows.net",
|
||||
"",
|
||||
]
|
||||
|
||||
for endpoint in invalid_endpoints:
|
||||
with pytest.raises(ValidationError) as exc:
|
||||
AzureAISearchConfig(
|
||||
name="test_tool",
|
||||
endpoint=endpoint,
|
||||
index_name="test-index",
|
||||
credential=AzureKeyCredential("test-key"),
|
||||
)
|
||||
assert "endpoint must be a valid URL" in str(exc.value)
|
||||
|
||||
|
||||
def test_top_validation() -> None:
|
||||
"""Test validation of top parameter."""
|
||||
valid_tops = [1, 5, 10, 100]
|
||||
|
||||
for top in valid_tops:
|
||||
config = AzureAISearchConfig(
|
||||
name="test_tool",
|
||||
endpoint="https://test.search.windows.net",
|
||||
index_name="test-index",
|
||||
credential=AzureKeyCredential("test-key"),
|
||||
top=top,
|
||||
)
|
||||
assert config.top == top
|
||||
|
||||
invalid_tops = [0, -1, -10]
|
||||
|
||||
for top in invalid_tops:
|
||||
with pytest.raises(ValidationError) as exc:
|
||||
AzureAISearchConfig(
|
||||
name="test_tool",
|
||||
endpoint="https://test.search.windows.net",
|
||||
index_name="test-index",
|
||||
credential=AzureKeyCredential("test-key"),
|
||||
top=top,
|
||||
)
|
||||
assert "top must be a positive integer" in str(exc.value)
|
||||
|
||||
|
||||
# =====================================
|
||||
# Query Type Tests
|
||||
# =====================================
|
||||
|
||||
|
||||
def test_query_type_normalization() -> None:
|
||||
"""Test that query_type normalization works correctly."""
|
||||
standard_query_types = {
|
||||
"simple": "simple",
|
||||
"full": "full",
|
||||
"semantic": "semantic",
|
||||
"vector": "vector",
|
||||
}
|
||||
|
||||
for input_type, expected_type in standard_query_types.items():
|
||||
config_args: Dict[str, Any] = {
|
||||
"name": "test_tool",
|
||||
"endpoint": "https://test.search.windows.net",
|
||||
"index_name": "test-index",
|
||||
"credential": AzureKeyCredential("test-key"),
|
||||
"query_type": cast(QueryTypeLiteral, input_type),
|
||||
}
|
||||
|
||||
if input_type == "semantic":
|
||||
config_args["semantic_config_name"] = "my-semantic-config"
|
||||
elif input_type == "vector":
|
||||
config_args["vector_fields"] = ["content_vector"]
|
||||
|
||||
config = AzureAISearchConfig(**config_args)
|
||||
assert config.query_type == expected_type
|
||||
|
||||
with pytest.raises(ValidationError) as exc:
|
||||
AzureAISearchConfig(
|
||||
name="test_tool",
|
||||
endpoint="https://test.search.windows.net",
|
||||
index_name="test-index",
|
||||
credential=AzureKeyCredential("test-key"),
|
||||
query_type=cast(Any, "invalid_type"),
|
||||
)
|
||||
assert "Input should be" in str(exc.value)
|
||||
|
||||
|
||||
def test_semantic_config_validation() -> None:
|
||||
"""Test validation of semantic configuration."""
|
||||
config = AzureAISearchConfig(
|
||||
name="test_tool",
|
||||
endpoint="https://test.search.windows.net",
|
||||
index_name="test-index",
|
||||
credential=AzureKeyCredential("test-key"),
|
||||
query_type=cast(QueryTypeLiteral, "semantic"),
|
||||
semantic_config_name="my-semantic-config",
|
||||
)
|
||||
assert config.query_type == "semantic"
|
||||
assert config.semantic_config_name == "my-semantic-config"
|
||||
|
||||
with pytest.raises(ValidationError) as exc:
|
||||
AzureAISearchConfig(
|
||||
name="test_tool",
|
||||
endpoint="https://test.search.windows.net",
|
||||
index_name="test-index",
|
||||
credential=AzureKeyCredential("test-key"),
|
||||
query_type=cast(QueryTypeLiteral, "semantic"),
|
||||
)
|
||||
assert "semantic_config_name must be provided" in str(exc.value)
|
||||
|
||||
|
||||
def test_vector_fields_validation() -> None:
|
||||
"""Test validation of vector fields for vector search."""
|
||||
config = AzureAISearchConfig(
|
||||
name="test_tool",
|
||||
endpoint="https://test.search.windows.net",
|
||||
index_name="test-index",
|
||||
credential=AzureKeyCredential("test-key"),
|
||||
query_type=cast(QueryTypeLiteral, "vector"),
|
||||
vector_fields=["content_vector"],
|
||||
)
|
||||
assert config.query_type == "vector"
|
||||
assert config.vector_fields == ["content_vector"]
|
||||
|
||||
|
||||
# =====================================
|
||||
# Embedding Configuration Tests
|
||||
# =====================================
|
||||
|
||||
|
||||
def test_azure_openai_endpoint_validation() -> None:
|
||||
"""Test validation of Azure OpenAI endpoint for client-side embeddings."""
|
||||
config = AzureAISearchConfig(
|
||||
name="test_tool",
|
||||
endpoint="https://test.search.windows.net",
|
||||
index_name="test-index",
|
||||
credential=AzureKeyCredential("test-key"),
|
||||
embedding_provider="azure_openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
openai_endpoint="https://test.openai.azure.com",
|
||||
)
|
||||
assert config.embedding_provider == "azure_openai"
|
||||
assert config.embedding_model == "text-embedding-ada-002"
|
||||
assert config.openai_endpoint == "https://test.openai.azure.com"
|
||||
|
||||
with pytest.raises(ValidationError) as exc:
|
||||
AzureAISearchConfig(
|
||||
name="test_tool",
|
||||
endpoint="https://test.search.windows.net",
|
||||
index_name="test-index",
|
||||
credential=AzureKeyCredential("test-key"),
|
||||
embedding_provider="azure_openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
)
|
||||
assert "openai_endpoint must be provided for azure_openai" in str(exc.value)
|
||||
|
||||
config = AzureAISearchConfig(
|
||||
name="test_tool",
|
||||
endpoint="https://test.search.windows.net",
|
||||
index_name="test-index",
|
||||
credential=AzureKeyCredential("test-key"),
|
||||
embedding_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
)
|
||||
assert config.embedding_provider == "openai"
|
||||
assert config.embedding_model == "text-embedding-ada-002"
|
||||
assert config.openai_endpoint is None
|
||||
|
||||
|
||||
# =====================================
|
||||
# Credential and Serialization Tests
|
||||
# =====================================
|
||||
|
||||
|
||||
def test_credential_validation() -> None:
|
||||
"""Test credential validation scenarios."""
|
||||
config = AzureAISearchConfig(
|
||||
name="test_tool",
|
||||
endpoint="https://test.search.windows.net",
|
||||
index_name="test-index",
|
||||
credential=AzureKeyCredential("test-key"),
|
||||
)
|
||||
assert isinstance(config.credential, AzureKeyCredential)
|
||||
assert config.credential.key == "test-key"
|
||||
|
||||
if azure_sdk_available:
|
||||
from azure.core.credentials import AccessToken
|
||||
from azure.core.credentials_async import AsyncTokenCredential
|
||||
|
||||
class TestTokenCredential(AsyncTokenCredential):
|
||||
async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
|
||||
return AccessToken("test-token", 12345)
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
async def __aenter__(self) -> "TestTokenCredential":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args: Any) -> None:
|
||||
await self.close()
|
||||
|
||||
config = AzureAISearchConfig(
|
||||
name="test",
|
||||
endpoint="https://endpoint",
|
||||
index_name="index",
|
||||
credential=TestTokenCredential(),
|
||||
)
|
||||
assert isinstance(config.credential, AsyncTokenCredential)
|
||||
|
||||
|
||||
def test_model_dump_scenarios() -> None:
|
||||
"""Test all model_dump scenarios to ensure full code coverage."""
|
||||
config = AzureAISearchConfig(
|
||||
name="test",
|
||||
endpoint="https://endpoint",
|
||||
index_name="index",
|
||||
credential=AzureKeyCredential("key"),
|
||||
)
|
||||
result = config.model_dump()
|
||||
assert isinstance(result["credential"], AzureKeyCredential)
|
||||
assert result["credential"].key == "key"
|
||||
|
||||
if azure_sdk_available:
|
||||
from azure.core.credentials import AccessToken
|
||||
from azure.core.credentials_async import AsyncTokenCredential
|
||||
|
||||
class TestTokenCredential(AsyncTokenCredential):
|
||||
async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
|
||||
return AccessToken("test-token", 12345)
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
async def __aenter__(self) -> "TestTokenCredential":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args: Any) -> None:
|
||||
await self.close()
|
||||
|
||||
config = AzureAISearchConfig(
|
||||
name="test",
|
||||
endpoint="https://endpoint",
|
||||
index_name="index",
|
||||
credential=TestTokenCredential(),
|
||||
)
|
||||
result = config.model_dump()
|
||||
assert isinstance(result["credential"], AsyncTokenCredential)
|
||||
else:
|
||||
pytest.skip("Skipping TokenCredential test - Azure SDK not available")
|
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user