mirror of
https://github.com/HKUDS/LightRAG.git
synced 2025-11-23 13:36:10 +00:00
Merge pull request #1987 from danielaskdd/llm-optimization
feat: Support turning off thinking on OpenRouter/vLLM
This commit is contained in:
commit
10bcf1479f
@ -142,6 +142,8 @@ LightRAG对大型语言模型(LLM)的能力要求远高于传统RAG,因为
|
||||
- **LLM选型**:
|
||||
- 推荐选用参数量至少为32B的LLM。
|
||||
- 上下文长度至少为32KB,推荐达到64KB。
|
||||
- 在文档索引阶段不建议选择推理模型。
|
||||
- 在查询阶段建议选择比索引阶段能力更强的模型,以达到更高的查询效果。
|
||||
- **Embedding模型**:
|
||||
- 高性能的Embedding模型对RAG至关重要。
|
||||
- 推荐使用主流的多语言Embedding模型,例如:BAAI/bge-m3 和 text-embedding-3-large。
|
||||
@ -265,7 +267,7 @@ if __name__ == "__main__":
|
||||
| **embedding_func_max_async** | `int` | 最大并发异步嵌入进程数 | `16` |
|
||||
| **llm_model_func** | `callable` | LLM生成的函数 | `gpt_4o_mini_complete` |
|
||||
| **llm_model_name** | `str` | 用于生成的LLM模型名称 | `meta-llama/Llama-3.2-1B-Instruct` |
|
||||
| **summary_max_tokens** | `int` | 生成实体关系摘要时送给LLM的最大令牌数 | `32000`(默认值由环境变量MAX_TOKENS更改) |
|
||||
| **summary_max_tokens** | `int` | 生成实体关系摘要时送给LLM的最大令牌数 | `32000`(由环境变量 SUMMARY_MAX_TOKENS 设置) |
|
||||
| **llm_model_max_async** | `int` | 最大并发异步LLM进程数 | `4`(默认值由环境变量MAX_ASYNC更改) |
|
||||
| **llm_model_kwargs** | `dict` | LLM生成的附加参数 | |
|
||||
| **vector_db_storage_cls_kwargs** | `dict` | 向量数据库的附加参数,如设置节点和关系检索的阈值 | cosine_better_than_threshold: 0.2(默认值由环境变量COSINE_THRESHOLD更改) |
|
||||
|
||||
@ -141,6 +141,8 @@ LightRAG's demands on the capabilities of Large Language Models (LLMs) are signi
|
||||
- **LLM Selection**:
|
||||
- It is recommended to use an LLM with at least 32 billion parameters.
|
||||
- The context length should be at least 32KB, with 64KB being recommended.
|
||||
- It is not recommended to choose reasoning models during the document indexing stage.
|
||||
- During the query stage, it is recommended to choose models with stronger capabilities than those used in the indexing stage to achieve better query results.
|
||||
- **Embedding Model**:
|
||||
- A high-performance Embedding model is essential for RAG.
|
||||
- We recommend using mainstream multilingual Embedding models, such as: `BAAI/bge-m3` and `text-embedding-3-large`.
|
||||
@ -272,7 +274,7 @@ A full list of LightRAG init parameters:
|
||||
| **embedding_func_max_async** | `int` | Maximum number of concurrent asynchronous embedding processes | `16` |
|
||||
| **llm_model_func** | `callable` | Function for LLM generation | `gpt_4o_mini_complete` |
|
||||
| **llm_model_name** | `str` | LLM model name for generation | `meta-llama/Llama-3.2-1B-Instruct` |
|
||||
| **summary_max_tokens** | `int` | Maximum tokens send to LLM to generate entity relation summaries | `32000`(default value changed by env var MAX_TOKENS) |
|
||||
| **summary_max_tokens** | `int` | Maximum tokens send to LLM to generate entity relation summaries | `32000`(configured by env var SUMMARY_MAX_TOKENS) |
|
||||
| **llm_model_max_async** | `int` | Maximum number of concurrent asynchronous LLM processes | `4`(default value changed by env var MAX_ASYNC) |
|
||||
| **llm_model_kwargs** | `dict` | Additional parameters for LLM generation | |
|
||||
| **vector_db_storage_cls_kwargs** | `dict` | Additional parameters for vector database, like setting the threshold for nodes and relations retrieval | cosine_better_than_threshold: 0.2(default value changed by env var COSINE_THRESHOLD) |
|
||||
@ -1287,8 +1289,10 @@ LightRAG now seamlessly integrates with [RAG-Anything](https://github.com/HKUDS/
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Initialize storage (this will load existing data if available)
|
||||
await lightrag_instance.initialize_storages()
|
||||
|
||||
# Now initialize RAGAnything with the existing LightRAG instance
|
||||
rag = RAGAnything(
|
||||
lightrag=lightrag_instance, # Pass the existing LightRAG instance
|
||||
@ -1317,12 +1321,14 @@ LightRAG now seamlessly integrates with [RAG-Anything](https://github.com/HKUDS/
|
||||
)
|
||||
# Note: working_dir, llm_model_func, embedding_func, etc. are inherited from lightrag_instance
|
||||
)
|
||||
|
||||
# Query the existing knowledge base
|
||||
result = await rag.query_with_multimodal(
|
||||
"What data has been processed in this LightRAG instance?",
|
||||
mode="hybrid"
|
||||
)
|
||||
print("Query result:", result)
|
||||
|
||||
# Add new multimodal documents to the existing LightRAG instance
|
||||
await rag.process_document_complete(
|
||||
file_path="path/to/new/multimodal_document.pdf",
|
||||
|
||||
34
env.example
34
env.example
@ -8,6 +8,8 @@ PORT=9621
|
||||
WEBUI_TITLE='My Graph KB'
|
||||
WEBUI_DESCRIPTION="Simple and Fast Graph Based RAG System"
|
||||
# WORKERS=2
|
||||
### gunicorn worker timeout(as default LLM request timeout if LLM_TIMEOUT is not set)
|
||||
# TIMEOUT=150
|
||||
# CORS_ORIGINS=http://localhost:3000,http://localhost:8080
|
||||
|
||||
### Optional SSL Configuration
|
||||
@ -105,7 +107,7 @@ ENABLE_LLM_CACHE_FOR_EXTRACT=true
|
||||
### Entity and relation summarization configuration
|
||||
### Number of duplicated entities/edges to trigger LLM re-summary on merge (at least 3 is recommented), and max tokens send to LLM
|
||||
# FORCE_LLM_SUMMARY_ON_MERGE=4
|
||||
# MAX_TOKENS=10000
|
||||
# SUMMARY_MAX_TOKENS=10000
|
||||
### Maximum number of entity extraction attempts for ambiguous content
|
||||
# MAX_GLEANING=1
|
||||
|
||||
@ -125,8 +127,8 @@ MAX_PARALLEL_INSERT=2
|
||||
### LLM Configuration
|
||||
### LLM_BINDING type: openai, ollama, lollms, azure_openai, aws_bedrock
|
||||
###########################################################
|
||||
### LLM temperature setting for all llm binding (openai, azure_openai, ollama)
|
||||
# TEMPERATURE=1.0
|
||||
### LLM request timeout setting for all llm (set to TIMEOUT if not specified, 0 means no timeout for Ollma)
|
||||
# LLM_TIMEOUT=150
|
||||
### Some models like o1-mini require temperature to be set to 1, some LLM can fall into output loops with low temperature
|
||||
|
||||
LLM_BINDING=openai
|
||||
@ -145,23 +147,33 @@ LLM_BINDING_API_KEY=your_api_key
|
||||
# LLM_BINDING=openai
|
||||
|
||||
### OpenAI Specific Parameters
|
||||
### Apply frequency penalty to prevent the LLM from generating repetitive or looping outputs
|
||||
# OPENAI_LLM_FREQUENCY_PENALTY=1.1
|
||||
### use the following command to see all support options for openai and azure_openai
|
||||
# OPENAI_LLM_TEMPERATURE=1.0
|
||||
# OPENAI_LLM_REASONING_EFFORT=low
|
||||
### For models like Qwen3 with fewer than 32B param, it is recommended to set the presence penalty to 1.5
|
||||
# OPENAI_LLM_PRESENCE_PENALTY=1.5
|
||||
### If the presence penalty still can not stop the model from generates repetitive or unconstrained output
|
||||
# OPENAI_LLM_MAX_COMPLETION_TOKENS=16384
|
||||
|
||||
### OpenRouter Specific Parameters
|
||||
# OPENAI_LLM_EXTRA_BODY='{"reasoning": {"enabled": false}}'
|
||||
### Qwen3 Specific Parameters depoly by vLLM
|
||||
# OPENAI_LLM_EXTRA_BODY='{"chat_template_kwargs": {"enable_thinking": false}}'
|
||||
|
||||
### use the following command to see all support options for OpenAI, azure_openai or OpenRouter
|
||||
### lightrag-server --llm-binding openai --help
|
||||
|
||||
### Ollama Server Specific Parameters
|
||||
### Time out in seconds, None for infinite timeout
|
||||
TIMEOUT=240
|
||||
### OLLAMA_LLM_NUM_CTX must be larger than MAX_TOTAL_TOKENS + 2000
|
||||
### OLLAMA_LLM_NUM_CTX must be provided, and should at least larger than MAX_TOTAL_TOKENS + 2000
|
||||
OLLAMA_LLM_NUM_CTX=32768
|
||||
# OLLAMA_LLM_TEMPERATURE=1.0
|
||||
### Stop sequences for Ollama LLM
|
||||
# OLLAMA_LLM_STOP='["</s>", "Assistant:", "\n\n"]'
|
||||
### If OLLAMA_LLM_TEMPERATURE is not specified, the system will default to the value defined by TEMPERATURE
|
||||
# OLLAMA_LLM_TEMPERATURE=0.85
|
||||
### use the following command to see all support options for Ollama LLM
|
||||
### lightrag-server --llm-binding ollama --help
|
||||
|
||||
### Bedrock Specific Parameters
|
||||
# BEDROCK_LLM_TEMPERATURE=1.0
|
||||
|
||||
####################################################################################
|
||||
### Embedding Configuration (Should not be changed after the first file processed)
|
||||
####################################################################################
|
||||
|
||||
@ -357,7 +357,7 @@ API 服务器可以通过三种方式配置(优先级从高到低):
|
||||
LightRAG 支持绑定到各种 LLM/嵌入后端:
|
||||
|
||||
* ollama
|
||||
* openai 和 openai 兼容
|
||||
* openai (含openai 兼容)
|
||||
* azure_openai
|
||||
* lollms
|
||||
* aws_bedrock
|
||||
@ -372,7 +372,10 @@ lightrag-server --llm-binding ollama --help
|
||||
lightrag-server --embedding-binding ollama --help
|
||||
```
|
||||
|
||||
> 请使用openai兼容方式访问OpenRouter或vLLM部署的LLM。可以通过 `OPENAI_LLM_EXTRA_BODY` 环境变量给OpenRouter或vLLM传递额外的参数,实现推理模式的关闭或者其它个性化控制。
|
||||
|
||||
### 实体提取配置
|
||||
|
||||
* ENABLE_LLM_CACHE_FOR_EXTRACT:为实体提取启用 LLM 缓存(默认:true)
|
||||
|
||||
在测试环境中将 `ENABLE_LLM_CACHE_FOR_EXTRACT` 设置为 true 以减少 LLM 调用成本是很常见的做法。
|
||||
@ -478,7 +481,7 @@ SUMMARY_LANGUAGE=Chinese
|
||||
MAX_PARALLEL_INSERT=2
|
||||
|
||||
### LLM Configuration (Use valid host. For local services installed with docker, you can use host.docker.internal)
|
||||
TIMEOUT=200
|
||||
TIMEOUT=150
|
||||
MAX_ASYNC=4
|
||||
|
||||
LLM_BINDING=openai
|
||||
|
||||
@ -360,7 +360,7 @@ Most of the configurations come with default settings; check out the details in
|
||||
LightRAG supports binding to various LLM/Embedding backends:
|
||||
|
||||
* ollama
|
||||
* openai & openai compatible
|
||||
* openai (including openai compatible)
|
||||
* azure_openai
|
||||
* lollms
|
||||
* aws_bedrock
|
||||
@ -374,6 +374,8 @@ lightrag-server --llm-binding ollama --help
|
||||
lightrag-server --embedding-binding ollama --help
|
||||
```
|
||||
|
||||
> Please use OpenAI-compatible method to access LLMs deployed by OpenRouter or vLLM. You can pass additional parameters to OpenRouter or vLLM through the `OPENAI_LLM_EXTRA_BODY` environment variable to disable reasoning mode or achieve other personalized controls.
|
||||
|
||||
### Entity Extraction Configuration
|
||||
* ENABLE_LLM_CACHE_FOR_EXTRACT: Enable LLM cache for entity extraction (default: true)
|
||||
|
||||
@ -485,7 +487,7 @@ SUMMARY_LANGUAGE=Chinese
|
||||
MAX_PARALLEL_INSERT=2
|
||||
|
||||
### LLM Configuration (Use valid host. For local services installed with docker, you can use host.docker.internal)
|
||||
TIMEOUT=200
|
||||
TIMEOUT=150
|
||||
MAX_ASYNC=4
|
||||
|
||||
LLM_BINDING=openai
|
||||
|
||||
@ -35,7 +35,6 @@ from lightrag.constants import (
|
||||
DEFAULT_EMBEDDING_BATCH_NUM,
|
||||
DEFAULT_OLLAMA_MODEL_NAME,
|
||||
DEFAULT_OLLAMA_MODEL_TAG,
|
||||
DEFAULT_TEMPERATURE,
|
||||
)
|
||||
|
||||
# use the .env that is inside the current folder
|
||||
@ -123,7 +122,7 @@ def parse_args() -> argparse.Namespace:
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
type=int,
|
||||
default=get_env_value("MAX_TOKENS", DEFAULT_SUMMARY_MAX_TOKENS, int),
|
||||
default=get_env_value("SUMMARY_MAX_TOKENS", DEFAULT_SUMMARY_MAX_TOKENS, int),
|
||||
help=f"Maximum token size (default: from env or {DEFAULT_SUMMARY_MAX_TOKENS})",
|
||||
)
|
||||
|
||||
@ -264,14 +263,6 @@ def parse_args() -> argparse.Namespace:
|
||||
elif os.environ.get("LLM_BINDING") in ["openai", "azure_openai"]:
|
||||
OpenAILLMOptions.add_args(parser)
|
||||
|
||||
# Add global temperature command line argument
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=get_env_value("TEMPERATURE", DEFAULT_TEMPERATURE, float),
|
||||
help="Global temperature setting for LLM (default: from env TEMPERATURE or 0.1)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# convert relative path to absolute path
|
||||
@ -330,32 +321,6 @@ def parse_args() -> argparse.Namespace:
|
||||
)
|
||||
args.enable_llm_cache = get_env_value("ENABLE_LLM_CACHE", True, bool)
|
||||
|
||||
# Handle Ollama LLM temperature with priority cascade when llm-binding is ollama
|
||||
if args.llm_binding == "ollama":
|
||||
# Priority order (highest to lowest):
|
||||
# 1. --ollama-llm-temperature command argument
|
||||
# 2. OLLAMA_LLM_TEMPERATURE environment variable
|
||||
# 3. --temperature command argument
|
||||
# 4. TEMPERATURE environment variable
|
||||
|
||||
# Check if --ollama-llm-temperature was explicitly provided in command line
|
||||
if "--ollama-llm-temperature" not in sys.argv:
|
||||
# Use args.temperature which handles --temperature command arg and TEMPERATURE env var priority
|
||||
args.ollama_llm_temperature = args.temperature
|
||||
|
||||
# Handle OpenAI LLM temperature with priority cascade when llm-binding is openai or azure_openai
|
||||
if args.llm_binding in ["openai", "azure_openai"]:
|
||||
# Priority order (highest to lowest):
|
||||
# 1. --openai-llm-temperature command argument
|
||||
# 2. OPENAI_LLM_TEMPERATURE environment variable
|
||||
# 3. --temperature command argument
|
||||
# 4. TEMPERATURE environment variable
|
||||
|
||||
# Check if --openai-llm-temperature was explicitly provided in command line
|
||||
if "--openai-llm-temperature" not in sys.argv:
|
||||
# Use args.temperature which handles --temperature command arg and TEMPERATURE env var priority
|
||||
args.openai_llm_temperature = args.temperature
|
||||
|
||||
# Select Document loading tool (DOCLING, DEFAULT)
|
||||
args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT")
|
||||
|
||||
|
||||
@ -254,6 +254,8 @@ def create_app(args):
|
||||
if args.embedding_binding == "jina":
|
||||
from lightrag.llm.jina import jina_embed
|
||||
|
||||
llm_timeout = get_env_value("LLM_TIMEOUT", args.timeout, int)
|
||||
|
||||
async def openai_alike_model_complete(
|
||||
prompt,
|
||||
system_prompt=None,
|
||||
@ -267,12 +269,10 @@ def create_app(args):
|
||||
if history_messages is None:
|
||||
history_messages = []
|
||||
|
||||
# Use OpenAI LLM options if available, otherwise fallback to global temperature
|
||||
if args.llm_binding == "openai":
|
||||
# Use OpenAI LLM options if available
|
||||
openai_options = OpenAILLMOptions.options_dict(args)
|
||||
kwargs["timeout"] = llm_timeout
|
||||
kwargs.update(openai_options)
|
||||
else:
|
||||
kwargs["temperature"] = args.temperature
|
||||
|
||||
return await openai_complete_if_cache(
|
||||
args.llm_model,
|
||||
@ -297,12 +297,10 @@ def create_app(args):
|
||||
if history_messages is None:
|
||||
history_messages = []
|
||||
|
||||
# Use OpenAI LLM options if available, otherwise fallback to global temperature
|
||||
if args.llm_binding == "azure_openai":
|
||||
# Use OpenAI LLM options
|
||||
openai_options = OpenAILLMOptions.options_dict(args)
|
||||
kwargs["timeout"] = llm_timeout
|
||||
kwargs.update(openai_options)
|
||||
else:
|
||||
kwargs["temperature"] = args.temperature
|
||||
|
||||
return await azure_openai_complete_if_cache(
|
||||
args.llm_model,
|
||||
@ -329,7 +327,7 @@ def create_app(args):
|
||||
history_messages = []
|
||||
|
||||
# Use global temperature for Bedrock
|
||||
kwargs["temperature"] = args.temperature
|
||||
kwargs["temperature"] = get_env_value("BEDROCK_LLM_TEMPERATURE", 1.0, float)
|
||||
|
||||
return await bedrock_complete_if_cache(
|
||||
args.llm_model,
|
||||
@ -451,7 +449,7 @@ def create_app(args):
|
||||
llm_model_kwargs=(
|
||||
{
|
||||
"host": args.llm_binding_host,
|
||||
"timeout": args.timeout,
|
||||
"timeout": llm_timeout,
|
||||
"options": OllamaLLMOptions.options_dict(args),
|
||||
"api_key": args.llm_binding_api_key,
|
||||
}
|
||||
@ -481,9 +479,6 @@ def create_app(args):
|
||||
llm_model_func=azure_openai_model_complete,
|
||||
chunk_token_size=int(args.chunk_size),
|
||||
chunk_overlap_token_size=int(args.chunk_overlap_size),
|
||||
llm_model_kwargs={
|
||||
"timeout": args.timeout,
|
||||
},
|
||||
llm_model_name=args.llm_model,
|
||||
llm_model_max_async=args.max_async,
|
||||
summary_max_tokens=args.max_tokens,
|
||||
|
||||
@ -153,7 +153,7 @@ def main():
|
||||
|
||||
# Timeout configuration prioritizes command line arguments
|
||||
gunicorn_config.timeout = (
|
||||
global_args.timeout * 2
|
||||
global_args.timeout + 30
|
||||
if global_args.timeout is not None
|
||||
else get_env_value(
|
||||
"TIMEOUT", DEFAULT_TIMEOUT + 30, int, special_none=True
|
||||
|
||||
@ -201,6 +201,8 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
||||
ASCIIColors.yellow(f"{args.port}")
|
||||
ASCIIColors.white(" ├─ Workers: ", end="")
|
||||
ASCIIColors.yellow(f"{args.workers}")
|
||||
ASCIIColors.white(" ├─ Timeout: ", end="")
|
||||
ASCIIColors.yellow(f"{args.timeout}")
|
||||
ASCIIColors.white(" ├─ CORS Origins: ", end="")
|
||||
ASCIIColors.yellow(f"{args.cors_origins}")
|
||||
ASCIIColors.white(" ├─ SSL Enabled: ", end="")
|
||||
@ -238,14 +240,10 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
||||
ASCIIColors.yellow(f"{args.llm_binding_host}")
|
||||
ASCIIColors.white(" ├─ Model: ", end="")
|
||||
ASCIIColors.yellow(f"{args.llm_model}")
|
||||
ASCIIColors.white(" ├─ Temperature: ", end="")
|
||||
ASCIIColors.yellow(f"{args.temperature}")
|
||||
ASCIIColors.white(" ├─ Max Async for LLM: ", end="")
|
||||
ASCIIColors.yellow(f"{args.max_async}")
|
||||
ASCIIColors.white(" ├─ Max Tokens: ", end="")
|
||||
ASCIIColors.yellow(f"{args.max_tokens}")
|
||||
ASCIIColors.white(" ├─ Timeout: ", end="")
|
||||
ASCIIColors.yellow(f"{args.timeout if args.timeout else 'None (infinite)'}")
|
||||
ASCIIColors.white(" ├─ LLM Cache Enabled: ", end="")
|
||||
ASCIIColors.yellow(f"{args.enable_llm_cache}")
|
||||
ASCIIColors.white(" └─ LLM Cache for Extraction Enabled: ", end="")
|
||||
|
||||
@ -49,7 +49,7 @@ DEFAULT_MAX_PARALLEL_INSERT = 2 # Default maximum parallel insert operations
|
||||
DEFAULT_EMBEDDING_FUNC_MAX_ASYNC = 8 # Default max async for embedding functions
|
||||
DEFAULT_EMBEDDING_BATCH_NUM = 10 # Default batch size for embedding computations
|
||||
|
||||
# Ollama Server Timetout in seconds
|
||||
# gunicorn worker timeout(as default LLM request timeout if LLM_TIMEOUT is not set)
|
||||
DEFAULT_TIMEOUT = 150
|
||||
|
||||
# Logging configuration defaults
|
||||
|
||||
@ -283,7 +283,7 @@ class LightRAG:
|
||||
"""Name of the LLM model used for generating responses."""
|
||||
|
||||
summary_max_tokens: int = field(
|
||||
default=int(os.getenv("MAX_TOKENS", DEFAULT_SUMMARY_MAX_TOKENS))
|
||||
default=int(os.getenv("SUMMARY_MAX_TOKENS", DEFAULT_SUMMARY_MAX_TOKENS))
|
||||
)
|
||||
"""Maximum number of tokens allowed per LLM response."""
|
||||
|
||||
|
||||
@ -36,7 +36,6 @@ async def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwar
|
||||
llm_instance = OpenAI(
|
||||
model="gpt-4",
|
||||
api_key="your-openai-key",
|
||||
temperature=0.7,
|
||||
)
|
||||
kwargs['llm_instance'] = llm_instance
|
||||
|
||||
@ -91,7 +90,6 @@ async def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwar
|
||||
model=f"openai/{settings.LLM_MODEL}", # Format: "provider/model_name"
|
||||
api_base=settings.LITELLM_URL,
|
||||
api_key=settings.LITELLM_KEY,
|
||||
temperature=0.7,
|
||||
)
|
||||
kwargs['llm_instance'] = llm_instance
|
||||
|
||||
|
||||
@ -77,14 +77,23 @@ async def anthropic_complete_if_cache(
|
||||
if not VERBOSE_DEBUG and logger.level == logging.DEBUG:
|
||||
logging.getLogger("anthropic").setLevel(logging.INFO)
|
||||
|
||||
kwargs.pop("hashing_kv", None)
|
||||
kwargs.pop("keyword_extraction", None)
|
||||
timeout = kwargs.pop("timeout", None)
|
||||
|
||||
anthropic_async_client = (
|
||||
AsyncAnthropic(default_headers=default_headers, api_key=api_key)
|
||||
AsyncAnthropic(
|
||||
default_headers=default_headers, api_key=api_key, timeout=timeout
|
||||
)
|
||||
if base_url is None
|
||||
else AsyncAnthropic(
|
||||
base_url=base_url, default_headers=default_headers, api_key=api_key
|
||||
base_url=base_url,
|
||||
default_headers=default_headers,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
)
|
||||
)
|
||||
kwargs.pop("hashing_kv", None)
|
||||
|
||||
messages: list[dict[str, Any]] = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
@ -59,13 +59,17 @@ async def azure_openai_complete_if_cache(
|
||||
or os.getenv("OPENAI_API_VERSION")
|
||||
)
|
||||
|
||||
kwargs.pop("hashing_kv", None)
|
||||
kwargs.pop("keyword_extraction", None)
|
||||
timeout = kwargs.pop("timeout", None)
|
||||
|
||||
openai_async_client = AsyncAzureOpenAI(
|
||||
azure_endpoint=base_url,
|
||||
azure_deployment=deployment,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
)
|
||||
kwargs.pop("hashing_kv", None)
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
@ -99,7 +99,7 @@ class BindingOptions:
|
||||
group = parser.add_argument_group(f"{cls._binding_name} binding options")
|
||||
for arg_item in cls.args_env_name_type_value():
|
||||
# Handle JSON parsing for list types
|
||||
if arg_item["type"] == List[str]:
|
||||
if arg_item["type"] is List[str]:
|
||||
|
||||
def json_list_parser(value):
|
||||
try:
|
||||
@ -126,6 +126,34 @@ class BindingOptions:
|
||||
default=env_value,
|
||||
help=arg_item["help"],
|
||||
)
|
||||
# Handle JSON parsing for dict types
|
||||
elif arg_item["type"] is dict:
|
||||
|
||||
def json_dict_parser(value):
|
||||
try:
|
||||
parsed = json.loads(value)
|
||||
if not isinstance(parsed, dict):
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"Expected JSON object, got {type(parsed).__name__}"
|
||||
)
|
||||
return parsed
|
||||
except json.JSONDecodeError as e:
|
||||
raise argparse.ArgumentTypeError(f"Invalid JSON: {e}")
|
||||
|
||||
# Get environment variable with JSON parsing
|
||||
env_value = get_env_value(f"{arg_item['env_name']}", argparse.SUPPRESS)
|
||||
if env_value is not argparse.SUPPRESS:
|
||||
try:
|
||||
env_value = json_dict_parser(env_value)
|
||||
except argparse.ArgumentTypeError:
|
||||
env_value = argparse.SUPPRESS
|
||||
|
||||
group.add_argument(
|
||||
f"--{arg_item['argname']}",
|
||||
type=json_dict_parser,
|
||||
default=env_value,
|
||||
help=arg_item["help"],
|
||||
)
|
||||
else:
|
||||
group.add_argument(
|
||||
f"--{arg_item['argname']}",
|
||||
@ -234,8 +262,8 @@ class BindingOptions:
|
||||
if arg_item["help"]:
|
||||
sample_stream.write(f"# {arg_item['help']}\n")
|
||||
|
||||
# Handle JSON formatting for list types
|
||||
if arg_item["type"] == List[str]:
|
||||
# Handle JSON formatting for list and dict types
|
||||
if arg_item["type"] is List[str] or arg_item["type"] is dict:
|
||||
default_value = json.dumps(arg_item["default"])
|
||||
else:
|
||||
default_value = arg_item["default"]
|
||||
@ -431,6 +459,8 @@ class OpenAILLMOptions(BindingOptions):
|
||||
stop: List[str] = field(default_factory=list) # Stop sequences
|
||||
temperature: float = DEFAULT_TEMPERATURE # Controls randomness (0.0 to 2.0)
|
||||
top_p: float = 1.0 # Nucleus sampling parameter (0.0 to 1.0)
|
||||
max_tokens: int = None # Maximum number of tokens to generate(deprecated, use max_completion_tokens instead)
|
||||
extra_body: dict = None # Extra body parameters for OpenRouter of vLLM
|
||||
|
||||
# Help descriptions
|
||||
_help: ClassVar[dict[str, str]] = {
|
||||
@ -443,6 +473,8 @@ class OpenAILLMOptions(BindingOptions):
|
||||
"stop": 'Stop sequences (JSON array of strings, e.g., \'["</s>", "\\n\\n"]\')',
|
||||
"temperature": "Controls randomness (0.0-2.0, higher = more creative)",
|
||||
"top_p": "Nucleus sampling parameter (0.0-1.0, lower = more focused)",
|
||||
"max_tokens": "Maximum number of tokens to generate (deprecated, use max_completion_tokens instead)",
|
||||
"extra_body": 'Extra body parameters for OpenRouter of vLLM (JSON dict, e.g., \'"reasoning": {"reasoning": {"enabled": false}}\')',
|
||||
}
|
||||
|
||||
|
||||
@ -493,6 +525,8 @@ if __name__ == "__main__":
|
||||
"1000",
|
||||
"--openai-llm-stop",
|
||||
'["</s>", "\\n\\n"]',
|
||||
"--openai-llm-reasoning",
|
||||
'{"effort": "high", "max_tokens": 2000, "exclude": false, "enabled": true}',
|
||||
]
|
||||
)
|
||||
print("Final args for LLM and Embedding:")
|
||||
@ -518,5 +552,100 @@ if __name__ == "__main__":
|
||||
print("\nOpenAI LLM options instance:")
|
||||
print(openai_options.asdict())
|
||||
|
||||
# Test creating OpenAI options instance with reasoning parameter
|
||||
openai_options_with_reasoning = OpenAILLMOptions(
|
||||
temperature=0.9,
|
||||
max_completion_tokens=2000,
|
||||
reasoning={
|
||||
"effort": "medium",
|
||||
"max_tokens": 1500,
|
||||
"exclude": True,
|
||||
"enabled": True,
|
||||
},
|
||||
)
|
||||
print("\nOpenAI LLM options instance with reasoning:")
|
||||
print(openai_options_with_reasoning.asdict())
|
||||
|
||||
# Test dict parsing functionality
|
||||
print("\n" + "=" * 50)
|
||||
print("TESTING DICT PARSING FUNCTIONALITY")
|
||||
print("=" * 50)
|
||||
|
||||
# Test valid JSON dict parsing
|
||||
test_parser = ArgumentParser(description="Test dict parsing")
|
||||
OpenAILLMOptions.add_args(test_parser)
|
||||
|
||||
try:
|
||||
test_args = test_parser.parse_args(
|
||||
["--openai-llm-reasoning", '{"effort": "low", "max_tokens": 1000}']
|
||||
)
|
||||
print("✓ Valid JSON dict parsing successful:")
|
||||
print(
|
||||
f" Parsed reasoning: {OpenAILLMOptions.options_dict(test_args)['reasoning']}"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"✗ Valid JSON dict parsing failed: {e}")
|
||||
|
||||
# Test invalid JSON dict parsing
|
||||
try:
|
||||
test_args = test_parser.parse_args(
|
||||
[
|
||||
"--openai-llm-reasoning",
|
||||
'{"effort": "low", "max_tokens": 1000', # Missing closing brace
|
||||
]
|
||||
)
|
||||
print("✗ Invalid JSON should have failed but didn't")
|
||||
except SystemExit:
|
||||
print("✓ Invalid JSON dict parsing correctly rejected")
|
||||
except Exception as e:
|
||||
print(f"✓ Invalid JSON dict parsing correctly rejected: {e}")
|
||||
|
||||
# Test non-dict JSON parsing
|
||||
try:
|
||||
test_args = test_parser.parse_args(
|
||||
[
|
||||
"--openai-llm-reasoning",
|
||||
'["not", "a", "dict"]', # Array instead of dict
|
||||
]
|
||||
)
|
||||
print("✗ Non-dict JSON should have failed but didn't")
|
||||
except SystemExit:
|
||||
print("✓ Non-dict JSON parsing correctly rejected")
|
||||
except Exception as e:
|
||||
print(f"✓ Non-dict JSON parsing correctly rejected: {e}")
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("TESTING ENVIRONMENT VARIABLE SUPPORT")
|
||||
print("=" * 50)
|
||||
|
||||
# Test environment variable support for dict
|
||||
import os
|
||||
|
||||
os.environ["OPENAI_LLM_REASONING"] = (
|
||||
'{"effort": "high", "max_tokens": 3000, "exclude": false}'
|
||||
)
|
||||
|
||||
env_parser = ArgumentParser(description="Test env var dict parsing")
|
||||
OpenAILLMOptions.add_args(env_parser)
|
||||
|
||||
try:
|
||||
env_args = env_parser.parse_args(
|
||||
[]
|
||||
) # No command line args, should use env var
|
||||
reasoning_from_env = OpenAILLMOptions.options_dict(env_args).get(
|
||||
"reasoning"
|
||||
)
|
||||
if reasoning_from_env:
|
||||
print("✓ Environment variable dict parsing successful:")
|
||||
print(f" Parsed reasoning from env: {reasoning_from_env}")
|
||||
else:
|
||||
print("✗ Environment variable dict parsing failed: No reasoning found")
|
||||
except Exception as e:
|
||||
print(f"✗ Environment variable dict parsing failed: {e}")
|
||||
finally:
|
||||
# Clean up environment variable
|
||||
if "OPENAI_LLM_REASONING" in os.environ:
|
||||
del os.environ["OPENAI_LLM_REASONING"]
|
||||
|
||||
else:
|
||||
print(BindingOptions.generate_dot_env_sample())
|
||||
|
||||
@ -59,7 +59,7 @@ async def lollms_model_if_cache(
|
||||
"personality": kwargs.get("personality", -1),
|
||||
"n_predict": kwargs.get("n_predict", None),
|
||||
"stream": stream,
|
||||
"temperature": kwargs.get("temperature", 0.8),
|
||||
"temperature": kwargs.get("temperature", 1.0),
|
||||
"top_k": kwargs.get("top_k", 50),
|
||||
"top_p": kwargs.get("top_p", 0.95),
|
||||
"repeat_penalty": kwargs.get("repeat_penalty", 0.8),
|
||||
|
||||
@ -51,6 +51,8 @@ async def _ollama_model_if_cache(
|
||||
# kwargs.pop("response_format", None) # allow json
|
||||
host = kwargs.pop("host", None)
|
||||
timeout = kwargs.pop("timeout", None)
|
||||
if timeout == 0:
|
||||
timeout = None
|
||||
kwargs.pop("hashing_kv", None)
|
||||
api_key = kwargs.pop("api_key", None)
|
||||
headers = {
|
||||
|
||||
@ -149,18 +149,20 @@ async def openai_complete_if_cache(
|
||||
if not VERBOSE_DEBUG and logger.level == logging.DEBUG:
|
||||
logging.getLogger("openai").setLevel(logging.INFO)
|
||||
|
||||
# Remove special kwargs that shouldn't be passed to OpenAI
|
||||
kwargs.pop("hashing_kv", None)
|
||||
kwargs.pop("keyword_extraction", None)
|
||||
|
||||
# Extract client configuration options
|
||||
client_configs = kwargs.pop("openai_client_configs", {})
|
||||
|
||||
# Create the OpenAI client
|
||||
openai_async_client = create_openai_async_client(
|
||||
api_key=api_key, base_url=base_url, client_configs=client_configs
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
client_configs=client_configs,
|
||||
)
|
||||
|
||||
# Remove special kwargs that shouldn't be passed to OpenAI
|
||||
kwargs.pop("hashing_kv", None)
|
||||
kwargs.pop("keyword_extraction", None)
|
||||
|
||||
# Prepare messages
|
||||
messages: list[dict[str, Any]] = []
|
||||
if system_prompt:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user