LightRAG/lightrag/llm/azure_openai.py

189 lines
5.4 KiB
Python
Raw Normal View History

"""
Azure OpenAI LLM Interface Module
==========================
This module provides interfaces for interacting with aure openai's language models,
including text generation and embedding capabilities.
Author: Lightrag team
Created: 2024-01-24
License: MIT License
Copyright (c) 2024 Lightrag
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
Version: 1.0.0
Change Log:
- 1.0.0 (2024-01-24): Initial release
* Added async chat completion support
* Added embedding generation
* Added stream response capability
Dependencies:
- openai
- numpy
- pipmaster
- Python >= 3.10
Usage:
from llm_interfaces.azure_openai import azure_openai_model_complete, azure_openai_embed
"""
__version__ = "1.0.0"
__author__ = "lightrag Team"
__status__ = "Production"
import os
import pipmaster as pm # Pipmaster for dynamic library install
# install specific modules
if not pm.is_installed("openai"):
pm.install("openai")
if not pm.is_installed("tenacity"):
pm.install("tenacity")
from openai import (
AsyncAzureOpenAI,
APIConnectionError,
RateLimitError,
APITimeoutError,
)
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.utils import (
wrap_embedding_func_with_attrs,
locate_json_string_body_from_string,
safe_unicode_decode,
)
import numpy as np
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APIConnectionError)
),
)
async def azure_openai_complete_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
base_url=None,
api_key=None,
api_version=None,
**kwargs,
):
if api_key:
os.environ["AZURE_OPENAI_API_KEY"] = api_key
if base_url:
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
if api_version:
os.environ["AZURE_OPENAI_API_VERSION"] = api_version
openai_async_client = AsyncAzureOpenAI(
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
)
kwargs.pop("hashing_kv", None)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
if prompt is not None:
messages.append({"role": "user", "content": prompt})
if "response_format" in kwargs:
response = await openai_async_client.beta.chat.completions.parse(
model=model, messages=messages, **kwargs
)
else:
response = await openai_async_client.chat.completions.create(
model=model, messages=messages, **kwargs
)
if hasattr(response, "__aiter__"):
async def inner():
async for chunk in response:
if len(chunk.choices) == 0:
continue
content = chunk.choices[0].delta.content
if content is None:
continue
if r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8"))
yield content
return inner()
else:
content = response.choices[0].message.content
if r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8"))
return content
async def azure_openai_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
keyword_extraction = kwargs.pop("keyword_extraction", None)
result = await azure_openai_complete_if_cache(
os.getenv("LLM_MODEL", "gpt-4o-mini"),
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
if keyword_extraction: # TODO: use JSON API
return locate_json_string_body_from_string(result)
return result
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8191)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def azure_openai_embed(
texts: list[str],
model: str = "text-embedding-3-small",
base_url: str = None,
api_key: str = None,
api_version: str = None,
) -> np.ndarray:
if api_key:
os.environ["AZURE_OPENAI_API_KEY"] = api_key
if base_url:
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
if api_version:
os.environ["AZURE_OPENAI_API_VERSION"] = api_version
openai_async_client = AsyncAzureOpenAI(
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
)
response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="float"
)
return np.array([dp.embedding for dp in response.data])