From ce1c9c9ddb9ec75a9f61ca00b8bd541a5128354e Mon Sep 17 00:00:00 2001 From: Julian Risch Date: Wed, 14 Jun 2023 09:53:00 +0200 Subject: [PATCH] fix: Relax ChatGPT model name check to support gpt-3.5-turbo-0613 (#5142) * relax model name checking for chatgpt * add unit tests --- haystack/nodes/prompt/invocation_layer/chatgpt.py | 2 +- test/prompt/invocation_layer/test_chatgpt.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/haystack/nodes/prompt/invocation_layer/chatgpt.py b/haystack/nodes/prompt/invocation_layer/chatgpt.py index cd49cfcf6..b973e505d 100644 --- a/haystack/nodes/prompt/invocation_layer/chatgpt.py +++ b/haystack/nodes/prompt/invocation_layer/chatgpt.py @@ -144,4 +144,4 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer): @classmethod def supports(cls, model_name_or_path: str, **kwargs) -> bool: - return model_name_or_path in ["gpt-3.5-turbo", "gpt-4", "gpt-4-32k"] + return any(m for m in ["gpt-3.5-turbo", "gpt-4"] if m in model_name_or_path) diff --git a/test/prompt/invocation_layer/test_chatgpt.py b/test/prompt/invocation_layer/test_chatgpt.py index 17799c655..ccbc0b1bc 100644 --- a/test/prompt/invocation_layer/test_chatgpt.py +++ b/test/prompt/invocation_layer/test_chatgpt.py @@ -27,3 +27,15 @@ def test_custom_api_base(mock_request): invocation_layer.invoke(prompt="dummy_prompt") assert mock_request.call_args.kwargs["url"] == "https://fake_api_base.com/chat/completions" + + +@pytest.mark.unit +def test_supports_correct_model_names(): + for model_name in ["gpt-3.5-turbo", "gpt-4", "gpt-4-32k", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0613"]: + assert ChatGPTInvocationLayer.supports(model_name) + + +@pytest.mark.unit +def test_does_not_support_wrong_model_names(): + for model_name in ["got-3.5-turbo", "wrong_model_name"]: + assert not ChatGPTInvocationLayer.supports(model_name)