From 1ac9ca7fac276f1cd299c8f2923f609a82d06b32 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Wed, 12 Apr 2023 09:38:04 +0200 Subject: [PATCH] merge (#4620) --- haystack/nodes/prompt/invocation_layer/chatgpt.py | 3 +-- haystack/nodes/prompt/prompt_model.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/haystack/nodes/prompt/invocation_layer/chatgpt.py b/haystack/nodes/prompt/invocation_layer/chatgpt.py index f00b5526d..a88567f7b 100644 --- a/haystack/nodes/prompt/invocation_layer/chatgpt.py +++ b/haystack/nodes/prompt/invocation_layer/chatgpt.py @@ -111,5 +111,4 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer): @classmethod def supports(cls, model_name_or_path: str, **kwargs) -> bool: - valid_model = any(m for m in ["gpt-3.5-turbo"] if m in model_name_or_path) - return valid_model + return model_name_or_path in ["gpt-3.5-turbo", "gpt-4", "gpt-4-32k"] diff --git a/haystack/nodes/prompt/prompt_model.py b/haystack/nodes/prompt/prompt_model.py index b83fb2387..cb14b0942 100644 --- a/haystack/nodes/prompt/prompt_model.py +++ b/haystack/nodes/prompt/prompt_model.py @@ -11,7 +11,7 @@ logger = logging.getLogger(__name__) def instruction_following_models() -> List[str]: - return ["flan", "mt0", "bloomz", "davinci", "opt-iml"] + return ["flan", "mt0", "bloomz", "davinci", "opt-iml", "gpt-3.5-turbo", "gpt-4"] class PromptModel(BaseComponent):