2024-08-15 12:03:41 -04:00
|
|
|
import json
|
2024-08-22 12:26:13 -07:00
|
|
|
import logging
|
2024-08-23 08:15:44 -07:00
|
|
|
import typing
|
2024-08-22 12:26:13 -07:00
|
|
|
|
2024-08-15 12:03:41 -04:00
|
|
|
from openai import AsyncOpenAI
|
2024-08-23 08:15:44 -07:00
|
|
|
from openai.types.chat import ChatCompletionMessageParam
|
2024-08-22 12:26:13 -07:00
|
|
|
|
2024-08-23 08:15:44 -07:00
|
|
|
from ..prompts.models import Message
|
2024-08-15 12:03:41 -04:00
|
|
|
from .client import LLMClient
|
|
|
|
from .config import LLMConfig
|
|
|
|
|
2024-08-16 09:29:57 -04:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2024-08-15 12:03:41 -04:00
|
|
|
|
|
|
|
class OpenAIClient(LLMClient):
|
2024-08-23 14:18:45 -04:00
|
|
|
def __init__(self, config: LLMConfig):
|
|
|
|
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
|
|
|
self.model = config.model
|
2024-08-15 12:03:41 -04:00
|
|
|
|
2024-08-23 14:18:45 -04:00
|
|
|
def get_embedder(self) -> typing.Any:
|
|
|
|
return self.client.embeddings
|
2024-08-23 08:15:44 -07:00
|
|
|
|
2024-08-23 14:18:45 -04:00
|
|
|
async def generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
|
|
|
|
openai_messages: list[ChatCompletionMessageParam] = []
|
|
|
|
for m in messages:
|
|
|
|
if m.role == 'user':
|
|
|
|
openai_messages.append({'role': 'user', 'content': m.content})
|
|
|
|
elif m.role == 'system':
|
|
|
|
openai_messages.append({'role': 'system', 'content': m.content})
|
|
|
|
try:
|
|
|
|
response = await self.client.chat.completions.create(
|
|
|
|
model=self.model,
|
|
|
|
messages=openai_messages,
|
|
|
|
temperature=0.1,
|
|
|
|
max_tokens=3000,
|
|
|
|
response_format={'type': 'json_object'},
|
|
|
|
)
|
|
|
|
result = response.choices[0].message.content or ''
|
|
|
|
return json.loads(result)
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f'Error in generating LLM response: {e}')
|
|
|
|
raise
|