2025-01-15 22:15:46 +08:00
|
|
|
|
"""
|
|
|
|
|
LightRAG Ollama 兼容接口测试脚本
|
|
|
|
|
|
|
|
|
|
这个脚本测试 LightRAG 的 Ollama 兼容接口,包括:
|
|
|
|
|
1. 基本功能测试(流式和非流式响应)
|
|
|
|
|
2. 查询模式测试(local、global、naive、hybrid)
|
|
|
|
|
3. 错误处理测试(包括流式和非流式场景)
|
|
|
|
|
|
|
|
|
|
所有响应都使用 JSON Lines 格式,符合 Ollama API 规范。
|
|
|
|
|
"""
|
|
|
|
|
|
2025-01-15 20:18:17 +08:00
|
|
|
|
import requests
|
|
|
|
|
import json
|
2025-01-15 22:15:46 +08:00
|
|
|
|
import argparse
|
|
|
|
|
import time
|
2025-01-15 22:39:41 +08:00
|
|
|
|
from typing import Dict, Any, Optional, List, Callable
|
2025-01-15 22:15:46 +08:00
|
|
|
|
from dataclasses import dataclass, asdict
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
from pathlib import Path
|
2025-01-15 20:18:17 +08:00
|
|
|
|
|
2025-01-15 22:15:46 +08:00
|
|
|
|
class OutputControl:
|
|
|
|
|
"""输出控制类,管理测试输出的详细程度"""
|
|
|
|
|
_verbose: bool = False
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def set_verbose(cls, verbose: bool) -> None:
|
|
|
|
|
"""设置输出详细程度
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
verbose: True 为详细模式,False 为静默模式
|
|
|
|
|
"""
|
|
|
|
|
cls._verbose = verbose
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def is_verbose(cls) -> bool:
|
|
|
|
|
"""获取当前输出模式
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
当前是否为详细模式
|
|
|
|
|
"""
|
|
|
|
|
return cls._verbose
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class TestResult:
|
|
|
|
|
"""测试结果数据类"""
|
|
|
|
|
name: str
|
|
|
|
|
success: bool
|
|
|
|
|
duration: float
|
|
|
|
|
error: Optional[str] = None
|
|
|
|
|
timestamp: str = ""
|
|
|
|
|
|
|
|
|
|
def __post_init__(self):
|
|
|
|
|
"""初始化后设置时间戳"""
|
|
|
|
|
if not self.timestamp:
|
|
|
|
|
self.timestamp = datetime.now().isoformat()
|
|
|
|
|
|
|
|
|
|
class TestStats:
|
|
|
|
|
"""测试统计信息"""
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.results: List[TestResult] = []
|
|
|
|
|
self.start_time = datetime.now()
|
|
|
|
|
|
|
|
|
|
def add_result(self, result: TestResult):
|
|
|
|
|
"""添加测试结果"""
|
|
|
|
|
self.results.append(result)
|
|
|
|
|
|
|
|
|
|
def export_results(self, path: str = "test_results.json"):
|
|
|
|
|
"""导出测试结果到 JSON 文件
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
path: 输出文件路径
|
|
|
|
|
"""
|
|
|
|
|
results_data = {
|
|
|
|
|
"start_time": self.start_time.isoformat(),
|
|
|
|
|
"end_time": datetime.now().isoformat(),
|
|
|
|
|
"results": [asdict(r) for r in self.results],
|
|
|
|
|
"summary": {
|
|
|
|
|
"total": len(self.results),
|
|
|
|
|
"passed": sum(1 for r in self.results if r.success),
|
|
|
|
|
"failed": sum(1 for r in self.results if not r.success),
|
|
|
|
|
"total_duration": sum(r.duration for r in self.results)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
with open(path, "w", encoding="utf-8") as f:
|
|
|
|
|
json.dump(results_data, f, ensure_ascii=False, indent=2)
|
|
|
|
|
print(f"\n测试结果已保存到: {path}")
|
2025-01-15 20:18:17 +08:00
|
|
|
|
|
2025-01-15 22:15:46 +08:00
|
|
|
|
def print_summary(self):
|
|
|
|
|
"""打印测试统计摘要"""
|
|
|
|
|
total = len(self.results)
|
|
|
|
|
passed = sum(1 for r in self.results if r.success)
|
|
|
|
|
failed = total - passed
|
|
|
|
|
duration = sum(r.duration for r in self.results)
|
|
|
|
|
|
|
|
|
|
print("\n=== 测试结果摘要 ===")
|
|
|
|
|
print(f"开始时间: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}")
|
|
|
|
|
print(f"总用时: {duration:.2f}秒")
|
|
|
|
|
print(f"总计: {total} 个测试")
|
|
|
|
|
print(f"通过: {passed} 个")
|
|
|
|
|
print(f"失败: {failed} 个")
|
|
|
|
|
|
|
|
|
|
if failed > 0:
|
|
|
|
|
print("\n失败的测试:")
|
|
|
|
|
for result in self.results:
|
|
|
|
|
if not result.success:
|
|
|
|
|
print(f"- {result.name}: {result.error}")
|
|
|
|
|
|
|
|
|
|
# 默认配置
|
|
|
|
|
DEFAULT_CONFIG = {
|
|
|
|
|
"server": {
|
|
|
|
|
"host": "localhost",
|
|
|
|
|
"port": 9621,
|
2025-01-15 20:18:17 +08:00
|
|
|
|
"model": "lightrag:latest",
|
2025-01-15 22:15:46 +08:00
|
|
|
|
"timeout": 30, # 请求超时时间(秒)
|
|
|
|
|
"max_retries": 3, # 最大重试次数
|
|
|
|
|
"retry_delay": 1 # 重试间隔(秒)
|
|
|
|
|
},
|
|
|
|
|
"test_cases": {
|
|
|
|
|
"basic": {
|
2025-01-16 01:11:59 +08:00
|
|
|
|
"query": "唐僧有几个徒弟"
|
2025-01-15 22:15:46 +08:00
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> requests.Response:
|
|
|
|
|
"""发送 HTTP 请求,支持重试机制
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
url: 请求 URL
|
|
|
|
|
data: 请求数据
|
|
|
|
|
stream: 是否使用流式响应
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
requests.Response 对象
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
requests.exceptions.RequestException: 请求失败且重试次数用完
|
|
|
|
|
"""
|
|
|
|
|
server_config = CONFIG["server"]
|
|
|
|
|
max_retries = server_config["max_retries"]
|
|
|
|
|
retry_delay = server_config["retry_delay"]
|
|
|
|
|
timeout = server_config["timeout"]
|
|
|
|
|
|
|
|
|
|
for attempt in range(max_retries):
|
|
|
|
|
try:
|
|
|
|
|
response = requests.post(
|
|
|
|
|
url,
|
|
|
|
|
json=data,
|
|
|
|
|
stream=stream,
|
|
|
|
|
timeout=timeout
|
|
|
|
|
)
|
|
|
|
|
return response
|
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
|
|
|
if attempt == max_retries - 1: # 最后一次重试
|
|
|
|
|
raise
|
|
|
|
|
print(f"\n请求失败,{retry_delay}秒后重试: {str(e)}")
|
|
|
|
|
time.sleep(retry_delay)
|
|
|
|
|
|
|
|
|
|
def load_config() -> Dict[str, Any]:
|
|
|
|
|
"""加载配置文件
|
|
|
|
|
|
|
|
|
|
首先尝试从当前目录的 config.json 加载,
|
|
|
|
|
如果不存在则使用默认配置
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
配置字典
|
|
|
|
|
"""
|
|
|
|
|
config_path = Path("config.json")
|
|
|
|
|
if config_path.exists():
|
|
|
|
|
with open(config_path, "r", encoding="utf-8") as f:
|
|
|
|
|
return json.load(f)
|
|
|
|
|
return DEFAULT_CONFIG
|
|
|
|
|
|
|
|
|
|
def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) -> None:
|
|
|
|
|
"""格式化打印 JSON 响应数据
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
data: 要打印的数据字典
|
|
|
|
|
title: 打印的标题
|
|
|
|
|
indent: JSON 缩进空格数
|
|
|
|
|
"""
|
|
|
|
|
if OutputControl.is_verbose():
|
|
|
|
|
if title:
|
|
|
|
|
print(f"\n=== {title} ===")
|
|
|
|
|
print(json.dumps(data, ensure_ascii=False, indent=indent))
|
|
|
|
|
|
|
|
|
|
# 全局配置
|
|
|
|
|
CONFIG = load_config()
|
|
|
|
|
|
|
|
|
|
def get_base_url() -> str:
|
|
|
|
|
"""返回基础 URL"""
|
|
|
|
|
server = CONFIG["server"]
|
|
|
|
|
return f"http://{server['host']}:{server['port']}/api/chat"
|
|
|
|
|
|
|
|
|
|
def create_request_data(
|
|
|
|
|
content: str,
|
|
|
|
|
stream: bool = False,
|
|
|
|
|
model: str = None
|
|
|
|
|
) -> Dict[str, Any]:
|
|
|
|
|
"""创建基本的请求数据
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
content: 用户消息内容
|
|
|
|
|
stream: 是否使用流式响应
|
|
|
|
|
model: 模型名称
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
包含完整请求数据的字典
|
|
|
|
|
"""
|
|
|
|
|
return {
|
|
|
|
|
"model": model or CONFIG["server"]["model"],
|
2025-01-15 20:18:17 +08:00
|
|
|
|
"messages": [
|
|
|
|
|
{
|
|
|
|
|
"role": "user",
|
2025-01-15 22:15:46 +08:00
|
|
|
|
"content": content
|
2025-01-15 20:18:17 +08:00
|
|
|
|
}
|
|
|
|
|
],
|
2025-01-15 22:15:46 +08:00
|
|
|
|
"stream": stream
|
2025-01-15 20:18:17 +08:00
|
|
|
|
}
|
2025-01-15 22:15:46 +08:00
|
|
|
|
|
|
|
|
|
# 全局测试统计
|
|
|
|
|
STATS = TestStats()
|
|
|
|
|
|
|
|
|
|
def run_test(func: Callable, name: str) -> None:
|
|
|
|
|
"""运行测试并记录结果
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
func: 测试函数
|
|
|
|
|
name: 测试名称
|
|
|
|
|
"""
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
try:
|
|
|
|
|
func()
|
|
|
|
|
duration = time.time() - start_time
|
|
|
|
|
STATS.add_result(TestResult(name, True, duration))
|
|
|
|
|
except Exception as e:
|
|
|
|
|
duration = time.time() - start_time
|
|
|
|
|
STATS.add_result(TestResult(name, False, duration, str(e)))
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
def test_non_stream_chat():
|
|
|
|
|
"""测试非流式调用 /api/chat 接口"""
|
|
|
|
|
url = get_base_url()
|
|
|
|
|
data = create_request_data(
|
|
|
|
|
CONFIG["test_cases"]["basic"]["query"],
|
|
|
|
|
stream=False
|
|
|
|
|
)
|
2025-01-15 20:18:17 +08:00
|
|
|
|
|
|
|
|
|
# 发送请求
|
2025-01-15 22:15:46 +08:00
|
|
|
|
response = make_request(url, data)
|
2025-01-15 20:18:17 +08:00
|
|
|
|
|
|
|
|
|
# 打印响应
|
2025-01-15 22:15:46 +08:00
|
|
|
|
if OutputControl.is_verbose():
|
|
|
|
|
print("\n=== 非流式调用响应 ===")
|
2025-01-15 21:15:12 +08:00
|
|
|
|
response_json = response.json()
|
|
|
|
|
|
2025-01-15 22:15:46 +08:00
|
|
|
|
# 打印响应内容
|
|
|
|
|
print_json_response({
|
2025-01-15 21:15:12 +08:00
|
|
|
|
"model": response_json["model"],
|
|
|
|
|
"message": response_json["message"]
|
2025-01-15 22:15:46 +08:00
|
|
|
|
}, "响应内容")
|
2025-01-15 21:15:12 +08:00
|
|
|
|
|
2025-01-16 01:16:53 +08:00
|
|
|
|
# # 打印性能统计
|
|
|
|
|
# print_json_response({
|
|
|
|
|
# "total_duration": response_json["total_duration"],
|
|
|
|
|
# "load_duration": response_json["load_duration"],
|
|
|
|
|
# "prompt_eval_count": response_json["prompt_eval_count"],
|
|
|
|
|
# "prompt_eval_duration": response_json["prompt_eval_duration"],
|
|
|
|
|
# "eval_count": response_json["eval_count"],
|
|
|
|
|
# "eval_duration": response_json["eval_duration"]
|
|
|
|
|
# }, "性能统计")
|
2025-01-15 20:18:17 +08:00
|
|
|
|
|
|
|
|
|
def test_stream_chat():
|
2025-01-15 22:15:46 +08:00
|
|
|
|
"""测试流式调用 /api/chat 接口
|
2025-01-15 20:18:17 +08:00
|
|
|
|
|
2025-01-15 22:15:46 +08:00
|
|
|
|
使用 JSON Lines 格式处理流式响应,每行是一个完整的 JSON 对象。
|
|
|
|
|
响应格式:
|
|
|
|
|
{
|
2025-01-15 20:18:17 +08:00
|
|
|
|
"model": "lightrag:latest",
|
2025-01-15 22:15:46 +08:00
|
|
|
|
"created_at": "2024-01-15T00:00:00Z",
|
|
|
|
|
"message": {
|
|
|
|
|
"role": "assistant",
|
|
|
|
|
"content": "部分响应内容",
|
|
|
|
|
"images": null
|
|
|
|
|
},
|
|
|
|
|
"done": false
|
2025-01-15 20:18:17 +08:00
|
|
|
|
}
|
|
|
|
|
|
2025-01-15 22:15:46 +08:00
|
|
|
|
最后一条消息会包含性能统计信息,done 为 true。
|
|
|
|
|
"""
|
|
|
|
|
url = get_base_url()
|
|
|
|
|
data = create_request_data(
|
2025-01-16 01:11:59 +08:00
|
|
|
|
CONFIG["test_cases"]["basic"]["query"],
|
2025-01-15 22:15:46 +08:00
|
|
|
|
stream=True
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 发送请求并获取流式响应
|
|
|
|
|
response = make_request(url, data, stream=True)
|
2025-01-15 20:18:17 +08:00
|
|
|
|
|
2025-01-15 22:15:46 +08:00
|
|
|
|
if OutputControl.is_verbose():
|
|
|
|
|
print("\n=== 流式调用响应 ===")
|
2025-01-15 20:18:17 +08:00
|
|
|
|
output_buffer = []
|
|
|
|
|
try:
|
2025-01-15 22:15:46 +08:00
|
|
|
|
for line in response.iter_lines():
|
|
|
|
|
if line: # 跳过空行
|
|
|
|
|
try:
|
|
|
|
|
# 解码并解析 JSON
|
|
|
|
|
data = json.loads(line.decode('utf-8'))
|
|
|
|
|
if data.get("done", True): # 如果是完成标记
|
|
|
|
|
if "total_duration" in data: # 最终的性能统计消息
|
2025-01-16 01:16:53 +08:00
|
|
|
|
# print_json_response(data, "性能统计")
|
2025-01-15 22:15:46 +08:00
|
|
|
|
break
|
|
|
|
|
else: # 正常的内容消息
|
|
|
|
|
message = data.get("message", {})
|
|
|
|
|
content = message.get("content", "")
|
|
|
|
|
if content: # 只收集非空内容
|
|
|
|
|
output_buffer.append(content)
|
|
|
|
|
print(content, end="", flush=True) # 实时打印内容
|
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
|
print("Error decoding JSON from response line")
|
2025-01-15 20:18:17 +08:00
|
|
|
|
finally:
|
|
|
|
|
response.close() # 确保关闭响应连接
|
|
|
|
|
|
2025-01-15 22:15:46 +08:00
|
|
|
|
# 打印一个换行
|
|
|
|
|
print()
|
|
|
|
|
|
|
|
|
|
def test_query_modes():
|
|
|
|
|
"""测试不同的查询模式前缀
|
|
|
|
|
|
|
|
|
|
支持的查询模式:
|
|
|
|
|
- /local: 本地检索模式,只在相关度高的文档中搜索
|
|
|
|
|
- /global: 全局检索模式,在所有文档中搜索
|
|
|
|
|
- /naive: 朴素模式,不使用任何优化策略
|
|
|
|
|
- /hybrid: 混合模式(默认),结合多种策略
|
|
|
|
|
|
|
|
|
|
每个模式都会返回相同格式的响应,但检索策略不同。
|
|
|
|
|
"""
|
|
|
|
|
url = get_base_url()
|
2025-01-17 11:04:36 +08:00
|
|
|
|
modes = ["local", "global", "naive", "hybrid", "mix"] # 支持的查询模式
|
2025-01-15 22:15:46 +08:00
|
|
|
|
|
|
|
|
|
for mode in modes:
|
|
|
|
|
if OutputControl.is_verbose():
|
|
|
|
|
print(f"\n=== 测试 /{mode} 模式 ===")
|
|
|
|
|
data = create_request_data(
|
2025-01-16 01:11:59 +08:00
|
|
|
|
f"/{mode} {CONFIG['test_cases']['basic']['query']}",
|
2025-01-15 22:15:46 +08:00
|
|
|
|
stream=False
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 发送请求
|
|
|
|
|
response = make_request(url, data)
|
|
|
|
|
response_json = response.json()
|
|
|
|
|
|
|
|
|
|
# 打印响应内容
|
|
|
|
|
print_json_response({
|
|
|
|
|
"model": response_json["model"],
|
|
|
|
|
"message": response_json["message"]
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
def create_error_test_data(error_type: str) -> Dict[str, Any]:
|
|
|
|
|
"""创建用于错误测试的请求数据
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
error_type: 错误类型,支持:
|
|
|
|
|
- empty_messages: 空消息列表
|
|
|
|
|
- invalid_role: 无效的角色字段
|
|
|
|
|
- missing_content: 缺少内容字段
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
包含错误数据的请求字典
|
|
|
|
|
"""
|
|
|
|
|
error_data = {
|
|
|
|
|
"empty_messages": {
|
|
|
|
|
"model": "lightrag:latest",
|
|
|
|
|
"messages": [],
|
|
|
|
|
"stream": True
|
|
|
|
|
},
|
|
|
|
|
"invalid_role": {
|
|
|
|
|
"model": "lightrag:latest",
|
|
|
|
|
"messages": [
|
|
|
|
|
{
|
|
|
|
|
"invalid_role": "user",
|
|
|
|
|
"content": "测试消息"
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"stream": True
|
|
|
|
|
},
|
|
|
|
|
"missing_content": {
|
|
|
|
|
"model": "lightrag:latest",
|
|
|
|
|
"messages": [
|
|
|
|
|
{
|
|
|
|
|
"role": "user"
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"stream": True
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return error_data.get(error_type, error_data["empty_messages"])
|
|
|
|
|
|
|
|
|
|
def test_stream_error_handling():
|
|
|
|
|
"""测试流式响应的错误处理
|
|
|
|
|
|
|
|
|
|
测试场景:
|
|
|
|
|
1. 空消息列表
|
|
|
|
|
2. 消息格式错误(缺少必需字段)
|
|
|
|
|
|
|
|
|
|
错误响应会立即返回,不会建立流式连接。
|
|
|
|
|
状态码应该是 4xx,并返回详细的错误信息。
|
|
|
|
|
"""
|
|
|
|
|
url = get_base_url()
|
|
|
|
|
|
|
|
|
|
if OutputControl.is_verbose():
|
|
|
|
|
print("\n=== 测试流式响应错误处理 ===")
|
|
|
|
|
|
|
|
|
|
# 测试空消息列表
|
|
|
|
|
if OutputControl.is_verbose():
|
|
|
|
|
print("\n--- 测试空消息列表(流式)---")
|
|
|
|
|
data = create_error_test_data("empty_messages")
|
|
|
|
|
response = make_request(url, data, stream=True)
|
|
|
|
|
print(f"状态码: {response.status_code}")
|
|
|
|
|
if response.status_code != 200:
|
|
|
|
|
print_json_response(response.json(), "错误信息")
|
|
|
|
|
response.close()
|
|
|
|
|
|
|
|
|
|
# 测试无效角色字段
|
|
|
|
|
if OutputControl.is_verbose():
|
|
|
|
|
print("\n--- 测试无效角色字段(流式)---")
|
|
|
|
|
data = create_error_test_data("invalid_role")
|
|
|
|
|
response = make_request(url, data, stream=True)
|
|
|
|
|
print(f"状态码: {response.status_code}")
|
|
|
|
|
if response.status_code != 200:
|
|
|
|
|
print_json_response(response.json(), "错误信息")
|
|
|
|
|
response.close()
|
|
|
|
|
|
|
|
|
|
# 测试缺少内容字段
|
|
|
|
|
if OutputControl.is_verbose():
|
|
|
|
|
print("\n--- 测试缺少内容字段(流式)---")
|
|
|
|
|
data = create_error_test_data("missing_content")
|
|
|
|
|
response = make_request(url, data, stream=True)
|
|
|
|
|
print(f"状态码: {response.status_code}")
|
|
|
|
|
if response.status_code != 200:
|
|
|
|
|
print_json_response(response.json(), "错误信息")
|
|
|
|
|
response.close()
|
|
|
|
|
|
|
|
|
|
def test_error_handling():
|
|
|
|
|
"""测试非流式响应的错误处理
|
|
|
|
|
|
|
|
|
|
测试场景:
|
|
|
|
|
1. 空消息列表
|
|
|
|
|
2. 消息格式错误(缺少必需字段)
|
|
|
|
|
|
|
|
|
|
错误响应格式:
|
|
|
|
|
{
|
|
|
|
|
"detail": "错误描述"
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
所有错误都应该返回合适的 HTTP 状态码和清晰的错误信息。
|
|
|
|
|
"""
|
|
|
|
|
url = get_base_url()
|
|
|
|
|
|
|
|
|
|
if OutputControl.is_verbose():
|
|
|
|
|
print("\n=== 测试错误处理 ===")
|
|
|
|
|
|
|
|
|
|
# 测试空消息列表
|
|
|
|
|
if OutputControl.is_verbose():
|
|
|
|
|
print("\n--- 测试空消息列表 ---")
|
|
|
|
|
data = create_error_test_data("empty_messages")
|
|
|
|
|
data["stream"] = False # 修改为非流式模式
|
|
|
|
|
response = make_request(url, data)
|
|
|
|
|
print(f"状态码: {response.status_code}")
|
|
|
|
|
print_json_response(response.json(), "错误信息")
|
|
|
|
|
|
|
|
|
|
# 测试无效角色字段
|
|
|
|
|
if OutputControl.is_verbose():
|
|
|
|
|
print("\n--- 测试无效角色字段 ---")
|
|
|
|
|
data = create_error_test_data("invalid_role")
|
|
|
|
|
data["stream"] = False # 修改为非流式模式
|
|
|
|
|
response = make_request(url, data)
|
|
|
|
|
print(f"状态码: {response.status_code}")
|
|
|
|
|
print_json_response(response.json(), "错误信息")
|
|
|
|
|
|
|
|
|
|
# 测试缺少内容字段
|
|
|
|
|
if OutputControl.is_verbose():
|
|
|
|
|
print("\n--- 测试缺少内容字段 ---")
|
|
|
|
|
data = create_error_test_data("missing_content")
|
|
|
|
|
data["stream"] = False # 修改为非流式模式
|
|
|
|
|
response = make_request(url, data)
|
|
|
|
|
print(f"状态码: {response.status_code}")
|
|
|
|
|
print_json_response(response.json(), "错误信息")
|
|
|
|
|
|
|
|
|
|
def get_test_cases() -> Dict[str, Callable]:
|
|
|
|
|
"""获取所有可用的测试用例
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
测试名称到测试函数的映射字典
|
|
|
|
|
"""
|
|
|
|
|
return {
|
|
|
|
|
"non_stream": test_non_stream_chat,
|
|
|
|
|
"stream": test_stream_chat,
|
|
|
|
|
"modes": test_query_modes,
|
|
|
|
|
"errors": test_error_handling,
|
|
|
|
|
"stream_errors": test_stream_error_handling
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def create_default_config():
|
|
|
|
|
"""创建默认配置文件"""
|
|
|
|
|
config_path = Path("config.json")
|
|
|
|
|
if not config_path.exists():
|
|
|
|
|
with open(config_path, "w", encoding="utf-8") as f:
|
|
|
|
|
json.dump(DEFAULT_CONFIG, f, ensure_ascii=False, indent=2)
|
|
|
|
|
print(f"已创建默认配置文件: {config_path}")
|
|
|
|
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
|
|
|
"""解析命令行参数"""
|
|
|
|
|
parser = argparse.ArgumentParser(
|
|
|
|
|
description="LightRAG Ollama 兼容接口测试",
|
|
|
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
|
|
|
epilog="""
|
|
|
|
|
配置文件 (config.json):
|
|
|
|
|
{
|
|
|
|
|
"server": {
|
|
|
|
|
"host": "localhost", # 服务器地址
|
|
|
|
|
"port": 9621, # 服务器端口
|
|
|
|
|
"model": "lightrag:latest" # 默认模型名称
|
|
|
|
|
},
|
|
|
|
|
"test_cases": {
|
|
|
|
|
"basic": {
|
|
|
|
|
"query": "测试查询", # 基本查询文本
|
|
|
|
|
"stream_query": "流式查询" # 流式查询文本
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
"""
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
2025-01-16 01:11:59 +08:00
|
|
|
|
"-q", "--quiet",
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="静默模式,只显示测试结果摘要"
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"-a", "--ask",
|
|
|
|
|
type=str,
|
|
|
|
|
help="指定查询内容,会覆盖配置文件中的查询设置"
|
2025-01-15 22:15:46 +08:00
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--init-config",
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="创建默认配置文件"
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--output",
|
|
|
|
|
type=str,
|
2025-01-16 20:22:53 +08:00
|
|
|
|
default="",
|
|
|
|
|
help="测试结果输出文件路径,默认不输出到文件"
|
2025-01-15 22:15:46 +08:00
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
2025-01-16 01:11:59 +08:00
|
|
|
|
"--tests",
|
|
|
|
|
nargs="+",
|
|
|
|
|
choices=list(get_test_cases().keys()) + ["all"],
|
|
|
|
|
default=["all"],
|
|
|
|
|
help="要运行的测试用例,可选: %(choices)s。使用 all 运行所有测试"
|
2025-01-15 22:15:46 +08:00
|
|
|
|
)
|
|
|
|
|
return parser.parse_args()
|
2025-01-15 20:18:17 +08:00
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2025-01-15 22:15:46 +08:00
|
|
|
|
args = parse_args()
|
2025-01-15 20:18:17 +08:00
|
|
|
|
|
2025-01-15 22:15:46 +08:00
|
|
|
|
# 设置输出模式
|
|
|
|
|
OutputControl.set_verbose(not args.quiet)
|
|
|
|
|
|
2025-01-16 01:11:59 +08:00
|
|
|
|
# 如果指定了查询内容,更新配置
|
|
|
|
|
if args.ask:
|
|
|
|
|
CONFIG["test_cases"]["basic"]["query"] = args.ask
|
|
|
|
|
|
2025-01-15 22:15:46 +08:00
|
|
|
|
# 如果指定了创建配置文件
|
|
|
|
|
if args.init_config:
|
|
|
|
|
create_default_config()
|
|
|
|
|
exit(0)
|
|
|
|
|
|
|
|
|
|
test_cases = get_test_cases()
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
if "all" in args.tests:
|
|
|
|
|
# 运行所有测试
|
|
|
|
|
if OutputControl.is_verbose():
|
|
|
|
|
print("\n【基本功能测试】")
|
|
|
|
|
run_test(test_non_stream_chat, "非流式调用测试")
|
|
|
|
|
run_test(test_stream_chat, "流式调用测试")
|
|
|
|
|
|
|
|
|
|
if OutputControl.is_verbose():
|
|
|
|
|
print("\n【查询模式测试】")
|
|
|
|
|
run_test(test_query_modes, "查询模式测试")
|
|
|
|
|
|
|
|
|
|
if OutputControl.is_verbose():
|
|
|
|
|
print("\n【错误处理测试】")
|
|
|
|
|
run_test(test_error_handling, "错误处理测试")
|
|
|
|
|
run_test(test_stream_error_handling, "流式错误处理测试")
|
|
|
|
|
else:
|
|
|
|
|
# 运行指定的测试
|
|
|
|
|
for test_name in args.tests:
|
|
|
|
|
if OutputControl.is_verbose():
|
|
|
|
|
print(f"\n【运行测试: {test_name}】")
|
|
|
|
|
run_test(test_cases[test_name], test_name)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f"\n发生错误: {str(e)}")
|
|
|
|
|
finally:
|
2025-01-16 20:22:53 +08:00
|
|
|
|
# 打印测试统计
|
2025-01-15 22:15:46 +08:00
|
|
|
|
STATS.print_summary()
|
2025-01-16 20:22:53 +08:00
|
|
|
|
# 如果指定了输出文件路径,则导出结果
|
|
|
|
|
if args.output:
|
|
|
|
|
STATS.export_results(args.output)
|