From 024332f98f97d487d695035003644b55708b0d62 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Tue, 7 Mar 2023 20:53:48 +0100 Subject: [PATCH] refactor: simplify registration of `PromptModelInvocationLayer` (#4339) * use __init_subclass__ and remove registering functions --- haystack/nodes/prompt/prompt_node.py | 58 ++++++++-------------------- haystack/nodes/prompt/providers.py | 15 +++++-- test/nodes/test_prompt_node.py | 4 +- 3 files changed, 29 insertions(+), 48 deletions(-) diff --git a/haystack/nodes/prompt/prompt_node.py b/haystack/nodes/prompt/prompt_node.py index 36cd2047a..bec4552cb 100644 --- a/haystack/nodes/prompt/prompt_node.py +++ b/haystack/nodes/prompt/prompt_node.py @@ -1,16 +1,15 @@ import copy import logging -import pydoc import re from abc import ABC from string import Template -from typing import Dict, List, Optional, Tuple, Union, Any, Type, Iterator +from typing import Dict, List, Optional, Tuple, Union, Any, Iterator, Type import torch from haystack import MultiLabel from haystack.nodes.base import BaseComponent -from haystack.nodes.prompt.providers import PromptModelInvocationLayer, known_providers +from haystack.nodes.prompt.providers import PromptModelInvocationLayer from haystack.schema import Document from haystack.telemetry_2 import send_event @@ -186,7 +185,7 @@ class PromptModel(BaseComponent): use_auth_token: Optional[Union[str, bool]] = None, use_gpu: Optional[bool] = None, devices: Optional[List[Union[str, torch.device]]] = None, - invocation_layer_class: Optional[str] = None, + invocation_layer_class: Optional[Type[PromptModelInvocationLayer]] = None, model_kwargs: Optional[Dict] = None, ): """ @@ -198,8 +197,7 @@ class PromptModel(BaseComponent): :param use_auth_token: The Hugging Face token to use. :param use_gpu: Whether to use GPU or not. :param devices: The devices to use where the model is loaded. - :param invocation_layer_class: The custom invocation layer class to use. Use a dotted notation indicating the - path from a module’s global scope to the class. If None, known invocation layers are used. + :param invocation_layer_class: The custom invocation layer class to use. If None, known invocation layers are used. :param model_kwargs: Additional keyword arguments passed to the underlying model. Note that Azure OpenAI InstructGPT models require two additional parameters: azure_base_url (The URL for the @@ -216,33 +214,11 @@ class PromptModel(BaseComponent): self.devices = devices self.model_kwargs = model_kwargs if model_kwargs else {} + self.model_invocation_layer = self.create_invocation_layer(invocation_layer_class=invocation_layer_class) - self.invocation_layer_classes: List[Type[PromptModelInvocationLayer]] = known_providers() - if invocation_layer_class: - klass: Optional[Type[PromptModelInvocationLayer]] = None - if isinstance(invocation_layer_class, str): - # try to find the invocation_layer_class provider class - search_path: List[str] = [ - f"haystack.nodes.prompt.providers.{invocation_layer_class}", - invocation_layer_class, - ] - klass = next((pydoc.locate(path) for path in search_path if pydoc.locate(path)), None) # type: ignore - - if not klass: - raise ValueError( - f"Could not locate PromptModelInvocationLayer class with name {invocation_layer_class}. " - f"Make sure to pass the full path to the class." - ) - - if not issubclass(klass, PromptModelInvocationLayer): - raise ValueError(f"Class {invocation_layer_class} is not a subclass of PromptModelInvocationLayer.") - - logger.info("Registering custom invocation layer class %s", klass) - self.register(klass) - - self.model_invocation_layer = self.create_invocation_layer() - - def create_invocation_layer(self) -> PromptModelInvocationLayer: + def create_invocation_layer( + self, invocation_layer_class: Optional[Type[PromptModelInvocationLayer]] + ) -> PromptModelInvocationLayer: kwargs = { "api_key": self.api_key, "use_auth_token": self.use_auth_token, @@ -251,26 +227,24 @@ class PromptModel(BaseComponent): } all_kwargs = {**self.model_kwargs, **kwargs} + if invocation_layer_class: + return invocation_layer_class( + model_name_or_path=self.model_name_or_path, max_length=self.max_length, **all_kwargs + ) # search all invocation layer classes and find the first one that supports the model, # then create an instance of that invocation layer - for invocation_layer in self.invocation_layer_classes: + for invocation_layer in PromptModelInvocationLayer.invocation_layer_providers: if invocation_layer.supports(self.model_name_or_path, **all_kwargs): return invocation_layer( model_name_or_path=self.model_name_or_path, max_length=self.max_length, **all_kwargs ) raise ValueError( f"Model {self.model_name_or_path} is not supported - no matching invocation layer found." - f" Currently supported invocation layers are: {self.invocation_layer_classes}" - f" You can implement and provide custom invocation layer for {self.model_name_or_path} via PromptModel init." + f" Currently supported invocation layers are: {PromptModelInvocationLayer.invocation_layer_providers}" + f" You can implement and provide custom invocation layer for {self.model_name_or_path} by subclassing " + "PromptModelInvocationLayer." ) - def register(self, invocation_layer: Type[PromptModelInvocationLayer]): - """ - Registers additional prompt model invocation layer. It takes a function that returns a boolean as a - matching condition on `model_name_or_path` and a class that implements `PromptModelInvocationLayer` interface. - """ - self.invocation_layer_classes.append(invocation_layer) - def invoke(self, prompt: Union[str, List[str]], **kwargs) -> List[str]: """ It takes in a prompt, and returns a list of responses using the underlying invocation layer. diff --git a/haystack/nodes/prompt/providers.py b/haystack/nodes/prompt/providers.py index c7dc867ee..315eb2680 100644 --- a/haystack/nodes/prompt/providers.py +++ b/haystack/nodes/prompt/providers.py @@ -35,6 +35,8 @@ class PromptModelInvocationLayer: could be even remote, for example, a call to a remote API endpoint. """ + invocation_layer_providers: List[Type["PromptModelInvocationLayer"]] = [] + def __init__(self, model_name_or_path: str, **kwargs): """ Creates a new PromptModelInvocationLayer instance. @@ -47,6 +49,15 @@ class PromptModelInvocationLayer: self.model_name_or_path = model_name_or_path + def __init_subclass__(cls, **kwargs): + """ + Used to register user-defined invocation layers. + + Called when a subclass of PromptModelInvocationLayer is imported. + """ + super().__init_subclass__(**kwargs) + cls.invocation_layer_providers.append(cls) + @abstractmethod def invoke(self, *args, **kwargs): """ @@ -76,10 +87,6 @@ class PromptModelInvocationLayer: pass -def known_providers() -> List[Type[PromptModelInvocationLayer]]: - return [HFLocalInvocationLayer, OpenAIInvocationLayer, AzureOpenAIInvocationLayer] - - class StopWordsCriteria(StoppingCriteria): """ Stops text generation if any one of the stop words is generated. diff --git a/test/nodes/test_prompt_node.py b/test/nodes/test_prompt_node.py index 83dc27634..1368034bb 100644 --- a/test/nodes/test_prompt_node.py +++ b/test/nodes/test_prompt_node.py @@ -88,8 +88,8 @@ def test_prompt_template_repr(): @pytest.mark.unit -def test_prompt_node_with_custom_invocation_layer_from_string(): - model = PromptModel("fake_model", invocation_layer_class="test.nodes.test_prompt_node.CustomInvocationLayer") +def test_prompt_node_with_custom_invocation_layer(): + model = PromptModel("fake_model") pn = PromptNode(model_name_or_path=model) output = pn("Some fake invocation")