Promptnode timeout (#6282)

This commit is contained in:
x110 2023-11-19 19:32:09 +04:00 committed by GitHub
parent 9b11462bf8
commit d03bffab8b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 23 additions and 4 deletions

View File

@ -33,6 +33,7 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer):
model_name_or_path: str = "gpt-3.5-turbo", model_name_or_path: str = "gpt-3.5-turbo",
max_length: Optional[int] = 500, max_length: Optional[int] = 500,
api_base: str = "https://api.openai.com/v1", api_base: str = "https://api.openai.com/v1",
timeout: Optional[float] = None,
**kwargs, **kwargs,
): ):
""" """
@ -48,7 +49,7 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer):
sensitive content using the [OpenAI Moderation API](https://platform.openai.com/docs/guides/moderation) sensitive content using the [OpenAI Moderation API](https://platform.openai.com/docs/guides/moderation)
if set. If the input or answers are flagged, an empty list is returned in place of the answers. if set. If the input or answers are flagged, an empty list is returned in place of the answers.
""" """
super().__init__(api_key, model_name_or_path, max_length, api_base=api_base, **kwargs) super().__init__(api_key, model_name_or_path, max_length, api_base=api_base, timeout=timeout, **kwargs)
def _extract_token(self, event_data: Dict[str, Any]): def _extract_token(self, event_data: Dict[str, Any]):
delta = event_data["choices"][0]["delta"] delta = event_data["choices"][0]["delta"]
@ -192,12 +193,17 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer):
extra_payload = {"messages": messages} extra_payload = {"messages": messages}
payload = {**base_payload, **extra_payload} payload = {**base_payload, **extra_payload}
if not stream: if not stream:
response = openai_request(url=self.url, headers=self.headers, payload=payload) response = openai_request(url=self.url, headers=self.headers, payload=payload, timeout=self.timeout)
_check_openai_finish_reason(result=response, payload=payload) _check_openai_finish_reason(result=response, payload=payload)
assistant_response = [choice["message"]["content"].strip() for choice in response["choices"]] assistant_response = [choice["message"]["content"].strip() for choice in response["choices"]]
else: else:
response = openai_request( response = openai_request(
url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True url=self.url,
headers=self.headers,
payload=payload,
timeout=self.timeout,
read_response=False,
stream=True,
) )
handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler()) handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler())
assistant_response = self._process_streaming_response(response=response, stream_handler=handler) assistant_response = self._process_streaming_response(response=response, stream_handler=handler)

View File

@ -37,6 +37,7 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
max_length: Optional[int] = 100, max_length: Optional[int] = 100,
api_base: str = "https://api.openai.com/v1", api_base: str = "https://api.openai.com/v1",
openai_organization: Optional[str] = None, openai_organization: Optional[str] = None,
timeout: Optional[float] = None,
**kwargs, **kwargs,
): ):
""" """
@ -66,6 +67,7 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
self.api_key = api_key self.api_key = api_key
self.api_base = api_base self.api_base = api_base
self.openai_organization = openai_organization self.openai_organization = openai_organization
self.timeout = timeout
# 16 is the default length for answers from OpenAI shown in the docs # 16 is the default length for answers from OpenAI shown in the docs
# here, https://platform.openai.com/docs/api-reference/completions/create. # here, https://platform.openai.com/docs/api-reference/completions/create.

View File

@ -36,6 +36,7 @@ class PromptModel(BaseComponent):
model_name_or_path: str = "google/flan-t5-base", model_name_or_path: str = "google/flan-t5-base",
max_length: Optional[int] = 100, max_length: Optional[int] = 100,
api_key: Optional[str] = None, api_key: Optional[str] = None,
timeout: Optional[float] = None,
use_auth_token: Optional[Union[str, bool]] = None, use_auth_token: Optional[Union[str, bool]] = None,
use_gpu: Optional[bool] = None, use_gpu: Optional[bool] = None,
devices: Optional[List[Union[str, "torch.device"]]] = None, devices: Optional[List[Union[str, "torch.device"]]] = None,
@ -63,6 +64,7 @@ class PromptModel(BaseComponent):
self.model_name_or_path = model_name_or_path self.model_name_or_path = model_name_or_path
self.max_length = max_length self.max_length = max_length
self.api_key = api_key self.api_key = api_key
self.timeout = timeout
self.use_auth_token = use_auth_token self.use_auth_token = use_auth_token
self.use_gpu = use_gpu self.use_gpu = use_gpu
self.devices = devices self.devices = devices
@ -75,6 +77,7 @@ class PromptModel(BaseComponent):
) -> PromptModelInvocationLayer: ) -> PromptModelInvocationLayer:
kwargs = { kwargs = {
"api_key": self.api_key, "api_key": self.api_key,
"timeout": self.timeout,
"use_auth_token": self.use_auth_token, "use_auth_token": self.use_auth_token,
"use_gpu": self.use_gpu, "use_gpu": self.use_gpu,
"devices": self.devices, "devices": self.devices,

View File

@ -57,6 +57,7 @@ class PromptNode(BaseComponent):
output_variable: Optional[str] = None, output_variable: Optional[str] = None,
max_length: Optional[int] = 100, max_length: Optional[int] = 100,
api_key: Optional[str] = None, api_key: Optional[str] = None,
timeout: Optional[float] = None,
use_auth_token: Optional[Union[str, bool]] = None, use_auth_token: Optional[Union[str, bool]] = None,
use_gpu: Optional[bool] = None, use_gpu: Optional[bool] = None,
devices: Optional[List[Union[str, "torch.device"]]] = None, devices: Optional[List[Union[str, "torch.device"]]] = None,
@ -113,6 +114,7 @@ class PromptNode(BaseComponent):
model_name_or_path=model_name_or_path, model_name_or_path=model_name_or_path,
max_length=max_length, max_length=max_length,
api_key=api_key, api_key=api_key,
timeout=timeout,
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
use_gpu=use_gpu, use_gpu=use_gpu,
devices=devices, devices=devices,

View File

@ -112,7 +112,7 @@ def openai_request(
url: str, url: str,
headers: Dict, headers: Dict,
payload: Dict, payload: Dict,
timeout: Union[float, Tuple[float, float]] = OPENAI_TIMEOUT, timeout: Optional[Union[float, Tuple[float, float]]] = None,
read_response: Optional[bool] = True, read_response: Optional[bool] = True,
**kwargs, **kwargs,
): ):
@ -124,6 +124,8 @@ def openai_request(
:param timeout: The timeout length of the request. The default is 30s. :param timeout: The timeout length of the request. The default is 30s.
:param read_response: Whether to read the response as JSON. The default is True. :param read_response: Whether to read the response as JSON. The default is True.
""" """
if timeout is None:
timeout = OPENAI_TIMEOUT
response = requests.request("POST", url, headers=headers, data=json.dumps(payload), timeout=timeout, **kwargs) response = requests.request("POST", url, headers=headers, data=json.dumps(payload), timeout=timeout, **kwargs)
if read_response: if read_response:
json_response = json.loads(response.text) json_response = json.loads(response.text)

View File

@ -0,0 +1,4 @@
---
enhancements:
- |
Introduces a new timeout keyword argument in PromptNode, addressing and fixing the issue #5380 for enhanced control over individual calls to OpenAI.