fix: Relax ChatGPT model name check to support gpt-3.5-turbo-0613 (#5142)

* relax model name checking for chatgpt

* add unit tests
This commit is contained in:
Julian Risch 2023-06-14 09:53:00 +02:00 committed by GitHub
parent 4c8e0b9d4a
commit ce1c9c9ddb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 1 deletions

View File

@ -144,4 +144,4 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer):
@classmethod
def supports(cls, model_name_or_path: str, **kwargs) -> bool:
return model_name_or_path in ["gpt-3.5-turbo", "gpt-4", "gpt-4-32k"]
return any(m for m in ["gpt-3.5-turbo", "gpt-4"] if m in model_name_or_path)

View File

@ -27,3 +27,15 @@ def test_custom_api_base(mock_request):
invocation_layer.invoke(prompt="dummy_prompt")
assert mock_request.call_args.kwargs["url"] == "https://fake_api_base.com/chat/completions"
@pytest.mark.unit
def test_supports_correct_model_names():
for model_name in ["gpt-3.5-turbo", "gpt-4", "gpt-4-32k", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0613"]:
assert ChatGPTInvocationLayer.supports(model_name)
@pytest.mark.unit
def test_does_not_support_wrong_model_names():
for model_name in ["got-3.5-turbo", "wrong_model_name"]:
assert not ChatGPTInvocationLayer.supports(model_name)