LightRAG/test_lightrag_ollama_chat.py

778 lines
25 KiB
Python
Raw Normal View History

2025-01-15 22:15:46 +08:00
"""
2025-01-17 13:36:31 +08:00
LightRAG Ollama Compatibility Interface Test Script
2025-01-15 22:15:46 +08:00
2025-01-17 13:36:31 +08:00
This script tests the LightRAG's Ollama compatibility interface, including:
1. Basic functionality tests (streaming and non-streaming responses)
2. Query mode tests (local, global, naive, hybrid)
3. Error handling tests (including streaming and non-streaming scenarios)
2025-01-15 22:15:46 +08:00
2025-01-17 13:36:31 +08:00
All responses use the JSON Lines format, complying with the Ollama API specification.
2025-01-15 22:15:46 +08:00
"""
import requests
import json
2025-01-15 22:15:46 +08:00
import argparse
import time
from typing import Dict, Any, Optional, List, Callable
2025-01-15 22:15:46 +08:00
from dataclasses import dataclass, asdict
from datetime import datetime
from pathlib import Path
2025-01-17 14:20:55 +08:00
2025-01-15 22:15:46 +08:00
class OutputControl:
"""Output control class, manages the verbosity of test output"""
2025-01-17 14:20:55 +08:00
2025-01-15 22:15:46 +08:00
_verbose: bool = False
@classmethod
def set_verbose(cls, verbose: bool) -> None:
cls._verbose = verbose
@classmethod
def is_verbose(cls) -> bool:
return cls._verbose
2025-01-17 14:20:55 +08:00
2025-01-15 22:15:46 +08:00
@dataclass
class TestResult:
"""Test result data class"""
2025-01-17 14:20:55 +08:00
2025-01-15 22:15:46 +08:00
name: str
success: bool
duration: float
error: Optional[str] = None
timestamp: str = ""
2025-01-17 13:36:31 +08:00
2025-01-15 22:15:46 +08:00
def __post_init__(self):
if not self.timestamp:
self.timestamp = datetime.now().isoformat()
2025-01-17 14:20:55 +08:00
2025-01-15 22:15:46 +08:00
class TestStats:
"""Test statistics"""
2025-01-17 14:20:55 +08:00
2025-01-15 22:15:46 +08:00
def __init__(self):
self.results: List[TestResult] = []
self.start_time = datetime.now()
2025-01-17 13:36:31 +08:00
2025-01-15 22:15:46 +08:00
def add_result(self, result: TestResult):
self.results.append(result)
2025-01-17 13:36:31 +08:00
2025-01-15 22:15:46 +08:00
def export_results(self, path: str = "test_results.json"):
"""Export test results to a JSON file
2025-01-15 22:15:46 +08:00
Args:
path: Output file path
2025-01-15 22:15:46 +08:00
"""
results_data = {
"start_time": self.start_time.isoformat(),
"end_time": datetime.now().isoformat(),
"results": [asdict(r) for r in self.results],
"summary": {
"total": len(self.results),
"passed": sum(1 for r in self.results if r.success),
"failed": sum(1 for r in self.results if not r.success),
2025-01-17 14:20:55 +08:00
"total_duration": sum(r.duration for r in self.results),
},
2025-01-15 22:15:46 +08:00
}
2025-01-17 13:36:31 +08:00
2025-01-15 22:15:46 +08:00
with open(path, "w", encoding="utf-8") as f:
json.dump(results_data, f, ensure_ascii=False, indent=2)
print(f"\nTest results saved to: {path}")
2025-01-17 13:36:31 +08:00
2025-01-15 22:15:46 +08:00
def print_summary(self):
total = len(self.results)
passed = sum(1 for r in self.results if r.success)
failed = total - passed
duration = sum(r.duration for r in self.results)
2025-01-17 13:36:31 +08:00
print("\n=== Test Summary ===")
print(f"Start time: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Total duration: {duration:.2f} seconds")
print(f"Total tests: {total}")
print(f"Passed: {passed}")
print(f"Failed: {failed}")
2025-01-17 13:36:31 +08:00
2025-01-15 22:15:46 +08:00
if failed > 0:
print("\nFailed tests:")
2025-01-15 22:15:46 +08:00
for result in self.results:
if not result.success:
print(f"- {result.name}: {result.error}")
2025-01-17 14:20:55 +08:00
2025-01-15 22:15:46 +08:00
DEFAULT_CONFIG = {
"server": {
"host": "localhost",
"port": 9621,
"model": "lightrag:latest",
"timeout": 30,
"max_retries": 3,
2025-01-17 14:20:55 +08:00
"retry_delay": 1,
2025-01-15 22:15:46 +08:00
},
"test_cases": {
"basic": {"query": "唐僧有几个徒弟"},
"generate": {"query": "电视剧西游记导演是谁"},
},
2025-01-15 22:15:46 +08:00
}
2025-01-17 14:20:55 +08:00
def make_request(
url: str, data: Dict[str, Any], stream: bool = False
) -> requests.Response:
"""Send an HTTP request with retry mechanism
2025-01-15 22:15:46 +08:00
Args:
url: Request URL
data: Request data
stream: Whether to use streaming response
2025-01-15 22:15:46 +08:00
Returns:
requests.Response: Response object
2025-01-17 13:36:31 +08:00
2025-01-15 22:15:46 +08:00
Raises:
requests.exceptions.RequestException: Request failed after all retries
2025-01-15 22:15:46 +08:00
"""
server_config = CONFIG["server"]
max_retries = server_config["max_retries"]
retry_delay = server_config["retry_delay"]
timeout = server_config["timeout"]
2025-01-17 13:36:31 +08:00
2025-01-15 22:15:46 +08:00
for attempt in range(max_retries):
try:
2025-01-17 14:20:55 +08:00
response = requests.post(url, json=data, stream=stream, timeout=timeout)
2025-01-15 22:15:46 +08:00
return response
except requests.exceptions.RequestException as e:
if attempt == max_retries - 1: # Last retry
2025-01-15 22:15:46 +08:00
raise
print(f"\nRequest failed, retrying in {retry_delay} seconds: {str(e)}")
2025-01-15 22:15:46 +08:00
time.sleep(retry_delay)
2025-01-17 14:20:55 +08:00
2025-01-15 22:15:46 +08:00
def load_config() -> Dict[str, Any]:
"""Load configuration file
2025-01-17 13:36:31 +08:00
First try to load from config.json in the current directory,
if it doesn't exist, use the default configuration
2025-01-15 22:15:46 +08:00
Returns:
Configuration dictionary
2025-01-15 22:15:46 +08:00
"""
config_path = Path("config.json")
if config_path.exists():
with open(config_path, "r", encoding="utf-8") as f:
return json.load(f)
return DEFAULT_CONFIG
2025-01-17 14:20:55 +08:00
2025-01-15 22:15:46 +08:00
def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) -> None:
"""Format and print JSON response data
2025-01-15 22:15:46 +08:00
Args:
data: Data dictionary to print
title: Title to print
indent: Number of spaces for JSON indentation
2025-01-15 22:15:46 +08:00
"""
if OutputControl.is_verbose():
if title:
print(f"\n=== {title} ===")
print(json.dumps(data, ensure_ascii=False, indent=indent))
2025-01-17 14:20:55 +08:00
# Global configuration
2025-01-15 22:15:46 +08:00
CONFIG = load_config()
2025-01-17 14:20:55 +08:00
def get_base_url(endpoint: str = "chat") -> str:
"""Return the base URL for specified endpoint
Args:
endpoint: API endpoint name (chat or generate)
Returns:
Complete URL for the endpoint
"""
2025-01-15 22:15:46 +08:00
server = CONFIG["server"]
return f"http://{server['host']}:{server['port']}/api/{endpoint}"
2025-01-15 22:15:46 +08:00
2025-01-17 14:20:55 +08:00
def create_chat_request_data(
2025-01-17 14:20:55 +08:00
content: str, stream: bool = False, model: str = None
2025-01-15 22:15:46 +08:00
) -> Dict[str, Any]:
"""Create chat request data
2025-01-15 22:15:46 +08:00
Args:
content: User message content
stream: Whether to use streaming response
model: Model name
2025-01-15 22:15:46 +08:00
Returns:
Dictionary containing complete chat request data
2025-01-15 22:15:46 +08:00
"""
return {
"model": model or CONFIG["server"]["model"],
2025-01-17 14:20:55 +08:00
"messages": [{"role": "user", "content": content}],
"stream": stream,
}
2025-01-15 22:15:46 +08:00
def create_generate_request_data(
prompt: str,
system: str = None,
stream: bool = False,
model: str = None,
options: Dict[str, Any] = None,
) -> Dict[str, Any]:
"""Create generate request data
Args:
prompt: Generation prompt
system: System prompt
stream: Whether to use streaming response
model: Model name
options: Additional options
Returns:
Dictionary containing complete generate request data
"""
data = {
"model": model or CONFIG["server"]["model"],
"prompt": prompt,
"stream": stream,
}
if system:
data["system"] = system
if options:
data["options"] = options
return data
2025-01-17 14:20:55 +08:00
# Global test statistics
2025-01-15 22:15:46 +08:00
STATS = TestStats()
2025-01-17 14:20:55 +08:00
2025-01-15 22:15:46 +08:00
def run_test(func: Callable, name: str) -> None:
"""Run a test and record the results
2025-01-15 22:15:46 +08:00
Args:
func: Test function
name: Test name
2025-01-15 22:15:46 +08:00
"""
start_time = time.time()
try:
func()
duration = time.time() - start_time
STATS.add_result(TestResult(name, True, duration))
except Exception as e:
duration = time.time() - start_time
STATS.add_result(TestResult(name, False, duration, str(e)))
raise
2025-01-17 14:20:55 +08:00
def test_non_stream_chat() -> None:
"""Test non-streaming call to /api/chat endpoint"""
2025-01-15 22:15:46 +08:00
url = get_base_url()
data = create_chat_request_data(
CONFIG["test_cases"]["basic"]["query"], stream=False
)
2025-01-17 13:36:31 +08:00
# Send request
2025-01-15 22:15:46 +08:00
response = make_request(url, data)
2025-01-17 13:36:31 +08:00
# Print response
2025-01-15 22:15:46 +08:00
if OutputControl.is_verbose():
print("\n=== Non-streaming call response ===")
response_json = response.json()
2025-01-17 13:36:31 +08:00
# Print response content
2025-01-17 14:20:55 +08:00
print_json_response(
{"model": response_json["model"], "message": response_json["message"]},
"Response content",
)
def test_stream_chat() -> None:
"""Test streaming call to /api/chat endpoint
2025-01-17 13:36:31 +08:00
Use JSON Lines format to process streaming responses, each line is a complete JSON object.
Response format:
2025-01-15 22:15:46 +08:00
{
"model": "lightrag:latest",
2025-01-15 22:15:46 +08:00
"created_at": "2024-01-15T00:00:00Z",
"message": {
"role": "assistant",
"content": "Partial response content",
2025-01-15 22:15:46 +08:00
"images": null
},
"done": false
}
2025-01-17 13:36:31 +08:00
The last message will contain performance statistics, with done set to true.
2025-01-15 22:15:46 +08:00
"""
url = get_base_url()
data = create_chat_request_data(CONFIG["test_cases"]["basic"]["query"], stream=True)
2025-01-17 13:36:31 +08:00
# Send request and get streaming response
2025-01-15 22:15:46 +08:00
response = make_request(url, data, stream=True)
2025-01-17 13:36:31 +08:00
2025-01-15 22:15:46 +08:00
if OutputControl.is_verbose():
print("\n=== Streaming call response ===")
output_buffer = []
try:
2025-01-15 22:15:46 +08:00
for line in response.iter_lines():
if line: # Skip empty lines
2025-01-15 22:15:46 +08:00
try:
# Decode and parse JSON
2025-01-17 14:20:55 +08:00
data = json.loads(line.decode("utf-8"))
if data.get("done", True): # If it's the completion marker
2025-01-17 14:20:55 +08:00
if (
"total_duration" in data
): # Final performance statistics message
# print_json_response(data, "Performance statistics")
2025-01-15 22:15:46 +08:00
break
else: # Normal content message
2025-01-15 22:15:46 +08:00
message = data.get("message", {})
content = message.get("content", "")
if content: # Only collect non-empty content
2025-01-15 22:15:46 +08:00
output_buffer.append(content)
2025-01-17 14:20:55 +08:00
print(
content, end="", flush=True
) # Print content in real-time
2025-01-15 22:15:46 +08:00
except json.JSONDecodeError:
print("Error decoding JSON from response line")
finally:
response.close() # Ensure the response connection is closed
2025-01-17 13:36:31 +08:00
# Print a newline
2025-01-15 22:15:46 +08:00
print()
2025-01-17 14:20:55 +08:00
def test_query_modes() -> None:
"""Test different query mode prefixes
2025-01-17 13:36:31 +08:00
Supported query modes:
- /local: Local retrieval mode, searches only in highly relevant documents
- /global: Global retrieval mode, searches across all documents
- /naive: Naive mode, does not use any optimization strategies
- /hybrid: Hybrid mode (default), combines multiple strategies
- /mix: Mix mode
2025-01-17 13:36:31 +08:00
Each mode will return responses in the same format, but with different retrieval strategies.
2025-01-15 22:15:46 +08:00
"""
url = get_base_url()
modes = ["local", "global", "naive", "hybrid", "mix"]
2025-01-17 13:36:31 +08:00
2025-01-15 22:15:46 +08:00
for mode in modes:
if OutputControl.is_verbose():
print(f"\n=== Testing /{mode} mode ===")
data = create_chat_request_data(
2025-01-17 14:20:55 +08:00
f"/{mode} {CONFIG['test_cases']['basic']['query']}", stream=False
2025-01-15 22:15:46 +08:00
)
2025-01-17 13:36:31 +08:00
# Send request
2025-01-15 22:15:46 +08:00
response = make_request(url, data)
response_json = response.json()
2025-01-17 13:36:31 +08:00
# Print response content
2025-01-17 14:20:55 +08:00
print_json_response(
{"model": response_json["model"], "message": response_json["message"]}
)
2025-01-15 22:15:46 +08:00
def create_error_test_data(error_type: str) -> Dict[str, Any]:
"""Create request data for error testing
2025-01-15 22:15:46 +08:00
Args:
error_type: Error type, supported:
- empty_messages: Empty message list
- invalid_role: Invalid role field
- missing_content: Missing content field
2025-01-17 13:36:31 +08:00
2025-01-15 22:15:46 +08:00
Returns:
Request dictionary containing error data
2025-01-15 22:15:46 +08:00
"""
error_data = {
2025-01-17 14:20:55 +08:00
"empty_messages": {"model": "lightrag:latest", "messages": [], "stream": True},
2025-01-15 22:15:46 +08:00
"invalid_role": {
"model": "lightrag:latest",
2025-01-17 14:20:55 +08:00
"messages": [{"invalid_role": "user", "content": "Test message"}],
"stream": True,
2025-01-15 22:15:46 +08:00
},
"missing_content": {
"model": "lightrag:latest",
2025-01-17 14:20:55 +08:00
"messages": [{"role": "user"}],
"stream": True,
},
2025-01-15 22:15:46 +08:00
}
return error_data.get(error_type, error_data["empty_messages"])
2025-01-17 14:20:55 +08:00
def test_stream_error_handling() -> None:
"""Test error handling for streaming responses
2025-01-17 13:36:31 +08:00
Test scenarios:
1. Empty message list
2. Message format error (missing required fields)
2025-01-17 13:36:31 +08:00
Error responses should be returned immediately without establishing a streaming connection.
The status code should be 4xx, and detailed error information should be returned.
2025-01-15 22:15:46 +08:00
"""
url = get_base_url()
2025-01-17 13:36:31 +08:00
2025-01-15 22:15:46 +08:00
if OutputControl.is_verbose():
print("\n=== Testing streaming response error handling ===")
2025-01-17 13:36:31 +08:00
# Test empty message list
2025-01-15 22:15:46 +08:00
if OutputControl.is_verbose():
print("\n--- Testing empty message list (streaming) ---")
2025-01-15 22:15:46 +08:00
data = create_error_test_data("empty_messages")
response = make_request(url, data, stream=True)
print(f"Status code: {response.status_code}")
2025-01-15 22:15:46 +08:00
if response.status_code != 200:
print_json_response(response.json(), "Error message")
2025-01-15 22:15:46 +08:00
response.close()
2025-01-17 13:36:31 +08:00
# Test invalid role field
2025-01-15 22:15:46 +08:00
if OutputControl.is_verbose():
print("\n--- Testing invalid role field (streaming) ---")
2025-01-15 22:15:46 +08:00
data = create_error_test_data("invalid_role")
response = make_request(url, data, stream=True)
print(f"Status code: {response.status_code}")
2025-01-15 22:15:46 +08:00
if response.status_code != 200:
print_json_response(response.json(), "Error message")
2025-01-15 22:15:46 +08:00
response.close()
# Test missing content field
2025-01-15 22:15:46 +08:00
if OutputControl.is_verbose():
print("\n--- Testing missing content field (streaming) ---")
2025-01-15 22:15:46 +08:00
data = create_error_test_data("missing_content")
response = make_request(url, data, stream=True)
print(f"Status code: {response.status_code}")
2025-01-15 22:15:46 +08:00
if response.status_code != 200:
print_json_response(response.json(), "Error message")
2025-01-15 22:15:46 +08:00
response.close()
2025-01-17 14:20:55 +08:00
def test_error_handling() -> None:
"""Test error handling for non-streaming responses
2025-01-17 13:36:31 +08:00
Test scenarios:
1. Empty message list
2. Message format error (missing required fields)
2025-01-17 13:36:31 +08:00
Error response format:
2025-01-15 22:15:46 +08:00
{
"detail": "Error description"
2025-01-15 22:15:46 +08:00
}
2025-01-17 13:36:31 +08:00
All errors should return appropriate HTTP status codes and clear error messages.
2025-01-15 22:15:46 +08:00
"""
url = get_base_url()
2025-01-17 13:36:31 +08:00
2025-01-15 22:15:46 +08:00
if OutputControl.is_verbose():
print("\n=== Testing error handling ===")
2025-01-17 13:36:31 +08:00
# Test empty message list
2025-01-15 22:15:46 +08:00
if OutputControl.is_verbose():
print("\n--- Testing empty message list ---")
2025-01-15 22:15:46 +08:00
data = create_error_test_data("empty_messages")
data["stream"] = False # Change to non-streaming mode
2025-01-15 22:15:46 +08:00
response = make_request(url, data)
print(f"Status code: {response.status_code}")
print_json_response(response.json(), "Error message")
2025-01-17 13:36:31 +08:00
# Test invalid role field
2025-01-15 22:15:46 +08:00
if OutputControl.is_verbose():
print("\n--- Testing invalid role field ---")
2025-01-15 22:15:46 +08:00
data = create_error_test_data("invalid_role")
data["stream"] = False # Change to non-streaming mode
2025-01-15 22:15:46 +08:00
response = make_request(url, data)
print(f"Status code: {response.status_code}")
print_json_response(response.json(), "Error message")
2025-01-15 22:15:46 +08:00
# Test missing content field
2025-01-15 22:15:46 +08:00
if OutputControl.is_verbose():
print("\n--- Testing missing content field ---")
2025-01-15 22:15:46 +08:00
data = create_error_test_data("missing_content")
data["stream"] = False # Change to non-streaming mode
2025-01-15 22:15:46 +08:00
response = make_request(url, data)
print(f"Status code: {response.status_code}")
print_json_response(response.json(), "Error message")
2025-01-15 22:15:46 +08:00
2025-01-17 14:20:55 +08:00
def test_non_stream_generate() -> None:
"""Test non-streaming call to /api/generate endpoint"""
url = get_base_url("generate")
data = create_generate_request_data(
CONFIG["test_cases"]["generate"]["query"], stream=False
)
# Send request
response = make_request(url, data)
# Print response
if OutputControl.is_verbose():
print("\n=== Non-streaming generate response ===")
response_json = response.json()
# Print response content
print_json_response(
{
"model": response_json["model"],
"response": response_json["response"],
"done": response_json["done"],
},
"Response content",
)
def test_stream_generate() -> None:
"""Test streaming call to /api/generate endpoint"""
url = get_base_url("generate")
data = create_generate_request_data(
CONFIG["test_cases"]["generate"]["query"], stream=True
)
# Send request and get streaming response
response = make_request(url, data, stream=True)
if OutputControl.is_verbose():
print("\n=== Streaming generate response ===")
output_buffer = []
try:
for line in response.iter_lines():
if line: # Skip empty lines
try:
# Decode and parse JSON
data = json.loads(line.decode("utf-8"))
if data.get("done", True): # If it's the completion marker
if (
"total_duration" in data
): # Final performance statistics message
break
else: # Normal content message
content = data.get("response", "")
if content: # Only collect non-empty content
output_buffer.append(content)
print(
content, end="", flush=True
) # Print content in real-time
except json.JSONDecodeError:
print("Error decoding JSON from response line")
finally:
response.close() # Ensure the response connection is closed
# Print a newline
print()
def test_generate_with_system() -> None:
"""Test generate with system prompt"""
url = get_base_url("generate")
data = create_generate_request_data(
CONFIG["test_cases"]["generate"]["query"],
system="你是一个知识渊博的助手",
stream=False,
)
# Send request
response = make_request(url, data)
# Print response
if OutputControl.is_verbose():
print("\n=== Generate with system prompt response ===")
response_json = response.json()
# Print response content
print_json_response(
{
"model": response_json["model"],
"response": response_json["response"],
"done": response_json["done"],
},
"Response content",
)
def test_generate_error_handling() -> None:
"""Test error handling for generate endpoint"""
url = get_base_url("generate")
# Test empty prompt
if OutputControl.is_verbose():
print("\n=== Testing empty prompt ===")
data = create_generate_request_data("", stream=False)
response = make_request(url, data)
print(f"Status code: {response.status_code}")
print_json_response(response.json(), "Error message")
# Test invalid options
if OutputControl.is_verbose():
print("\n=== Testing invalid options ===")
data = create_generate_request_data(
CONFIG["test_cases"]["basic"]["query"],
options={"invalid_option": "value"},
stream=False,
)
response = make_request(url, data)
print(f"Status code: {response.status_code}")
print_json_response(response.json(), "Error message")
def test_generate_concurrent() -> None:
"""Test concurrent generate requests"""
import asyncio
import aiohttp
from contextlib import asynccontextmanager
@asynccontextmanager
async def get_session():
async with aiohttp.ClientSession() as session:
yield session
async def make_request(session, prompt: str):
url = get_base_url("generate")
data = create_generate_request_data(prompt, stream=False)
try:
async with session.post(url, json=data) as response:
return await response.json()
except Exception as e:
return {"error": str(e)}
async def run_concurrent_requests():
prompts = ["第一个问题", "第二个问题", "第三个问题", "第四个问题", "第五个问题"]
async with get_session() as session:
tasks = [make_request(session, prompt) for prompt in prompts]
results = await asyncio.gather(*tasks)
return results
if OutputControl.is_verbose():
print("\n=== Testing concurrent generate requests ===")
# Run concurrent requests
results = asyncio.run(run_concurrent_requests())
# Print results
for i, result in enumerate(results, 1):
print(f"\nRequest {i} result:")
print_json_response(result)
2025-01-15 22:15:46 +08:00
def get_test_cases() -> Dict[str, Callable]:
"""Get all available test cases
2025-01-15 22:15:46 +08:00
Returns:
A dictionary mapping test names to test functions
2025-01-15 22:15:46 +08:00
"""
return {
"non_stream": test_non_stream_chat,
"stream": test_stream_chat,
"modes": test_query_modes,
"errors": test_error_handling,
2025-01-17 14:20:55 +08:00
"stream_errors": test_stream_error_handling,
"non_stream_generate": test_non_stream_generate,
"stream_generate": test_stream_generate,
"generate_with_system": test_generate_with_system,
"generate_errors": test_generate_error_handling,
"generate_concurrent": test_generate_concurrent,
2025-01-15 22:15:46 +08:00
}
2025-01-17 14:20:55 +08:00
2025-01-15 22:15:46 +08:00
def create_default_config():
"""Create a default configuration file"""
2025-01-15 22:15:46 +08:00
config_path = Path("config.json")
if not config_path.exists():
with open(config_path, "w", encoding="utf-8") as f:
json.dump(DEFAULT_CONFIG, f, ensure_ascii=False, indent=2)
print(f"Default configuration file created: {config_path}")
2025-01-15 22:15:46 +08:00
2025-01-17 14:20:55 +08:00
2025-01-15 22:15:46 +08:00
def parse_args() -> argparse.Namespace:
"""Parse command line arguments"""
2025-01-15 22:15:46 +08:00
parser = argparse.ArgumentParser(
description="LightRAG Ollama Compatibility Interface Testing",
2025-01-15 22:15:46 +08:00
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Configuration file (config.json):
2025-01-15 22:15:46 +08:00
{
"server": {
"host": "localhost", # Server address
"port": 9621, # Server port
"model": "lightrag:latest" # Default model name
2025-01-15 22:15:46 +08:00
},
"test_cases": {
"basic": {
"query": "Test query", # Basic query text
"stream_query": "Stream query" # Stream query text
2025-01-15 22:15:46 +08:00
}
}
}
2025-01-17 14:20:55 +08:00
""",
2025-01-15 22:15:46 +08:00
)
parser.add_argument(
2025-01-17 14:20:55 +08:00
"-q",
"--quiet",
action="store_true",
2025-01-17 14:20:55 +08:00
help="Silent mode, only display test result summary",
)
parser.add_argument(
2025-01-17 14:20:55 +08:00
"-a",
"--ask",
type=str,
2025-01-17 14:20:55 +08:00
help="Specify query content, which will override the query settings in the configuration file",
2025-01-15 22:15:46 +08:00
)
parser.add_argument(
2025-01-17 14:20:55 +08:00
"--init-config", action="store_true", help="Create default configuration file"
2025-01-15 22:15:46 +08:00
)
parser.add_argument(
"--output",
type=str,
2025-01-16 20:22:53 +08:00
default="",
2025-01-17 14:20:55 +08:00
help="Test result output file path, default is not to output to a file",
2025-01-15 22:15:46 +08:00
)
parser.add_argument(
"--tests",
nargs="+",
choices=list(get_test_cases().keys()) + ["all"],
default=["all"],
2025-01-17 14:20:55 +08:00
help="Test cases to run, options: %(choices)s. Use 'all' to run all tests",
2025-01-15 22:15:46 +08:00
)
return parser.parse_args()
2025-01-17 14:20:55 +08:00
if __name__ == "__main__":
2025-01-15 22:15:46 +08:00
args = parse_args()
2025-01-17 13:36:31 +08:00
# Set output mode
2025-01-15 22:15:46 +08:00
OutputControl.set_verbose(not args.quiet)
2025-01-17 13:36:31 +08:00
# If query content is specified, update the configuration
if args.ask:
CONFIG["test_cases"]["basic"]["query"] = args.ask
2025-01-17 13:36:31 +08:00
# If specified to create a configuration file
2025-01-15 22:15:46 +08:00
if args.init_config:
create_default_config()
exit(0)
2025-01-17 13:36:31 +08:00
2025-01-15 22:15:46 +08:00
test_cases = get_test_cases()
2025-01-17 13:36:31 +08:00
2025-01-15 22:15:46 +08:00
try:
if "all" in args.tests:
# Run all tests
2025-01-15 22:15:46 +08:00
if OutputControl.is_verbose():
print("\n【Chat API Tests】")
run_test(test_non_stream_chat, "Non-streaming Chat Test")
run_test(test_stream_chat, "Streaming Chat Test")
run_test(test_query_modes, "Chat Query Mode Test")
run_test(test_error_handling, "Chat Error Handling Test")
run_test(test_stream_error_handling, "Chat Streaming Error Test")
2025-01-17 13:36:31 +08:00
2025-01-15 22:15:46 +08:00
if OutputControl.is_verbose():
print("\n【Generate API Tests】")
run_test(test_non_stream_generate, "Non-streaming Generate Test")
run_test(test_stream_generate, "Streaming Generate Test")
run_test(test_generate_with_system, "Generate with System Prompt Test")
run_test(test_generate_error_handling, "Generate Error Handling Test")
run_test(test_generate_concurrent, "Generate Concurrent Test")
2025-01-15 22:15:46 +08:00
else:
# Run specified tests
2025-01-15 22:15:46 +08:00
for test_name in args.tests:
if OutputControl.is_verbose():
print(f"\n【Running Test: {test_name}")
2025-01-15 22:15:46 +08:00
run_test(test_cases[test_name], test_name)
except Exception as e:
print(f"\nAn error occurred: {str(e)}")
2025-01-15 22:15:46 +08:00
finally:
# Print test statistics
2025-01-15 22:15:46 +08:00
STATS.print_summary()
# If an output file path is specified, export the results
2025-01-16 20:22:53 +08:00
if args.output:
STATS.export_results(args.output)