Added fix when using Azure OpenAI with gpt-4 (#5105)

This commit is contained in:
erwanlc 2023-06-19 10:17:58 +02:00 committed by GitHub
parent f52477d31b
commit 97f136b901
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,5 +1,6 @@
from typing import Dict, List, Optional, Tuple, Union, Any, Type, overload
import logging
import re
from typing import Any, Dict, List, Optional, Tuple, Type, Union, overload
from haystack.nodes.base import BaseComponent
from haystack.nodes.prompt.invocation_layer import PromptModelInvocationLayer
@ -84,9 +85,16 @@ class PromptModel(BaseComponent):
return invocation_layer_class(
model_name_or_path=self.model_name_or_path, max_length=self.max_length, **all_kwargs
)
# search all invocation layer classes and find the first one that supports the model,
potential_invocation_layer = PromptModelInvocationLayer.invocation_layer_providers
# if azure_base_url exist as an argument, invocation layer classes are filtered to only keep the ones relatives to azure
if "azure_base_url" in self.model_kwargs:
potential_invocation_layer = [
layer for layer in potential_invocation_layer if re.search(r"azure", layer.__name__, re.IGNORECASE)
]
# search all invocation layer classes candidates and find the first one that supports the model,
# then create an instance of that invocation layer
for invocation_layer in PromptModelInvocationLayer.invocation_layer_providers:
for invocation_layer in potential_invocation_layer:
if invocation_layer.supports(self.model_name_or_path, **all_kwargs):
return invocation_layer(
model_name_or_path=self.model_name_or_path, max_length=self.max_length, **all_kwargs