mirror of
https://github.com/HKUDS/LightRAG.git
synced 2025-07-05 16:10:36 +00:00
856 lines
28 KiB
Python
856 lines
28 KiB
Python
"""
|
||
LightRAG Ollama Compatibility Interface Test Script
|
||
|
||
This script tests the LightRAG's Ollama compatibility interface, including:
|
||
1. Basic functionality tests (streaming and non-streaming responses)
|
||
2. Query mode tests (local, global, naive, hybrid)
|
||
3. Error handling tests (including streaming and non-streaming scenarios)
|
||
|
||
All responses use the JSON Lines format, complying with the Ollama API specification.
|
||
"""
|
||
|
||
import requests
|
||
import json
|
||
import argparse
|
||
import time
|
||
from typing import Dict, Any, Optional, List, Callable
|
||
from dataclasses import dataclass, asdict
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from enum import Enum, auto
|
||
|
||
|
||
class ErrorCode(Enum):
|
||
"""Error codes for MCP errors"""
|
||
|
||
InvalidRequest = auto()
|
||
InternalError = auto()
|
||
|
||
|
||
class McpError(Exception):
|
||
"""Base exception class for MCP errors"""
|
||
|
||
def __init__(self, code: ErrorCode, message: str):
|
||
self.code = code
|
||
self.message = message
|
||
super().__init__(message)
|
||
|
||
|
||
DEFAULT_CONFIG = {
|
||
"server": {
|
||
"host": "localhost",
|
||
"port": 9621,
|
||
"model": "lightrag:latest",
|
||
"timeout": 300,
|
||
"max_retries": 1,
|
||
"retry_delay": 1,
|
||
},
|
||
"test_cases": {
|
||
"basic": {"query": "唐僧有几个徒弟"},
|
||
"generate": {"query": "电视剧西游记导演是谁"},
|
||
},
|
||
}
|
||
|
||
# Example conversation history for testing
|
||
EXAMPLE_CONVERSATION = [
|
||
{"role": "user", "content": "你好"},
|
||
{"role": "assistant", "content": "你好!我是一个AI助手,很高兴为你服务。"},
|
||
{"role": "user", "content": "Who are you?"},
|
||
{"role": "assistant", "content": "I'm a Knowledge base query assistant."},
|
||
]
|
||
|
||
|
||
class OutputControl:
|
||
"""Output control class, manages the verbosity of test output"""
|
||
|
||
_verbose: bool = False
|
||
|
||
@classmethod
|
||
def set_verbose(cls, verbose: bool) -> None:
|
||
cls._verbose = verbose
|
||
|
||
@classmethod
|
||
def is_verbose(cls) -> bool:
|
||
return cls._verbose
|
||
|
||
|
||
@dataclass
|
||
class TestResult:
|
||
"""Test result data class"""
|
||
|
||
name: str
|
||
success: bool
|
||
duration: float
|
||
error: Optional[str] = None
|
||
timestamp: str = ""
|
||
|
||
def __post_init__(self):
|
||
if not self.timestamp:
|
||
self.timestamp = datetime.now().isoformat()
|
||
|
||
|
||
class TestStats:
|
||
"""Test statistics"""
|
||
|
||
def __init__(self):
|
||
self.results: List[TestResult] = []
|
||
self.start_time = datetime.now()
|
||
|
||
def add_result(self, result: TestResult):
|
||
self.results.append(result)
|
||
|
||
def export_results(self, path: str = "test_results.json"):
|
||
"""Export test results to a JSON file
|
||
Args:
|
||
path: Output file path
|
||
"""
|
||
results_data = {
|
||
"start_time": self.start_time.isoformat(),
|
||
"end_time": datetime.now().isoformat(),
|
||
"results": [asdict(r) for r in self.results],
|
||
"summary": {
|
||
"total": len(self.results),
|
||
"passed": sum(1 for r in self.results if r.success),
|
||
"failed": sum(1 for r in self.results if not r.success),
|
||
"total_duration": sum(r.duration for r in self.results),
|
||
},
|
||
}
|
||
|
||
with open(path, "w", encoding="utf-8") as f:
|
||
json.dump(results_data, f, ensure_ascii=False, indent=2)
|
||
print(f"\nTest results saved to: {path}")
|
||
|
||
def print_summary(self):
|
||
total = len(self.results)
|
||
passed = sum(1 for r in self.results if r.success)
|
||
failed = total - passed
|
||
duration = sum(r.duration for r in self.results)
|
||
|
||
print("\n=== Test Summary ===")
|
||
print(f"Start time: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}")
|
||
print(f"Total duration: {duration:.2f} seconds")
|
||
print(f"Total tests: {total}")
|
||
print(f"Passed: {passed}")
|
||
print(f"Failed: {failed}")
|
||
|
||
if failed > 0:
|
||
print("\nFailed tests:")
|
||
for result in self.results:
|
||
if not result.success:
|
||
print(f"- {result.name}: {result.error}")
|
||
|
||
|
||
def make_request(
|
||
url: str, data: Dict[str, Any], stream: bool = False, check_status: bool = True
|
||
) -> requests.Response:
|
||
"""Send an HTTP request with retry mechanism
|
||
Args:
|
||
url: Request URL
|
||
data: Request data
|
||
stream: Whether to use streaming response
|
||
check_status: Whether to check HTTP status code (default: True)
|
||
Returns:
|
||
requests.Response: Response object
|
||
|
||
Raises:
|
||
requests.exceptions.RequestException: Request failed after all retries
|
||
requests.exceptions.HTTPError: HTTP status code is not 200 (when check_status is True)
|
||
"""
|
||
server_config = CONFIG["server"]
|
||
max_retries = server_config["max_retries"]
|
||
retry_delay = server_config["retry_delay"]
|
||
timeout = server_config["timeout"]
|
||
|
||
for attempt in range(max_retries):
|
||
try:
|
||
response = requests.post(url, json=data, stream=stream, timeout=timeout)
|
||
if check_status and response.status_code != 200:
|
||
response.raise_for_status()
|
||
return response
|
||
except requests.exceptions.RequestException as e:
|
||
if attempt == max_retries - 1: # Last retry
|
||
raise
|
||
print(f"\nRequest failed, retrying in {retry_delay} seconds: {str(e)}")
|
||
time.sleep(retry_delay)
|
||
|
||
|
||
def load_config() -> Dict[str, Any]:
|
||
"""Load configuration file
|
||
|
||
First try to load from config.json in the current directory,
|
||
if it doesn't exist, use the default configuration
|
||
Returns:
|
||
Configuration dictionary
|
||
"""
|
||
config_path = Path("config.json")
|
||
if config_path.exists():
|
||
with open(config_path, "r", encoding="utf-8") as f:
|
||
return json.load(f)
|
||
return DEFAULT_CONFIG
|
||
|
||
|
||
def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) -> None:
|
||
"""Format and print JSON response data
|
||
Args:
|
||
data: Data dictionary to print
|
||
title: Title to print
|
||
indent: Number of spaces for JSON indentation
|
||
"""
|
||
if OutputControl.is_verbose():
|
||
if title:
|
||
print(f"\n=== {title} ===")
|
||
print(json.dumps(data, ensure_ascii=False, indent=indent))
|
||
|
||
|
||
# Global configuration
|
||
CONFIG = load_config()
|
||
|
||
|
||
def get_base_url(endpoint: str = "chat") -> str:
|
||
"""Return the base URL for specified endpoint
|
||
Args:
|
||
endpoint: API endpoint name (chat or generate)
|
||
Returns:
|
||
Complete URL for the endpoint
|
||
"""
|
||
server = CONFIG["server"]
|
||
return f"http://{server['host']}:{server['port']}/api/{endpoint}"
|
||
|
||
|
||
def create_chat_request_data(
|
||
content: str,
|
||
stream: bool = False,
|
||
model: str = None,
|
||
conversation_history: List[Dict[str, str]] = None,
|
||
) -> Dict[str, Any]:
|
||
"""Create chat request data
|
||
Args:
|
||
content: User message content
|
||
stream: Whether to use streaming response
|
||
model: Model name
|
||
conversation_history: List of previous conversation messages
|
||
history_turns: Number of history turns to include
|
||
Returns:
|
||
Dictionary containing complete chat request data
|
||
"""
|
||
messages = conversation_history or []
|
||
messages.append({"role": "user", "content": content})
|
||
|
||
return {
|
||
"model": model or CONFIG["server"]["model"],
|
||
"messages": messages,
|
||
"stream": stream,
|
||
}
|
||
|
||
|
||
def create_generate_request_data(
|
||
prompt: str,
|
||
system: str = None,
|
||
stream: bool = False,
|
||
model: str = None,
|
||
options: Dict[str, Any] = None,
|
||
) -> Dict[str, Any]:
|
||
"""Create generate request data
|
||
Args:
|
||
prompt: Generation prompt
|
||
system: System prompt
|
||
stream: Whether to use streaming response
|
||
model: Model name
|
||
options: Additional options
|
||
Returns:
|
||
Dictionary containing complete generate request data
|
||
"""
|
||
data = {
|
||
"model": model or CONFIG["server"]["model"],
|
||
"prompt": prompt,
|
||
"stream": stream,
|
||
}
|
||
if system:
|
||
data["system"] = system
|
||
if options:
|
||
data["options"] = options
|
||
return data
|
||
|
||
|
||
# Global test statistics
|
||
STATS = TestStats()
|
||
|
||
|
||
def run_test(func: Callable, name: str) -> None:
|
||
"""Run a test and record the results
|
||
Args:
|
||
func: Test function
|
||
name: Test name
|
||
"""
|
||
start_time = time.time()
|
||
try:
|
||
func()
|
||
duration = time.time() - start_time
|
||
STATS.add_result(TestResult(name, True, duration))
|
||
except Exception as e:
|
||
duration = time.time() - start_time
|
||
STATS.add_result(TestResult(name, False, duration, str(e)))
|
||
raise
|
||
|
||
|
||
def test_non_stream_chat() -> None:
|
||
"""Test non-streaming call to /api/chat endpoint"""
|
||
url = get_base_url()
|
||
|
||
# Send request with conversation history
|
||
data = create_chat_request_data(
|
||
CONFIG["test_cases"]["basic"]["query"],
|
||
stream=False,
|
||
conversation_history=EXAMPLE_CONVERSATION,
|
||
)
|
||
response = make_request(url, data)
|
||
|
||
# Print response
|
||
if OutputControl.is_verbose():
|
||
print("\n=== Non-streaming call response ===")
|
||
response_json = response.json()
|
||
|
||
# Print response content
|
||
print_json_response(
|
||
{"model": response_json["model"], "message": response_json["message"]},
|
||
"Response content",
|
||
)
|
||
|
||
|
||
def test_stream_chat() -> None:
|
||
"""Test streaming call to /api/chat endpoint
|
||
|
||
Use JSON Lines format to process streaming responses, each line is a complete JSON object.
|
||
Response format:
|
||
{
|
||
"model": "lightrag:latest",
|
||
"created_at": "2024-01-15T00:00:00Z",
|
||
"message": {
|
||
"role": "assistant",
|
||
"content": "Partial response content",
|
||
"images": null
|
||
},
|
||
"done": false
|
||
}
|
||
|
||
The last message will contain performance statistics, with done set to true.
|
||
"""
|
||
url = get_base_url()
|
||
|
||
# Send request with conversation history
|
||
data = create_chat_request_data(
|
||
CONFIG["test_cases"]["basic"]["query"],
|
||
stream=True,
|
||
conversation_history=EXAMPLE_CONVERSATION,
|
||
)
|
||
response = make_request(url, data, stream=True)
|
||
|
||
if OutputControl.is_verbose():
|
||
print("\n=== Streaming call response ===")
|
||
output_buffer = []
|
||
try:
|
||
for line in response.iter_lines():
|
||
if line: # Skip empty lines
|
||
try:
|
||
# Decode and parse JSON
|
||
data = json.loads(line.decode("utf-8"))
|
||
if data.get("done", True): # If it's the completion marker
|
||
if (
|
||
"total_duration" in data
|
||
): # Final performance statistics message
|
||
# print_json_response(data, "Performance statistics")
|
||
break
|
||
else: # Normal content message
|
||
message = data.get("message", {})
|
||
content = message.get("content", "")
|
||
if content: # Only collect non-empty content
|
||
output_buffer.append(content)
|
||
print(
|
||
content, end="", flush=True
|
||
) # Print content in real-time
|
||
except json.JSONDecodeError:
|
||
print("Error decoding JSON from response line")
|
||
finally:
|
||
response.close() # Ensure the response connection is closed
|
||
|
||
# Print a newline
|
||
print()
|
||
|
||
|
||
def test_query_modes() -> None:
|
||
"""Test different query mode prefixes
|
||
|
||
Supported query modes:
|
||
- /local: Local retrieval mode, searches only in highly relevant documents
|
||
- /global: Global retrieval mode, searches across all documents
|
||
- /naive: Naive mode, does not use any optimization strategies
|
||
- /hybrid: Hybrid mode (default), combines multiple strategies
|
||
- /mix: Mix mode
|
||
|
||
Each mode will return responses in the same format, but with different retrieval strategies.
|
||
"""
|
||
url = get_base_url()
|
||
modes = ["local", "global", "naive", "hybrid", "mix"]
|
||
|
||
for mode in modes:
|
||
if OutputControl.is_verbose():
|
||
print(f"\n=== Testing /{mode} mode ===")
|
||
data = create_chat_request_data(
|
||
f"/{mode} {CONFIG['test_cases']['basic']['query']}", stream=False
|
||
)
|
||
|
||
# Send request
|
||
response = make_request(url, data)
|
||
response_json = response.json()
|
||
|
||
# Print response content
|
||
print_json_response(
|
||
{"model": response_json["model"], "message": response_json["message"]}
|
||
)
|
||
|
||
|
||
def create_error_test_data(error_type: str) -> Dict[str, Any]:
|
||
"""Create request data for error testing
|
||
Args:
|
||
error_type: Error type, supported:
|
||
- empty_messages: Empty message list
|
||
- invalid_role: Invalid role field
|
||
- missing_content: Missing content field
|
||
|
||
Returns:
|
||
Request dictionary containing error data
|
||
"""
|
||
error_data = {
|
||
"empty_messages": {"model": "lightrag:latest", "messages": [], "stream": True},
|
||
"invalid_role": {
|
||
"model": "lightrag:latest",
|
||
"messages": [{"invalid_role": "user", "content": "Test message"}],
|
||
"stream": True,
|
||
},
|
||
"missing_content": {
|
||
"model": "lightrag:latest",
|
||
"messages": [{"role": "user"}],
|
||
"stream": True,
|
||
},
|
||
}
|
||
return error_data.get(error_type, error_data["empty_messages"])
|
||
|
||
|
||
def test_stream_error_handling() -> None:
|
||
"""Test error handling for streaming responses
|
||
|
||
Test scenarios:
|
||
1. Empty message list
|
||
2. Message format error (missing required fields)
|
||
|
||
Error responses should be returned immediately without establishing a streaming connection.
|
||
The status code should be 4xx, and detailed error information should be returned.
|
||
"""
|
||
url = get_base_url()
|
||
|
||
if OutputControl.is_verbose():
|
||
print("\n=== Testing streaming response error handling ===")
|
||
|
||
# Test empty message list
|
||
if OutputControl.is_verbose():
|
||
print("\n--- Testing empty message list (streaming) ---")
|
||
data = create_error_test_data("empty_messages")
|
||
response = make_request(url, data, stream=True, check_status=False)
|
||
print(f"Status code: {response.status_code}")
|
||
if response.status_code != 200:
|
||
print_json_response(response.json(), "Error message")
|
||
response.close()
|
||
|
||
# Test invalid role field
|
||
if OutputControl.is_verbose():
|
||
print("\n--- Testing invalid role field (streaming) ---")
|
||
data = create_error_test_data("invalid_role")
|
||
response = make_request(url, data, stream=True, check_status=False)
|
||
print(f"Status code: {response.status_code}")
|
||
if response.status_code != 200:
|
||
print_json_response(response.json(), "Error message")
|
||
response.close()
|
||
|
||
# Test missing content field
|
||
if OutputControl.is_verbose():
|
||
print("\n--- Testing missing content field (streaming) ---")
|
||
data = create_error_test_data("missing_content")
|
||
response = make_request(url, data, stream=True, check_status=False)
|
||
print(f"Status code: {response.status_code}")
|
||
if response.status_code != 200:
|
||
print_json_response(response.json(), "Error message")
|
||
response.close()
|
||
|
||
|
||
def test_error_handling() -> None:
|
||
"""Test error handling for non-streaming responses
|
||
|
||
Test scenarios:
|
||
1. Empty message list
|
||
2. Message format error (missing required fields)
|
||
|
||
Error response format:
|
||
{
|
||
"detail": "Error description"
|
||
}
|
||
|
||
All errors should return appropriate HTTP status codes and clear error messages.
|
||
"""
|
||
url = get_base_url()
|
||
|
||
if OutputControl.is_verbose():
|
||
print("\n=== Testing error handling ===")
|
||
|
||
# Test empty message list
|
||
if OutputControl.is_verbose():
|
||
print("\n--- Testing empty message list ---")
|
||
data = create_error_test_data("empty_messages")
|
||
data["stream"] = False # Change to non-streaming mode
|
||
response = make_request(url, data, check_status=False)
|
||
print(f"Status code: {response.status_code}")
|
||
print_json_response(response.json(), "Error message")
|
||
|
||
# Test invalid role field
|
||
if OutputControl.is_verbose():
|
||
print("\n--- Testing invalid role field ---")
|
||
data = create_error_test_data("invalid_role")
|
||
data["stream"] = False # Change to non-streaming mode
|
||
response = make_request(url, data, check_status=False)
|
||
print(f"Status code: {response.status_code}")
|
||
print_json_response(response.json(), "Error message")
|
||
|
||
# Test missing content field
|
||
if OutputControl.is_verbose():
|
||
print("\n--- Testing missing content field ---")
|
||
data = create_error_test_data("missing_content")
|
||
data["stream"] = False # Change to non-streaming mode
|
||
response = make_request(url, data, check_status=False)
|
||
print(f"Status code: {response.status_code}")
|
||
print_json_response(response.json(), "Error message")
|
||
|
||
|
||
def test_non_stream_generate() -> None:
|
||
"""Test non-streaming call to /api/generate endpoint"""
|
||
url = get_base_url("generate")
|
||
data = create_generate_request_data(
|
||
CONFIG["test_cases"]["generate"]["query"], stream=False
|
||
)
|
||
|
||
# Send request
|
||
response = make_request(url, data)
|
||
|
||
# Print response
|
||
if OutputControl.is_verbose():
|
||
print("\n=== Non-streaming generate response ===")
|
||
response_json = response.json()
|
||
|
||
# Print response content
|
||
print(json.dumps(response_json, ensure_ascii=False, indent=2))
|
||
|
||
|
||
def test_stream_generate() -> None:
|
||
"""Test streaming call to /api/generate endpoint"""
|
||
url = get_base_url("generate")
|
||
data = create_generate_request_data(
|
||
CONFIG["test_cases"]["generate"]["query"], stream=True
|
||
)
|
||
|
||
# Send request and get streaming response
|
||
response = make_request(url, data, stream=True)
|
||
|
||
if OutputControl.is_verbose():
|
||
print("\n=== Streaming generate response ===")
|
||
output_buffer = []
|
||
try:
|
||
for line in response.iter_lines():
|
||
if line: # Skip empty lines
|
||
try:
|
||
# Decode and parse JSON
|
||
data = json.loads(line.decode("utf-8"))
|
||
if data.get("done", True): # If it's the completion marker
|
||
if (
|
||
"total_duration" in data
|
||
): # Final performance statistics message
|
||
break
|
||
else: # Normal content message
|
||
content = data.get("response", "")
|
||
if content: # Only collect non-empty content
|
||
output_buffer.append(content)
|
||
print(
|
||
content, end="", flush=True
|
||
) # Print content in real-time
|
||
except json.JSONDecodeError:
|
||
print("Error decoding JSON from response line")
|
||
finally:
|
||
response.close() # Ensure the response connection is closed
|
||
|
||
# Print a newline
|
||
print()
|
||
|
||
|
||
def test_generate_with_system() -> None:
|
||
"""Test generate with system prompt"""
|
||
url = get_base_url("generate")
|
||
data = create_generate_request_data(
|
||
CONFIG["test_cases"]["generate"]["query"],
|
||
system="你是一个知识渊博的助手",
|
||
stream=False,
|
||
)
|
||
|
||
# Send request
|
||
response = make_request(url, data)
|
||
|
||
# Print response
|
||
if OutputControl.is_verbose():
|
||
print("\n=== Generate with system prompt response ===")
|
||
response_json = response.json()
|
||
|
||
# Print response content
|
||
print_json_response(
|
||
{
|
||
"model": response_json["model"],
|
||
"response": response_json["response"],
|
||
"done": response_json["done"],
|
||
},
|
||
"Response content",
|
||
)
|
||
|
||
|
||
def test_generate_error_handling() -> None:
|
||
"""Test error handling for generate endpoint"""
|
||
url = get_base_url("generate")
|
||
|
||
# Test empty prompt
|
||
if OutputControl.is_verbose():
|
||
print("\n=== Testing empty prompt ===")
|
||
data = create_generate_request_data("", stream=False)
|
||
response = make_request(url, data, check_status=False)
|
||
print(f"Status code: {response.status_code}")
|
||
print_json_response(response.json(), "Error message")
|
||
|
||
# Test invalid options
|
||
if OutputControl.is_verbose():
|
||
print("\n=== Testing invalid options ===")
|
||
data = create_generate_request_data(
|
||
CONFIG["test_cases"]["basic"]["query"],
|
||
options={"invalid_option": "value"},
|
||
stream=False,
|
||
)
|
||
response = make_request(url, data, check_status=False)
|
||
print(f"Status code: {response.status_code}")
|
||
print_json_response(response.json(), "Error message")
|
||
|
||
|
||
def test_generate_concurrent() -> None:
|
||
"""Test concurrent generate requests"""
|
||
import asyncio
|
||
import aiohttp
|
||
from contextlib import asynccontextmanager
|
||
|
||
@asynccontextmanager
|
||
async def get_session():
|
||
async with aiohttp.ClientSession() as session:
|
||
yield session
|
||
|
||
async def make_request(session, prompt: str, request_id: int):
|
||
url = get_base_url("generate")
|
||
data = create_generate_request_data(prompt, stream=False)
|
||
try:
|
||
async with session.post(url, json=data) as response:
|
||
if response.status != 200:
|
||
error_msg = (
|
||
f"Request {request_id} failed with status {response.status}"
|
||
)
|
||
if OutputControl.is_verbose():
|
||
print(f"\n{error_msg}")
|
||
raise McpError(ErrorCode.InternalError, error_msg)
|
||
result = await response.json()
|
||
if "error" in result:
|
||
error_msg = (
|
||
f"Request {request_id} returned error: {result['error']}"
|
||
)
|
||
if OutputControl.is_verbose():
|
||
print(f"\n{error_msg}")
|
||
raise McpError(ErrorCode.InternalError, error_msg)
|
||
return result
|
||
except Exception as e:
|
||
error_msg = f"Request {request_id} failed: {str(e)}"
|
||
if OutputControl.is_verbose():
|
||
print(f"\n{error_msg}")
|
||
raise McpError(ErrorCode.InternalError, error_msg)
|
||
|
||
async def run_concurrent_requests():
|
||
prompts = ["第一个问题", "第二个问题", "第三个问题", "第四个问题", "第五个问题"]
|
||
|
||
async with get_session() as session:
|
||
tasks = [
|
||
make_request(session, prompt, i + 1) for i, prompt in enumerate(prompts)
|
||
]
|
||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
||
success_results = []
|
||
error_messages = []
|
||
|
||
for i, result in enumerate(results):
|
||
if isinstance(result, Exception):
|
||
error_messages.append(f"Request {i+1} failed: {str(result)}")
|
||
else:
|
||
success_results.append((i + 1, result))
|
||
|
||
if error_messages:
|
||
for req_id, result in success_results:
|
||
if OutputControl.is_verbose():
|
||
print(f"\nRequest {req_id} succeeded:")
|
||
print_json_response(result)
|
||
|
||
error_summary = "\n".join(error_messages)
|
||
raise McpError(
|
||
ErrorCode.InternalError,
|
||
f"Some concurrent requests failed:\n{error_summary}",
|
||
)
|
||
|
||
return results
|
||
|
||
if OutputControl.is_verbose():
|
||
print("\n=== Testing concurrent generate requests ===")
|
||
|
||
# Run concurrent requests
|
||
try:
|
||
results = asyncio.run(run_concurrent_requests())
|
||
# all success, print out results
|
||
for i, result in enumerate(results, 1):
|
||
print(f"\nRequest {i} result:")
|
||
print_json_response(result)
|
||
except McpError:
|
||
# error message already printed
|
||
raise
|
||
|
||
|
||
def get_test_cases() -> Dict[str, Callable]:
|
||
"""Get all available test cases
|
||
Returns:
|
||
A dictionary mapping test names to test functions
|
||
"""
|
||
return {
|
||
"non_stream": test_non_stream_chat,
|
||
"stream": test_stream_chat,
|
||
"modes": test_query_modes,
|
||
"errors": test_error_handling,
|
||
"stream_errors": test_stream_error_handling,
|
||
"non_stream_generate": test_non_stream_generate,
|
||
"stream_generate": test_stream_generate,
|
||
"generate_with_system": test_generate_with_system,
|
||
"generate_errors": test_generate_error_handling,
|
||
"generate_concurrent": test_generate_concurrent,
|
||
}
|
||
|
||
|
||
def create_default_config():
|
||
"""Create a default configuration file"""
|
||
config_path = Path("config.json")
|
||
if not config_path.exists():
|
||
with open(config_path, "w", encoding="utf-8") as f:
|
||
json.dump(DEFAULT_CONFIG, f, ensure_ascii=False, indent=2)
|
||
print(f"Default configuration file created: {config_path}")
|
||
|
||
|
||
def parse_args() -> argparse.Namespace:
|
||
"""Parse command line arguments"""
|
||
parser = argparse.ArgumentParser(
|
||
description="LightRAG Ollama Compatibility Interface Testing",
|
||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||
epilog="""
|
||
Configuration file (config.json):
|
||
{
|
||
"server": {
|
||
"host": "localhost", # Server address
|
||
"port": 9621, # Server port
|
||
"model": "lightrag:latest" # Default model name
|
||
},
|
||
"test_cases": {
|
||
"basic": {
|
||
"query": "Test query", # Basic query text
|
||
"stream_query": "Stream query" # Stream query text
|
||
}
|
||
}
|
||
}
|
||
""",
|
||
)
|
||
parser.add_argument(
|
||
"-q",
|
||
"--quiet",
|
||
action="store_true",
|
||
help="Silent mode, only display test result summary",
|
||
)
|
||
parser.add_argument(
|
||
"-a",
|
||
"--ask",
|
||
type=str,
|
||
help="Specify query content, which will override the query settings in the configuration file",
|
||
)
|
||
parser.add_argument(
|
||
"--init-config", action="store_true", help="Create default configuration file"
|
||
)
|
||
parser.add_argument(
|
||
"--output",
|
||
type=str,
|
||
default="",
|
||
help="Test result output file path, default is not to output to a file",
|
||
)
|
||
parser.add_argument(
|
||
"--tests",
|
||
nargs="+",
|
||
choices=list(get_test_cases().keys()) + ["all"],
|
||
default=["all"],
|
||
help="Test cases to run, options: %(choices)s. Use 'all' to run all tests (except error tests)",
|
||
)
|
||
return parser.parse_args()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
args = parse_args()
|
||
|
||
# Set output mode
|
||
OutputControl.set_verbose(not args.quiet)
|
||
|
||
# If query content is specified, update the configuration
|
||
if args.ask:
|
||
CONFIG["test_cases"]["basic"]["query"] = args.ask
|
||
|
||
# If specified to create a configuration file
|
||
if args.init_config:
|
||
create_default_config()
|
||
exit(0)
|
||
|
||
test_cases = get_test_cases()
|
||
|
||
try:
|
||
if "all" in args.tests:
|
||
# Run all tests except error handling tests
|
||
if OutputControl.is_verbose():
|
||
print("\n【Chat API Tests】")
|
||
run_test(test_non_stream_chat, "Non-streaming Chat Test")
|
||
run_test(test_stream_chat, "Streaming Chat Test")
|
||
run_test(test_query_modes, "Chat Query Mode Test")
|
||
|
||
if OutputControl.is_verbose():
|
||
print("\n【Generate API Tests】")
|
||
run_test(test_non_stream_generate, "Non-streaming Generate Test")
|
||
run_test(test_stream_generate, "Streaming Generate Test")
|
||
run_test(test_generate_with_system, "Generate with System Prompt Test")
|
||
run_test(test_generate_concurrent, "Generate Concurrent Test")
|
||
else:
|
||
# Run specified tests
|
||
for test_name in args.tests:
|
||
if OutputControl.is_verbose():
|
||
print(f"\n【Running Test: {test_name}】")
|
||
run_test(test_cases[test_name], test_name)
|
||
except Exception as e:
|
||
print(f"\nAn error occurred: {str(e)}")
|
||
finally:
|
||
# Print test statistics
|
||
STATS.print_summary()
|
||
# If an output file path is specified, export the results
|
||
if args.output:
|
||
STATS.export_results(args.output)
|