Simplify Azure Ai Search Tool (#6511)

## Why are these changes needed?

Simplified the azure ai search tool and fixed bugs in the code


## Related issue number

"Closes #6430 " 

## Checks

- [X] I've included any doc changes needed for
<https://microsoft.github.io/autogen/>. See
<https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to
build and test documentation locally.
- [X] I've added tests (if relevant) corresponding to the changes
introduced in this PR.
- [X] I've made sure all auto checks have passed.

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
Jay Prakash Thakur 2025-05-13 13:42:11 -07:00 committed by GitHub
parent 978cbd2e89
commit 87cf4f07dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 2111 additions and 2346 deletions

View File

@ -4,6 +4,7 @@ from ._ai_search import (
SearchQuery,
SearchResult,
SearchResults,
VectorizableTextQuery,
)
from ._config import AzureAISearchConfig
@ -14,4 +15,5 @@ __all__ = [
"SearchResult",
"SearchResults",
"AzureAISearchConfig",
"VectorizableTextQuery",
]

View File

@ -6,173 +6,180 @@ settings for authentication, search behavior, retry policies, and caching.
import logging
from typing import (
Any,
Dict,
List,
Literal,
Optional,
Type,
TypeVar,
Union,
)
from azure.core.credentials import AzureKeyCredential, TokenCredential
from pydantic import BaseModel, Field, model_validator
# Add explicit ignore for the specific model validator error
# pyright: reportArgumentType=false
# pyright: reportUnknownArgumentType=false
# pyright: reportUnknownVariableType=false
from azure.core.credentials import AzureKeyCredential
from azure.core.credentials_async import AsyncTokenCredential
from pydantic import BaseModel, Field, field_validator, model_validator
T = TypeVar("T", bound="AzureAISearchConfig")
logger = logging.getLogger(__name__)
QueryTypeLiteral = Literal["simple", "full", "semantic", "vector"]
DEFAULT_API_VERSION = "2023-10-01-preview"
class AzureAISearchConfig(BaseModel):
"""Configuration for Azure AI Search tool.
"""Configuration for Azure AI Search with validation.
This class defines the configuration parameters for :class:`AzureAISearchTool`.
It provides options for customizing search behavior including query types,
field selection, authentication, retry policies, and caching strategies.
This class defines the configuration parameters for Azure AI Search tools, including
authentication, search behavior, caching, and embedding settings.
.. note::
This class requires the :code:`azure` extra for the :code:`autogen-ext` package.
This class requires the ``azure`` extra for the ``autogen-ext`` package.
.. code-block:: bash
pip install -U "autogen-ext[azure]"
Example:
.. note::
**Prerequisites:**
1. An Azure AI Search service must be created in your Azure subscription.
2. The search index must be properly configured for your use case:
- For vector search: Index must have vector fields
- For semantic search: Index must have semantic configuration
- For hybrid search: Both vector fields and text fields must be configured
3. Required packages:
- Base functionality: ``azure-search-documents>=11.4.0``
- For Azure OpenAI embeddings: ``openai azure-identity``
- For OpenAI embeddings: ``openai``
Example Usage:
.. code-block:: python
from azure.core.credentials import AzureKeyCredential
from autogen_ext.tools.azure import AzureAISearchConfig
# Basic configuration for full-text search
config = AzureAISearchConfig(
name="doc_search",
endpoint="https://my-search.search.windows.net",
index_name="my-index",
credential=AzureKeyCredential("<your-key>"),
query_type="vector",
vector_fields=["embedding"],
name="doc-search",
endpoint="https://your-search.search.windows.net", # Your Azure AI Search endpoint
index_name="<your-index>", # Name of your search index
credential=AzureKeyCredential("<your-key>"), # Your Azure AI Search admin key
query_type="simple",
search_fields=["content", "title"], # Update with your searchable fields
top=5,
)
For more details, see:
* `Azure AI Search Overview <https://learn.microsoft.com/azure/search/search-what-is-azure-search>`_
* `Vector Search <https://learn.microsoft.com/azure/search/vector-search-overview>`_
# Configuration for vector search with Azure OpenAI embeddings
vector_config = AzureAISearchConfig(
name="vector-search",
endpoint="https://your-search.search.windows.net",
index_name="<your-index>",
credential=AzureKeyCredential("<your-key>"),
query_type="vector",
vector_fields=["embedding"], # Update with your vector field name
embedding_provider="azure_openai",
embedding_model="text-embedding-ada-002",
openai_endpoint="https://your-openai.openai.azure.com", # Your Azure OpenAI endpoint
openai_api_key="<your-openai-key>", # Your Azure OpenAI key
top=5,
)
Args:
name (str): Name for the tool instance, used to identify it in the agent's toolkit.
description (Optional[str]): Human-readable description of what this tool does and how to use it.
endpoint (str): The full URL of your Azure AI Search service, in the format
'https://<service-name>.search.windows.net'.
index_name (str): Name of the target search index in your Azure AI Search service.
The index must be pre-created and properly configured.
api_version (str): Azure AI Search REST API version to use. Defaults to '2023-11-01'.
Only change if you need specific features from a different API version.
credential (Union[AzureKeyCredential, TokenCredential]): Azure authentication credential:
- AzureKeyCredential: For API key authentication (admin/query key)
- TokenCredential: For Azure AD authentication (e.g., DefaultAzureCredential)
query_type (Literal["keyword", "fulltext", "vector", "semantic"]): The search query mode to use:
- 'keyword': Basic keyword search (default)
- 'fulltext': Full Lucene query syntax
- 'vector': Vector similarity search
- 'semantic': Semantic search using semantic configuration
search_fields (Optional[List[str]]): List of index fields to search within. If not specified,
searches all searchable fields. Example: ['title', 'content'].
select_fields (Optional[List[str]]): Fields to return in search results. If not specified,
returns all fields. Use to optimize response size.
vector_fields (Optional[List[str]]): Vector field names for vector search. Must be configured
in your search index as vector fields. Required for vector search.
top (Optional[int]): Maximum number of documents to return in search results.
Helps control response size and processing time.
retry_enabled (bool): Whether to enable retry policy for transient errors. Defaults to True.
retry_max_attempts (Optional[int]): Maximum number of retry attempts for failed requests. Defaults to 3.
retry_mode (Literal["fixed", "exponential"]): Retry backoff strategy: fixed or exponential. Defaults to "exponential".
enable_caching (bool): Whether to enable client-side caching of search results. Defaults to False.
cache_ttl_seconds (int): Time-to-live for cached search results in seconds. Defaults to 300 (5 minutes).
filter (Optional[str]): OData filter expression to refine search results.
# Configuration for hybrid search with semantic ranking
hybrid_config = AzureAISearchConfig(
name="hybrid-search",
endpoint="https://your-search.search.windows.net",
index_name="<your-index>",
credential=AzureKeyCredential("<your-key>"),
query_type="semantic",
semantic_config_name="<your-semantic-config>", # Name of your semantic configuration
search_fields=["content", "title"], # Update with your search fields
vector_fields=["embedding"], # Update with your vector field name
embedding_provider="openai",
embedding_model="text-embedding-ada-002",
openai_api_key="<your-openai-key>", # Your OpenAI API key
top=5,
)
"""
name: str = Field(description="The name of the tool")
description: Optional[str] = Field(default=None, description="A description of the tool")
endpoint: str = Field(description="The endpoint URL for your Azure AI Search service")
index_name: str = Field(description="The name of the search index to query")
api_version: str = Field(default="2023-11-01", description="API version to use")
credential: Union[AzureKeyCredential, TokenCredential] = Field(
description="The credential to use for authentication"
name: str = Field(description="The name of this tool instance")
description: Optional[str] = Field(default=None, description="Description explaining the tool's purpose")
endpoint: str = Field(description="The full URL of your Azure AI Search service")
index_name: str = Field(description="Name of the search index to query")
credential: Union[AzureKeyCredential, AsyncTokenCredential] = Field(
description="Azure credential for authentication (API key or token)"
)
query_type: Literal["keyword", "fulltext", "vector", "semantic"] = Field(
default="keyword",
description="Type of query to perform (keyword for classic, fulltext for Lucene, vector for embedding, semantic for semantic/AI search)",
api_version: str = Field(
default=DEFAULT_API_VERSION,
description=f"Azure AI Search API version to use. Defaults to {DEFAULT_API_VERSION}.",
)
search_fields: Optional[List[str]] = Field(default=None, description="Optional list of fields to search in")
select_fields: Optional[List[str]] = Field(default=None, description="Optional list of fields to return in results")
vector_fields: Optional[List[str]] = Field(
default=None, description="Optional list of vector fields for vector search"
query_type: QueryTypeLiteral = Field(
default="simple", description="Type of search to perform: simple, full, semantic, or vector"
)
top: Optional[int] = Field(default=None, description="Optional number of results to return")
filter: Optional[str] = Field(default=None, description="Optional OData filter expression to refine search results")
retry_enabled: bool = Field(default=True, description="Whether to enable retry policy for transient errors")
retry_max_attempts: Optional[int] = Field(
default=3, description="Maximum number of retry attempts for failed requests"
search_fields: Optional[List[str]] = Field(default=None, description="Fields to search within documents")
select_fields: Optional[List[str]] = Field(default=None, description="Fields to return in search results")
vector_fields: Optional[List[str]] = Field(default=None, description="Fields to use for vector search")
top: Optional[int] = Field(
default=None, description="Maximum number of results to return. For vector searches, acts as k in k-NN."
)
retry_mode: Literal["fixed", "exponential"] = Field(
default="exponential",
description="Retry backoff strategy: fixed or exponential",
filter: Optional[str] = Field(default=None, description="OData filter expression to refine search results")
semantic_config_name: Optional[str] = Field(
default=None, description="Semantic configuration name for enhanced results"
)
enable_caching: bool = Field(
default=False,
description="Whether to enable client-side caching of search results",
)
cache_ttl_seconds: int = Field(
default=300, # 5 minutes
description="Time-to-live for cached search results in seconds",
)
enable_caching: bool = Field(default=False, description="Whether to cache search results")
cache_ttl_seconds: int = Field(default=300, description="How long to cache results in seconds")
embedding_provider: Optional[str] = Field(
default=None,
description="Name of embedding provider to use (e.g., 'azure_openai', 'openai')",
)
embedding_model: Optional[str] = Field(default=None, description="Model name to use for generating embeddings")
embedding_dimension: Optional[int] = Field(
default=None, description="Dimension of embedding vectors produced by the model"
default=None, description="Name of embedding provider for client-side embeddings"
)
embedding_model: Optional[str] = Field(default=None, description="Model name for client-side embeddings")
openai_api_key: Optional[str] = Field(default=None, description="API key for OpenAI/Azure OpenAI embeddings")
openai_api_version: Optional[str] = Field(default=None, description="API version for Azure OpenAI embeddings")
openai_endpoint: Optional[str] = Field(default=None, description="Endpoint URL for Azure OpenAI embeddings")
model_config = {"arbitrary_types_allowed": True}
@classmethod
@model_validator(mode="before")
def validate_credentials(cls: Type[T], data: Any) -> Any:
"""Validate and convert credential data."""
if not isinstance(data, dict):
return data
@field_validator("endpoint")
def validate_endpoint(cls, v: str) -> str:
"""Validate that the endpoint is a valid URL."""
if not v.startswith(("http://", "https://")):
raise ValueError("endpoint must be a valid URL starting with http:// or https://")
return v
result = {}
@field_validator("query_type")
def normalize_query_type(cls, v: QueryTypeLiteral) -> QueryTypeLiteral:
"""Normalize query type to standard values."""
if not v:
return "simple"
for key, value in data.items():
result[str(key)] = value
if isinstance(v, str) and v.lower() == "fulltext":
return "full"
if "credential" in result:
credential = result["credential"]
return v
if isinstance(credential, dict) and "api_key" in credential:
api_key = str(credential["api_key"])
result["credential"] = AzureKeyCredential(api_key)
@field_validator("top")
def validate_top(cls, v: Optional[int]) -> Optional[int]:
"""Ensure top is a positive integer if provided."""
if v is not None and v <= 0:
raise ValueError("top must be a positive integer")
return v
return result
@model_validator(mode="after")
def validate_interdependent_fields(self) -> "AzureAISearchConfig":
"""Validate interdependent fields after all fields have been parsed."""
if self.query_type == "semantic" and not self.semantic_config_name:
raise ValueError("semantic_config_name must be provided when query_type is 'semantic'")
def model_dump(self, **kwargs: Any) -> Dict[str, Any]:
"""Custom model_dump to handle credentials."""
result: Dict[str, Any] = super().model_dump(**kwargs)
if self.query_type == "vector" and not self.vector_fields:
raise ValueError("vector_fields must be provided for vector search")
if isinstance(self.credential, AzureKeyCredential):
result["credential"] = {"type": "AzureKeyCredential"}
elif isinstance(self.credential, TokenCredential):
result["credential"] = {"type": "TokenCredential"}
if (
self.embedding_provider
and self.embedding_provider.lower() == "azure_openai"
and self.embedding_model
and not self.openai_endpoint
):
raise ValueError("openai_endpoint must be provided for azure_openai embedding provider")
return result
return self

View File

@ -1,7 +1,7 @@
"""Test fixtures for Azure AI Search tool tests."""
import warnings
from typing import Any, Dict, Generator, List, Protocol, Type, TypeVar, Union
from typing import Any, Dict, Iterator, List, Protocol, TypeVar, Union
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@ -9,6 +9,17 @@ from autogen_core import ComponentModel
T = TypeVar("T")
try:
from azure.core.credentials import AzureKeyCredential, TokenCredential
azure_sdk_available = True
except ImportError:
azure_sdk_available = False
skip_if_no_azure_sdk = pytest.mark.skipif(
not azure_sdk_available, reason="Azure SDK components (azure-search-documents, azure-identity) not available"
)
class AccessTokenProtocol(Protocol):
"""Protocol matching Azure AccessToken."""
@ -47,18 +58,13 @@ class MockTokenCredential:
return MockAccessToken("mock-token", 12345)
try:
from azure.core.credentials import AccessToken, AzureKeyCredential, TokenCredential
_access_token_type: Type[AccessToken] = AccessToken
azure_sdk_available = True
except ImportError:
AzureKeyCredential = MockAzureKeyCredential # type: ignore
TokenCredential = MockTokenCredential # type: ignore
_access_token_type = MockAccessToken # type: ignore
azure_sdk_available = False
CredentialType = Union[AzureKeyCredential, TokenCredential, MockAzureKeyCredential, MockTokenCredential, Any]
CredentialType = Union[
AzureKeyCredential, # pyright: ignore [reportPossiblyUnboundVariable]
TokenCredential, # pyright: ignore [reportPossiblyUnboundVariable]
MockAzureKeyCredential,
MockTokenCredential,
Any,
]
needs_azure_sdk = pytest.mark.skipif(not azure_sdk_available, reason="Azure SDK not available")
@ -70,10 +76,14 @@ warnings.filterwarnings(
@pytest.fixture
def mock_vectorized_query() -> Generator[MagicMock, None, None]:
def mock_vectorized_query() -> MagicMock:
"""Create a mock VectorizedQuery for testing."""
with patch("azure.search.documents.models.VectorizedQuery") as mock:
yield mock
if azure_sdk_available:
from azure.search.documents.models import VectorizedQuery
return MagicMock(spec=VectorizedQuery)
else:
return MagicMock()
@pytest.fixture
@ -87,7 +97,7 @@ def test_config() -> ComponentModel:
"endpoint": "https://test-search-service.search.windows.net",
"index_name": "test-index",
"api_version": "2023-10-01-Preview",
"credential": AzureKeyCredential("test-key") if azure_sdk_available else {"api_key": "test-key"},
"credential": AzureKeyCredential("test-key") if azure_sdk_available else {"api_key": "test-key"}, # pyright: ignore [reportPossiblyUnboundVariable]
"query_type": "keyword",
"search_fields": ["content", "title"],
"select_fields": ["id", "content", "title", "source"],
@ -106,7 +116,7 @@ def keyword_config() -> ComponentModel:
"description": "Keyword search tool",
"endpoint": "https://test-search-service.search.windows.net",
"index_name": "test-index",
"credential": AzureKeyCredential("test-key") if azure_sdk_available else {"api_key": "test-key"},
"credential": AzureKeyCredential("test-key") if azure_sdk_available else {"api_key": "test-key"}, # pyright: ignore [reportPossiblyUnboundVariable]
"query_type": "keyword",
"search_fields": ["content", "title"],
"select_fields": ["id", "content", "title", "source"],
@ -125,7 +135,7 @@ def vector_config() -> ComponentModel:
"endpoint": "https://test-search-service.search.windows.net",
"index_name": "test-index",
"api_version": "2023-10-01-Preview",
"credential": AzureKeyCredential("test-key") if azure_sdk_available else {"api_key": "test-key"},
"credential": AzureKeyCredential("test-key") if azure_sdk_available else {"api_key": "test-key"}, # pyright: ignore [reportPossiblyUnboundVariable]
"query_type": "vector",
"vector_fields": ["embedding"],
"select_fields": ["id", "content", "title", "source"],
@ -145,7 +155,7 @@ def hybrid_config() -> ComponentModel:
"endpoint": "https://test-search-service.search.windows.net",
"index_name": "test-index",
"api_version": "2023-10-01-Preview",
"credential": AzureKeyCredential("test-key") if azure_sdk_available else {"api_key": "test-key"},
"credential": AzureKeyCredential("test-key") if azure_sdk_available else {"api_key": "test-key"}, # pyright: ignore [reportPossiblyUnboundVariable]
"query_type": "keyword",
"search_fields": ["content", "title"],
"vector_fields": ["embedding"],
@ -196,108 +206,18 @@ class AsyncIterator:
@pytest.fixture
def mock_search_client(mock_search_response: List[Dict[str, Any]]) -> tuple[MagicMock, Any]:
"""Create a mock search client for testing."""
mock_client = MagicMock()
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
def mock_search_client(mock_search_response: List[Dict[str, Any]]) -> Iterator[MagicMock]:
"""Create a mock search client for testing, with the patch active."""
mock_client_instance = MagicMock()
mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance)
mock_client_instance.__aexit__ = AsyncMock(return_value=None)
search_results = AsyncIterator(mock_search_response)
mock_client.search = MagicMock(return_value=search_results)
search_results_iterator = AsyncIterator(mock_search_response)
mock_client_instance.search = MagicMock(return_value=search_results_iterator)
patcher = patch("azure.search.documents.aio.SearchClient", return_value=mock_client)
patch_target = "autogen_ext.tools.azure._ai_search.SearchClient"
patcher = patch(patch_target, return_value=mock_client_instance)
return mock_client, patcher
def test_validate_credentials_scenarios() -> None:
"""Test all validate_credentials scenarios to ensure full code coverage."""
import sys
from autogen_ext.tools.azure._config import AzureAISearchConfig
module_path = sys.modules[AzureAISearchConfig.__module__].__file__
if module_path is not None:
assert "autogen-ext" in module_path
data: Any = "not a dict"
result: Any = AzureAISearchConfig.validate_credentials(data) # type: ignore
assert result == data
data_empty: Dict[str, Any] = {}
result_empty: Dict[str, Any] = AzureAISearchConfig.validate_credentials(data_empty) # type: ignore
assert isinstance(result_empty, dict)
data_items: Dict[str, Any] = {"key1": "value1", "key2": "value2"}
result_items: Dict[str, Any] = AzureAISearchConfig.validate_credentials(data_items) # type: ignore
assert result_items["key1"] == "value1"
assert result_items["key2"] == "value2"
data_with_api_key: Dict[str, Any] = {
"name": "test",
"endpoint": "https://test.search.windows.net",
"index_name": "test-index",
"credential": {"api_key": "test-key"},
}
result_with_api_key: Dict[str, Any] = AzureAISearchConfig.validate_credentials(data_with_api_key) # type: ignore
cred = result_with_api_key["credential"] # type: ignore
assert isinstance(cred, (AzureKeyCredential, MockAzureKeyCredential))
assert hasattr(cred, "key")
assert cred.key == "test-key" # type: ignore
credential: Any = AzureKeyCredential("test-key")
data_with_credential: Dict[str, Any] = {
"name": "test",
"endpoint": "https://test.search.windows.net",
"index_name": "test-index",
"credential": credential,
}
result_with_credential: Dict[str, Any] = AzureAISearchConfig.validate_credentials(data_with_credential) # type: ignore
assert result_with_credential["credential"] is credential
data_without_api_key: Dict[str, Any] = {
"name": "test",
"endpoint": "https://test.search.windows.net",
"index_name": "test-index",
"credential": {"username": "test-user", "password": "test-pass"},
}
result_without_api_key: Dict[str, Any] = AzureAISearchConfig.validate_credentials(data_without_api_key) # type: ignore
assert result_without_api_key["credential"] == {"username": "test-user", "password": "test-pass"}
def test_model_dump_scenarios() -> None:
"""Test all model_dump scenarios to ensure full code coverage."""
import sys
from autogen_ext.tools.azure._config import AzureAISearchConfig
module_path = sys.modules[AzureAISearchConfig.__module__].__file__
if module_path is not None:
assert "autogen-ext" in module_path
config = AzureAISearchConfig(
name="test",
endpoint="https://endpoint",
index_name="index",
credential=AzureKeyCredential("key"), # type: ignore
)
result = config.model_dump()
assert result["credential"] == {"type": "AzureKeyCredential"}
if azure_sdk_available:
from azure.core.credentials import AccessToken
from azure.core.credentials import TokenCredential as RealTokenCredential
class TestTokenCredential(RealTokenCredential):
def get_token(self, *args: Any, **kwargs: Any) -> AccessToken:
"""Override of get_token method that returns proper type."""
return AccessToken("test-token", 12345)
config = AzureAISearchConfig(
name="test", endpoint="https://endpoint", index_name="index", credential=TestTokenCredential()
)
result = config.model_dump()
assert result["credential"] == {"type": "TokenCredential"}
else:
pytest.skip("Skipping TokenCredential test - Azure SDK not available")
patcher.start()
yield mock_client_instance
patcher.stop()

View File

@ -0,0 +1,297 @@
from typing import Any, Dict, cast
import pytest
from autogen_ext.tools.azure._config import AzureAISearchConfig, QueryTypeLiteral
from azure.core.credentials import AzureKeyCredential
from pydantic import ValidationError
from tests.tools.azure.conftest import azure_sdk_available
skip_if_no_azure_sdk = pytest.mark.skipif(
not azure_sdk_available, reason="Azure SDK components (azure-search-documents, azure-identity) not available"
)
# =====================================
# Basic Configuration Tests
# =====================================
def test_basic_config_creation() -> None:
"""Test that a basic valid configuration can be created."""
config = AzureAISearchConfig(
name="test_tool",
endpoint="https://test-search.search.windows.net",
index_name="test-index",
credential=AzureKeyCredential("test-key"),
)
assert config.name == "test_tool"
assert config.endpoint == "https://test-search.search.windows.net"
assert config.index_name == "test-index"
assert isinstance(config.credential, AzureKeyCredential)
assert config.query_type == "simple" # default value
def test_endpoint_validation() -> None:
"""Test that endpoint validation works correctly."""
valid_endpoints = ["https://test.search.windows.net", "http://localhost:8080"]
for endpoint in valid_endpoints:
config = AzureAISearchConfig(
name="test_tool",
endpoint=endpoint,
index_name="test-index",
credential=AzureKeyCredential("test-key"),
)
assert config.endpoint == endpoint
invalid_endpoints = [
"test.search.windows.net",
"ftp://test.search.windows.net",
"",
]
for endpoint in invalid_endpoints:
with pytest.raises(ValidationError) as exc:
AzureAISearchConfig(
name="test_tool",
endpoint=endpoint,
index_name="test-index",
credential=AzureKeyCredential("test-key"),
)
assert "endpoint must be a valid URL" in str(exc.value)
def test_top_validation() -> None:
"""Test validation of top parameter."""
valid_tops = [1, 5, 10, 100]
for top in valid_tops:
config = AzureAISearchConfig(
name="test_tool",
endpoint="https://test.search.windows.net",
index_name="test-index",
credential=AzureKeyCredential("test-key"),
top=top,
)
assert config.top == top
invalid_tops = [0, -1, -10]
for top in invalid_tops:
with pytest.raises(ValidationError) as exc:
AzureAISearchConfig(
name="test_tool",
endpoint="https://test.search.windows.net",
index_name="test-index",
credential=AzureKeyCredential("test-key"),
top=top,
)
assert "top must be a positive integer" in str(exc.value)
# =====================================
# Query Type Tests
# =====================================
def test_query_type_normalization() -> None:
"""Test that query_type normalization works correctly."""
standard_query_types = {
"simple": "simple",
"full": "full",
"semantic": "semantic",
"vector": "vector",
}
for input_type, expected_type in standard_query_types.items():
config_args: Dict[str, Any] = {
"name": "test_tool",
"endpoint": "https://test.search.windows.net",
"index_name": "test-index",
"credential": AzureKeyCredential("test-key"),
"query_type": cast(QueryTypeLiteral, input_type),
}
if input_type == "semantic":
config_args["semantic_config_name"] = "my-semantic-config"
elif input_type == "vector":
config_args["vector_fields"] = ["content_vector"]
config = AzureAISearchConfig(**config_args)
assert config.query_type == expected_type
with pytest.raises(ValidationError) as exc:
AzureAISearchConfig(
name="test_tool",
endpoint="https://test.search.windows.net",
index_name="test-index",
credential=AzureKeyCredential("test-key"),
query_type=cast(Any, "invalid_type"),
)
assert "Input should be" in str(exc.value)
def test_semantic_config_validation() -> None:
"""Test validation of semantic configuration."""
config = AzureAISearchConfig(
name="test_tool",
endpoint="https://test.search.windows.net",
index_name="test-index",
credential=AzureKeyCredential("test-key"),
query_type=cast(QueryTypeLiteral, "semantic"),
semantic_config_name="my-semantic-config",
)
assert config.query_type == "semantic"
assert config.semantic_config_name == "my-semantic-config"
with pytest.raises(ValidationError) as exc:
AzureAISearchConfig(
name="test_tool",
endpoint="https://test.search.windows.net",
index_name="test-index",
credential=AzureKeyCredential("test-key"),
query_type=cast(QueryTypeLiteral, "semantic"),
)
assert "semantic_config_name must be provided" in str(exc.value)
def test_vector_fields_validation() -> None:
"""Test validation of vector fields for vector search."""
config = AzureAISearchConfig(
name="test_tool",
endpoint="https://test.search.windows.net",
index_name="test-index",
credential=AzureKeyCredential("test-key"),
query_type=cast(QueryTypeLiteral, "vector"),
vector_fields=["content_vector"],
)
assert config.query_type == "vector"
assert config.vector_fields == ["content_vector"]
# =====================================
# Embedding Configuration Tests
# =====================================
def test_azure_openai_endpoint_validation() -> None:
"""Test validation of Azure OpenAI endpoint for client-side embeddings."""
config = AzureAISearchConfig(
name="test_tool",
endpoint="https://test.search.windows.net",
index_name="test-index",
credential=AzureKeyCredential("test-key"),
embedding_provider="azure_openai",
embedding_model="text-embedding-ada-002",
openai_endpoint="https://test.openai.azure.com",
)
assert config.embedding_provider == "azure_openai"
assert config.embedding_model == "text-embedding-ada-002"
assert config.openai_endpoint == "https://test.openai.azure.com"
with pytest.raises(ValidationError) as exc:
AzureAISearchConfig(
name="test_tool",
endpoint="https://test.search.windows.net",
index_name="test-index",
credential=AzureKeyCredential("test-key"),
embedding_provider="azure_openai",
embedding_model="text-embedding-ada-002",
)
assert "openai_endpoint must be provided for azure_openai" in str(exc.value)
config = AzureAISearchConfig(
name="test_tool",
endpoint="https://test.search.windows.net",
index_name="test-index",
credential=AzureKeyCredential("test-key"),
embedding_provider="openai",
embedding_model="text-embedding-ada-002",
)
assert config.embedding_provider == "openai"
assert config.embedding_model == "text-embedding-ada-002"
assert config.openai_endpoint is None
# =====================================
# Credential and Serialization Tests
# =====================================
def test_credential_validation() -> None:
"""Test credential validation scenarios."""
config = AzureAISearchConfig(
name="test_tool",
endpoint="https://test.search.windows.net",
index_name="test-index",
credential=AzureKeyCredential("test-key"),
)
assert isinstance(config.credential, AzureKeyCredential)
assert config.credential.key == "test-key"
if azure_sdk_available:
from azure.core.credentials import AccessToken
from azure.core.credentials_async import AsyncTokenCredential
class TestTokenCredential(AsyncTokenCredential):
async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
return AccessToken("test-token", 12345)
async def close(self) -> None:
pass
async def __aenter__(self) -> "TestTokenCredential":
return self
async def __aexit__(self, *args: Any) -> None:
await self.close()
config = AzureAISearchConfig(
name="test",
endpoint="https://endpoint",
index_name="index",
credential=TestTokenCredential(),
)
assert isinstance(config.credential, AsyncTokenCredential)
def test_model_dump_scenarios() -> None:
"""Test all model_dump scenarios to ensure full code coverage."""
config = AzureAISearchConfig(
name="test",
endpoint="https://endpoint",
index_name="index",
credential=AzureKeyCredential("key"),
)
result = config.model_dump()
assert isinstance(result["credential"], AzureKeyCredential)
assert result["credential"].key == "key"
if azure_sdk_available:
from azure.core.credentials import AccessToken
from azure.core.credentials_async import AsyncTokenCredential
class TestTokenCredential(AsyncTokenCredential):
async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
return AccessToken("test-token", 12345)
async def close(self) -> None:
pass
async def __aenter__(self) -> "TestTokenCredential":
return self
async def __aexit__(self, *args: Any) -> None:
await self.close()
config = AzureAISearchConfig(
name="test",
endpoint="https://endpoint",
index_name="index",
credential=TestTokenCredential(),
)
result = config.model_dump()
assert isinstance(result["credential"], AsyncTokenCredential)
else:
pytest.skip("Skipping TokenCredential test - Azure SDK not available")