mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-25 16:15:35 +00:00
fix: Relax ChatGPT model name check to support gpt-3.5-turbo-0613 (#5142)
* relax model name checking for chatgpt * add unit tests
This commit is contained in:
parent
4c8e0b9d4a
commit
ce1c9c9ddb
@ -144,4 +144,4 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer):
|
||||
|
||||
@classmethod
|
||||
def supports(cls, model_name_or_path: str, **kwargs) -> bool:
|
||||
return model_name_or_path in ["gpt-3.5-turbo", "gpt-4", "gpt-4-32k"]
|
||||
return any(m for m in ["gpt-3.5-turbo", "gpt-4"] if m in model_name_or_path)
|
||||
|
@ -27,3 +27,15 @@ def test_custom_api_base(mock_request):
|
||||
|
||||
invocation_layer.invoke(prompt="dummy_prompt")
|
||||
assert mock_request.call_args.kwargs["url"] == "https://fake_api_base.com/chat/completions"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_supports_correct_model_names():
|
||||
for model_name in ["gpt-3.5-turbo", "gpt-4", "gpt-4-32k", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0613"]:
|
||||
assert ChatGPTInvocationLayer.supports(model_name)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_does_not_support_wrong_model_names():
|
||||
for model_name in ["got-3.5-turbo", "wrong_model_name"]:
|
||||
assert not ChatGPTInvocationLayer.supports(model_name)
|
||||
|
Loading…
x
Reference in New Issue
Block a user