feat: Add support for OpenAI's gpt-3.5-turbo-instruct model (#5837)

* support gpt-3.5.-turbo-instruct

* add release note
This commit is contained in:
Malte Pietsch 2023-09-19 16:06:43 +02:00 committed by GitHub
parent 41126397d6
commit aa3cc3d5ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 11 additions and 3 deletions

View File

@ -136,5 +136,8 @@ class ChatGPTInvocationLayer(OpenAIInvocationLayer):
@classmethod
def supports(cls, model_name_or_path: str, **kwargs) -> bool:
valid_model = any(m for m in ["gpt-3.5-turbo", "gpt-4"] if m in model_name_or_path)
valid_model = (
any(m for m in ["gpt-3.5-turbo", "gpt-4"] if m in model_name_or_path)
and not "gpt-3.5-turbo-instruct" in model_name_or_path
)
return valid_model and not has_azure_parameters(**kwargs)

View File

@ -231,7 +231,7 @@ class OpenAIInvocationLayer(PromptModelInvocationLayer):
@classmethod
def supports(cls, model_name_or_path: str, **kwargs) -> bool:
valid_model = model_name_or_path in ["ada", "babbage", "davinci", "curie"] or any(
valid_model = model_name_or_path in ["ada", "babbage", "davinci", "curie", "gpt-3.5-turbo-instruct"] or any(
m in model_name_or_path for m in ["-ada-", "-babbage-", "-davinci-", "-curie-"]
)
return valid_model and not has_azure_parameters(**kwargs)

View File

@ -0,0 +1,4 @@
---
enhancements:
- |
Support OpenAI's new `gpt-3.5-turbo-instruct` model

View File

@ -38,7 +38,7 @@ def test_supports_correct_model_names():
@pytest.mark.unit
def test_does_not_support_wrong_model_names():
for model_name in ["got-3.5-turbo", "wrong_model_name"]:
for model_name in ["got-3.5-turbo", "wrong_model_name", "gpt-3.5-turbo-instruct"]:
assert not ChatGPTInvocationLayer.supports(model_name)

View File

@ -132,6 +132,7 @@ def test_supports(load_openai_tokenizer):
assert layer.supports("davinci")
assert layer.supports("text-ada-001")
assert layer.supports("text-davinci-002")
assert layer.supports("gpt-3.5-turbo-instruct")
# the following model contains "ada" in the name, but it's not from OpenAI
assert not layer.supports("ybelkada/mpt-7b-bf16-sharded")