LightRAG/lightrag/llm/openai.py
2025-02-06 23:12:35 +08:00

313 lines
9.0 KiB
Python

"""
OpenAI LLM Interface Module
==========================
This module provides interfaces for interacting with openai's language models,
including text generation and embedding capabilities.
Author: Lightrag team
Created: 2024-01-24
License: MIT License
Copyright (c) 2024 Lightrag
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
Version: 1.0.0
Change Log:
- 1.0.0 (2024-01-24): Initial release
* Added async chat completion support
* Added embedding generation
* Added stream response capability
Dependencies:
- openai
- numpy
- pipmaster
- Python >= 3.10
Usage:
from llm_interfaces.openai import openai_model_complete, openai_embed
"""
__version__ = "1.0.0"
__author__ = "lightrag Team"
__status__ = "Production"
import sys
import os
if sys.version_info < (3, 9):
from typing import AsyncIterator
else:
from collections.abc import AsyncIterator
import pipmaster as pm # Pipmaster for dynamic library install
# install specific modules
if not pm.is_installed("openai"):
pm.install("openai")
from openai import (
AsyncOpenAI,
APIConnectionError,
RateLimitError,
APITimeoutError,
)
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.utils import (
wrap_embedding_func_with_attrs,
locate_json_string_body_from_string,
safe_unicode_decode,
logger,
)
from lightrag.types import GPTKeywordExtractionFormat
from lightrag.api import __api_version__
import numpy as np
from typing import Union
class InvalidResponseError(Exception):
"""Custom exception class for triggering retry mechanism"""
pass
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError, InvalidResponseError)
),
)
async def openai_complete_if_cache(
model,
prompt,
system_prompt=None,
history_messages=None,
base_url=None,
api_key=None,
**kwargs,
) -> str:
if history_messages is None:
history_messages = []
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
default_headers = {
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
"Content-Type": "application/json",
}
openai_async_client = (
AsyncOpenAI(default_headers=default_headers)
if base_url is None
else AsyncOpenAI(base_url=base_url, default_headers=default_headers)
)
kwargs.pop("hashing_kv", None)
kwargs.pop("keyword_extraction", None)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
# 添加日志输出
logger.debug("===== Query Input to LLM =====")
logger.debug(f"Model: {model} Base URL: {base_url}")
logger.debug(f"Additional kwargs: {kwargs}")
logger.debug(f"Query: {prompt}")
logger.debug(f"System prompt: {system_prompt}")
# logger.debug(f"Messages: {messages}")
try:
if "response_format" in kwargs:
response = await openai_async_client.beta.chat.completions.parse(
model=model, messages=messages, **kwargs
)
else:
response = await openai_async_client.chat.completions.create(
model=model, messages=messages, **kwargs
)
except APIConnectionError as e:
logger.error(f"OpenAI API Connection Error: {str(e)}")
raise
except RateLimitError as e:
logger.error(f"OpenAI API Rate Limit Error: {str(e)}")
raise
except APITimeoutError as e:
logger.error(f"OpenAI API Timeout Error: {str(e)}")
raise
except Exception as e:
logger.error(f"OpenAI API Call Failed: {str(e)}")
logger.error(f"Model: {model}")
logger.error(f"Request parameters: {kwargs}")
raise
if hasattr(response, "__aiter__"):
async def inner():
try:
async for chunk in response:
content = chunk.choices[0].delta.content
if content is None:
continue
if r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8"))
yield content
except Exception as e:
logger.error(f"Error in stream response: {str(e)}")
raise
return inner()
else:
if (
not response
or not response.choices
or not hasattr(response.choices[0], "message")
or not hasattr(response.choices[0].message, "content")
):
logger.error("Invalid response from OpenAI API")
raise InvalidResponseError("Invalid response from OpenAI API")
content = response.choices[0].message.content
if not content or content.strip() == "":
logger.error("Received empty content from OpenAI API")
raise InvalidResponseError("Received empty content from OpenAI API")
if r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8"))
return content
async def openai_complete(
prompt,
system_prompt=None,
history_messages=None,
keyword_extraction=False,
**kwargs,
) -> Union[str, AsyncIterator[str]]:
if history_messages is None:
history_messages = []
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
kwargs["response_format"] = "json"
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
return await openai_complete_if_cache(
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
async def gpt_4o_complete(
prompt,
system_prompt=None,
history_messages=None,
keyword_extraction=False,
**kwargs,
) -> str:
if history_messages is None:
history_messages = []
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat
return await openai_complete_if_cache(
"gpt-4o",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
async def gpt_4o_mini_complete(
prompt,
system_prompt=None,
history_messages=None,
keyword_extraction=False,
**kwargs,
) -> str:
if history_messages is None:
history_messages = []
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
kwargs["response_format"] = GPTKeywordExtractionFormat
return await openai_complete_if_cache(
"gpt-4o-mini",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
async def nvidia_openai_complete(
prompt,
system_prompt=None,
history_messages=None,
keyword_extraction=False,
**kwargs,
) -> str:
if history_messages is None:
history_messages = []
keyword_extraction = kwargs.pop("keyword_extraction", None)
result = await openai_complete_if_cache(
"nvidia/llama-3.1-nemotron-70b-instruct", # context length 128k
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
base_url="https://integrate.api.nvidia.com/v1",
**kwargs,
)
if keyword_extraction: # TODO: use JSON API
return locate_json_string_body_from_string(result)
return result
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def openai_embed(
texts: list[str],
model: str = "text-embedding-3-small",
base_url: str = None,
api_key: str = None,
) -> np.ndarray:
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
default_headers = {
"User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
"Content-Type": "application/json",
}
openai_async_client = (
AsyncOpenAI(default_headers=default_headers)
if base_url is None
else AsyncOpenAI(base_url=base_url, default_headers=default_headers)
)
response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="float"
)
return np.array([dp.embedding for dp in response.data])