refactor: simplify registration of PromptModelInvocationLayer (#4339)

* use __init_subclass__ and remove registering functions
This commit is contained in:
ZanSara 2023-03-07 20:53:48 +01:00 committed by GitHub
parent 7d5e7c089c
commit 024332f98f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 29 additions and 48 deletions

View File

@ -1,16 +1,15 @@
import copy import copy
import logging import logging
import pydoc
import re import re
from abc import ABC from abc import ABC
from string import Template from string import Template
from typing import Dict, List, Optional, Tuple, Union, Any, Type, Iterator from typing import Dict, List, Optional, Tuple, Union, Any, Iterator, Type
import torch import torch
from haystack import MultiLabel from haystack import MultiLabel
from haystack.nodes.base import BaseComponent from haystack.nodes.base import BaseComponent
from haystack.nodes.prompt.providers import PromptModelInvocationLayer, known_providers from haystack.nodes.prompt.providers import PromptModelInvocationLayer
from haystack.schema import Document from haystack.schema import Document
from haystack.telemetry_2 import send_event from haystack.telemetry_2 import send_event
@ -186,7 +185,7 @@ class PromptModel(BaseComponent):
use_auth_token: Optional[Union[str, bool]] = None, use_auth_token: Optional[Union[str, bool]] = None,
use_gpu: Optional[bool] = None, use_gpu: Optional[bool] = None,
devices: Optional[List[Union[str, torch.device]]] = None, devices: Optional[List[Union[str, torch.device]]] = None,
invocation_layer_class: Optional[str] = None, invocation_layer_class: Optional[Type[PromptModelInvocationLayer]] = None,
model_kwargs: Optional[Dict] = None, model_kwargs: Optional[Dict] = None,
): ):
""" """
@ -198,8 +197,7 @@ class PromptModel(BaseComponent):
:param use_auth_token: The Hugging Face token to use. :param use_auth_token: The Hugging Face token to use.
:param use_gpu: Whether to use GPU or not. :param use_gpu: Whether to use GPU or not.
:param devices: The devices to use where the model is loaded. :param devices: The devices to use where the model is loaded.
:param invocation_layer_class: The custom invocation layer class to use. Use a dotted notation indicating the :param invocation_layer_class: The custom invocation layer class to use. If None, known invocation layers are used.
path from a modules global scope to the class. If None, known invocation layers are used.
:param model_kwargs: Additional keyword arguments passed to the underlying model. :param model_kwargs: Additional keyword arguments passed to the underlying model.
Note that Azure OpenAI InstructGPT models require two additional parameters: azure_base_url (The URL for the Note that Azure OpenAI InstructGPT models require two additional parameters: azure_base_url (The URL for the
@ -216,33 +214,11 @@ class PromptModel(BaseComponent):
self.devices = devices self.devices = devices
self.model_kwargs = model_kwargs if model_kwargs else {} self.model_kwargs = model_kwargs if model_kwargs else {}
self.model_invocation_layer = self.create_invocation_layer(invocation_layer_class=invocation_layer_class)
self.invocation_layer_classes: List[Type[PromptModelInvocationLayer]] = known_providers() def create_invocation_layer(
if invocation_layer_class: self, invocation_layer_class: Optional[Type[PromptModelInvocationLayer]]
klass: Optional[Type[PromptModelInvocationLayer]] = None ) -> PromptModelInvocationLayer:
if isinstance(invocation_layer_class, str):
# try to find the invocation_layer_class provider class
search_path: List[str] = [
f"haystack.nodes.prompt.providers.{invocation_layer_class}",
invocation_layer_class,
]
klass = next((pydoc.locate(path) for path in search_path if pydoc.locate(path)), None) # type: ignore
if not klass:
raise ValueError(
f"Could not locate PromptModelInvocationLayer class with name {invocation_layer_class}. "
f"Make sure to pass the full path to the class."
)
if not issubclass(klass, PromptModelInvocationLayer):
raise ValueError(f"Class {invocation_layer_class} is not a subclass of PromptModelInvocationLayer.")
logger.info("Registering custom invocation layer class %s", klass)
self.register(klass)
self.model_invocation_layer = self.create_invocation_layer()
def create_invocation_layer(self) -> PromptModelInvocationLayer:
kwargs = { kwargs = {
"api_key": self.api_key, "api_key": self.api_key,
"use_auth_token": self.use_auth_token, "use_auth_token": self.use_auth_token,
@ -251,26 +227,24 @@ class PromptModel(BaseComponent):
} }
all_kwargs = {**self.model_kwargs, **kwargs} all_kwargs = {**self.model_kwargs, **kwargs}
if invocation_layer_class:
return invocation_layer_class(
model_name_or_path=self.model_name_or_path, max_length=self.max_length, **all_kwargs
)
# search all invocation layer classes and find the first one that supports the model, # search all invocation layer classes and find the first one that supports the model,
# then create an instance of that invocation layer # then create an instance of that invocation layer
for invocation_layer in self.invocation_layer_classes: for invocation_layer in PromptModelInvocationLayer.invocation_layer_providers:
if invocation_layer.supports(self.model_name_or_path, **all_kwargs): if invocation_layer.supports(self.model_name_or_path, **all_kwargs):
return invocation_layer( return invocation_layer(
model_name_or_path=self.model_name_or_path, max_length=self.max_length, **all_kwargs model_name_or_path=self.model_name_or_path, max_length=self.max_length, **all_kwargs
) )
raise ValueError( raise ValueError(
f"Model {self.model_name_or_path} is not supported - no matching invocation layer found." f"Model {self.model_name_or_path} is not supported - no matching invocation layer found."
f" Currently supported invocation layers are: {self.invocation_layer_classes}" f" Currently supported invocation layers are: {PromptModelInvocationLayer.invocation_layer_providers}"
f" You can implement and provide custom invocation layer for {self.model_name_or_path} via PromptModel init." f" You can implement and provide custom invocation layer for {self.model_name_or_path} by subclassing "
"PromptModelInvocationLayer."
) )
def register(self, invocation_layer: Type[PromptModelInvocationLayer]):
"""
Registers additional prompt model invocation layer. It takes a function that returns a boolean as a
matching condition on `model_name_or_path` and a class that implements `PromptModelInvocationLayer` interface.
"""
self.invocation_layer_classes.append(invocation_layer)
def invoke(self, prompt: Union[str, List[str]], **kwargs) -> List[str]: def invoke(self, prompt: Union[str, List[str]], **kwargs) -> List[str]:
""" """
It takes in a prompt, and returns a list of responses using the underlying invocation layer. It takes in a prompt, and returns a list of responses using the underlying invocation layer.

View File

@ -35,6 +35,8 @@ class PromptModelInvocationLayer:
could be even remote, for example, a call to a remote API endpoint. could be even remote, for example, a call to a remote API endpoint.
""" """
invocation_layer_providers: List[Type["PromptModelInvocationLayer"]] = []
def __init__(self, model_name_or_path: str, **kwargs): def __init__(self, model_name_or_path: str, **kwargs):
""" """
Creates a new PromptModelInvocationLayer instance. Creates a new PromptModelInvocationLayer instance.
@ -47,6 +49,15 @@ class PromptModelInvocationLayer:
self.model_name_or_path = model_name_or_path self.model_name_or_path = model_name_or_path
def __init_subclass__(cls, **kwargs):
"""
Used to register user-defined invocation layers.
Called when a subclass of PromptModelInvocationLayer is imported.
"""
super().__init_subclass__(**kwargs)
cls.invocation_layer_providers.append(cls)
@abstractmethod @abstractmethod
def invoke(self, *args, **kwargs): def invoke(self, *args, **kwargs):
""" """
@ -76,10 +87,6 @@ class PromptModelInvocationLayer:
pass pass
def known_providers() -> List[Type[PromptModelInvocationLayer]]:
return [HFLocalInvocationLayer, OpenAIInvocationLayer, AzureOpenAIInvocationLayer]
class StopWordsCriteria(StoppingCriteria): class StopWordsCriteria(StoppingCriteria):
""" """
Stops text generation if any one of the stop words is generated. Stops text generation if any one of the stop words is generated.

View File

@ -88,8 +88,8 @@ def test_prompt_template_repr():
@pytest.mark.unit @pytest.mark.unit
def test_prompt_node_with_custom_invocation_layer_from_string(): def test_prompt_node_with_custom_invocation_layer():
model = PromptModel("fake_model", invocation_layer_class="test.nodes.test_prompt_node.CustomInvocationLayer") model = PromptModel("fake_model")
pn = PromptNode(model_name_or_path=model) pn = PromptNode(model_name_or_path=model)
output = pn("Some fake invocation") output = pn("Some fake invocation")