refactor: simplify registration of PromptModelInvocationLayer (#4339)

* use __init_subclass__ and remove registering functions
This commit is contained in:
ZanSara 2023-03-07 20:53:48 +01:00 committed by GitHub
parent 7d5e7c089c
commit 024332f98f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 29 additions and 48 deletions

View File

@ -1,16 +1,15 @@
import copy
import logging
import pydoc
import re
from abc import ABC
from string import Template
from typing import Dict, List, Optional, Tuple, Union, Any, Type, Iterator
from typing import Dict, List, Optional, Tuple, Union, Any, Iterator, Type
import torch
from haystack import MultiLabel
from haystack.nodes.base import BaseComponent
from haystack.nodes.prompt.providers import PromptModelInvocationLayer, known_providers
from haystack.nodes.prompt.providers import PromptModelInvocationLayer
from haystack.schema import Document
from haystack.telemetry_2 import send_event
@ -186,7 +185,7 @@ class PromptModel(BaseComponent):
use_auth_token: Optional[Union[str, bool]] = None,
use_gpu: Optional[bool] = None,
devices: Optional[List[Union[str, torch.device]]] = None,
invocation_layer_class: Optional[str] = None,
invocation_layer_class: Optional[Type[PromptModelInvocationLayer]] = None,
model_kwargs: Optional[Dict] = None,
):
"""
@ -198,8 +197,7 @@ class PromptModel(BaseComponent):
:param use_auth_token: The Hugging Face token to use.
:param use_gpu: Whether to use GPU or not.
:param devices: The devices to use where the model is loaded.
:param invocation_layer_class: The custom invocation layer class to use. Use a dotted notation indicating the
path from a modules global scope to the class. If None, known invocation layers are used.
:param invocation_layer_class: The custom invocation layer class to use. If None, known invocation layers are used.
:param model_kwargs: Additional keyword arguments passed to the underlying model.
Note that Azure OpenAI InstructGPT models require two additional parameters: azure_base_url (The URL for the
@ -216,33 +214,11 @@ class PromptModel(BaseComponent):
self.devices = devices
self.model_kwargs = model_kwargs if model_kwargs else {}
self.model_invocation_layer = self.create_invocation_layer(invocation_layer_class=invocation_layer_class)
self.invocation_layer_classes: List[Type[PromptModelInvocationLayer]] = known_providers()
if invocation_layer_class:
klass: Optional[Type[PromptModelInvocationLayer]] = None
if isinstance(invocation_layer_class, str):
# try to find the invocation_layer_class provider class
search_path: List[str] = [
f"haystack.nodes.prompt.providers.{invocation_layer_class}",
invocation_layer_class,
]
klass = next((pydoc.locate(path) for path in search_path if pydoc.locate(path)), None) # type: ignore
if not klass:
raise ValueError(
f"Could not locate PromptModelInvocationLayer class with name {invocation_layer_class}. "
f"Make sure to pass the full path to the class."
)
if not issubclass(klass, PromptModelInvocationLayer):
raise ValueError(f"Class {invocation_layer_class} is not a subclass of PromptModelInvocationLayer.")
logger.info("Registering custom invocation layer class %s", klass)
self.register(klass)
self.model_invocation_layer = self.create_invocation_layer()
def create_invocation_layer(self) -> PromptModelInvocationLayer:
def create_invocation_layer(
self, invocation_layer_class: Optional[Type[PromptModelInvocationLayer]]
) -> PromptModelInvocationLayer:
kwargs = {
"api_key": self.api_key,
"use_auth_token": self.use_auth_token,
@ -251,26 +227,24 @@ class PromptModel(BaseComponent):
}
all_kwargs = {**self.model_kwargs, **kwargs}
if invocation_layer_class:
return invocation_layer_class(
model_name_or_path=self.model_name_or_path, max_length=self.max_length, **all_kwargs
)
# search all invocation layer classes and find the first one that supports the model,
# then create an instance of that invocation layer
for invocation_layer in self.invocation_layer_classes:
for invocation_layer in PromptModelInvocationLayer.invocation_layer_providers:
if invocation_layer.supports(self.model_name_or_path, **all_kwargs):
return invocation_layer(
model_name_or_path=self.model_name_or_path, max_length=self.max_length, **all_kwargs
)
raise ValueError(
f"Model {self.model_name_or_path} is not supported - no matching invocation layer found."
f" Currently supported invocation layers are: {self.invocation_layer_classes}"
f" You can implement and provide custom invocation layer for {self.model_name_or_path} via PromptModel init."
f" Currently supported invocation layers are: {PromptModelInvocationLayer.invocation_layer_providers}"
f" You can implement and provide custom invocation layer for {self.model_name_or_path} by subclassing "
"PromptModelInvocationLayer."
)
def register(self, invocation_layer: Type[PromptModelInvocationLayer]):
"""
Registers additional prompt model invocation layer. It takes a function that returns a boolean as a
matching condition on `model_name_or_path` and a class that implements `PromptModelInvocationLayer` interface.
"""
self.invocation_layer_classes.append(invocation_layer)
def invoke(self, prompt: Union[str, List[str]], **kwargs) -> List[str]:
"""
It takes in a prompt, and returns a list of responses using the underlying invocation layer.

View File

@ -35,6 +35,8 @@ class PromptModelInvocationLayer:
could be even remote, for example, a call to a remote API endpoint.
"""
invocation_layer_providers: List[Type["PromptModelInvocationLayer"]] = []
def __init__(self, model_name_or_path: str, **kwargs):
"""
Creates a new PromptModelInvocationLayer instance.
@ -47,6 +49,15 @@ class PromptModelInvocationLayer:
self.model_name_or_path = model_name_or_path
def __init_subclass__(cls, **kwargs):
"""
Used to register user-defined invocation layers.
Called when a subclass of PromptModelInvocationLayer is imported.
"""
super().__init_subclass__(**kwargs)
cls.invocation_layer_providers.append(cls)
@abstractmethod
def invoke(self, *args, **kwargs):
"""
@ -76,10 +87,6 @@ class PromptModelInvocationLayer:
pass
def known_providers() -> List[Type[PromptModelInvocationLayer]]:
return [HFLocalInvocationLayer, OpenAIInvocationLayer, AzureOpenAIInvocationLayer]
class StopWordsCriteria(StoppingCriteria):
"""
Stops text generation if any one of the stop words is generated.

View File

@ -88,8 +88,8 @@ def test_prompt_template_repr():
@pytest.mark.unit
def test_prompt_node_with_custom_invocation_layer_from_string():
model = PromptModel("fake_model", invocation_layer_class="test.nodes.test_prompt_node.CustomInvocationLayer")
def test_prompt_node_with_custom_invocation_layer():
model = PromptModel("fake_model")
pn = PromptNode(model_name_or_path=model)
output = pn("Some fake invocation")