diff --git a/haystack/nodes/prompt/invocation_layer/chatgpt.py b/haystack/nodes/prompt/invocation_layer/chatgpt.py index cd49cfcf6..b973e505d 100644 --- a/haystack/nodes/prompt/invocation_layer/chatgpt.py +++ b/haystack/nodes/prompt/invocation_layer/chatgpt.py @@ -144,4 +144,4 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer): @classmethod def supports(cls, model_name_or_path: str, **kwargs) -> bool: - return model_name_or_path in ["gpt-3.5-turbo", "gpt-4", "gpt-4-32k"] + return any(m for m in ["gpt-3.5-turbo", "gpt-4"] if m in model_name_or_path) diff --git a/test/prompt/invocation_layer/test_chatgpt.py b/test/prompt/invocation_layer/test_chatgpt.py index 17799c655..ccbc0b1bc 100644 --- a/test/prompt/invocation_layer/test_chatgpt.py +++ b/test/prompt/invocation_layer/test_chatgpt.py @@ -27,3 +27,15 @@ def test_custom_api_base(mock_request): invocation_layer.invoke(prompt="dummy_prompt") assert mock_request.call_args.kwargs["url"] == "https://fake_api_base.com/chat/completions" + + +@pytest.mark.unit +def test_supports_correct_model_names(): + for model_name in ["gpt-3.5-turbo", "gpt-4", "gpt-4-32k", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0613"]: + assert ChatGPTInvocationLayer.supports(model_name) + + +@pytest.mark.unit +def test_does_not_support_wrong_model_names(): + for model_name in ["got-3.5-turbo", "wrong_model_name"]: + assert not ChatGPTInvocationLayer.supports(model_name)