""" Copyright 2024, Zep Software, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import logging import typing import openai from openai import AsyncOpenAI from openai.types.chat import ChatCompletionMessageParam from pydantic import BaseModel from ..prompts.models import Message from .client import LLMClient from .config import LLMConfig from .errors import RateLimitError, RefusalError logger = logging.getLogger(__name__) DEFAULT_MODEL = 'gpt-4o-mini' class OpenAIClient(LLMClient): """ OpenAIClient is a client class for interacting with OpenAI's language models. This class extends the LLMClient and provides methods to initialize the client, get an embedder, and generate responses from the language model. Attributes: client (AsyncOpenAI): The OpenAI client used to interact with the API. model (str): The model name to use for generating responses. temperature (float): The temperature to use for generating responses. max_tokens (int): The maximum number of tokens to generate in a response. Methods: __init__(config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None): Initializes the OpenAIClient with the provided configuration, cache setting, and client. _generate_response(messages: list[Message]) -> dict[str, typing.Any]: Generates a response from the language model based on the provided messages. """ def __init__( self, config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None ): """ Initialize the OpenAIClient with the provided configuration, cache setting, and client. Args: config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens. cache (bool): Whether to use caching for responses. Defaults to False. client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created. """ # removed caching to simplify the `generate_response` override if cache: raise NotImplementedError('Caching is not implemented for OpenAI') if config is None: config = LLMConfig() super().__init__(config, cache) if client is None: self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url) else: self.client = client async def _generate_response( self, messages: list[Message], response_model: type[BaseModel] | None = None ) -> dict[str, typing.Any]: openai_messages: list[ChatCompletionMessageParam] = [] for m in messages: if m.role == 'user': openai_messages.append({'role': 'user', 'content': m.content}) elif m.role == 'system': openai_messages.append({'role': 'system', 'content': m.content}) try: response = await self.client.beta.chat.completions.parse( model=self.model or DEFAULT_MODEL, messages=openai_messages, temperature=self.temperature, max_tokens=self.max_tokens, response_format=response_model, # type: ignore ) response_object = response.choices[0].message if response_object.parsed: return response_object.parsed.model_dump() elif response_object.refusal: raise RefusalError(response_object.refusal) else: raise Exception('No response from LLM') except openai.LengthFinishReasonError as e: raise Exception(f'Output length exceeded max tokens {self.max_tokens}: {e}') from e except openai.RateLimitError as e: raise RateLimitError from e except Exception as e: logger.error(f'Error in generating LLM response: {e}') raise async def generate_response( self, messages: list[Message], response_model: type[BaseModel] | None = None ) -> dict[str, typing.Any]: response = await self._generate_response(messages, response_model) return response