mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-03 03:09:28 +00:00
refactor: simplify registration of PromptModelInvocationLayer (#4339)
* use __init_subclass__ and remove registering functions
This commit is contained in:
parent
7d5e7c089c
commit
024332f98f
@ -1,16 +1,15 @@
|
||||
import copy
|
||||
import logging
|
||||
import pydoc
|
||||
import re
|
||||
from abc import ABC
|
||||
from string import Template
|
||||
from typing import Dict, List, Optional, Tuple, Union, Any, Type, Iterator
|
||||
from typing import Dict, List, Optional, Tuple, Union, Any, Iterator, Type
|
||||
|
||||
import torch
|
||||
|
||||
from haystack import MultiLabel
|
||||
from haystack.nodes.base import BaseComponent
|
||||
from haystack.nodes.prompt.providers import PromptModelInvocationLayer, known_providers
|
||||
from haystack.nodes.prompt.providers import PromptModelInvocationLayer
|
||||
from haystack.schema import Document
|
||||
from haystack.telemetry_2 import send_event
|
||||
|
||||
@ -186,7 +185,7 @@ class PromptModel(BaseComponent):
|
||||
use_auth_token: Optional[Union[str, bool]] = None,
|
||||
use_gpu: Optional[bool] = None,
|
||||
devices: Optional[List[Union[str, torch.device]]] = None,
|
||||
invocation_layer_class: Optional[str] = None,
|
||||
invocation_layer_class: Optional[Type[PromptModelInvocationLayer]] = None,
|
||||
model_kwargs: Optional[Dict] = None,
|
||||
):
|
||||
"""
|
||||
@ -198,8 +197,7 @@ class PromptModel(BaseComponent):
|
||||
:param use_auth_token: The Hugging Face token to use.
|
||||
:param use_gpu: Whether to use GPU or not.
|
||||
:param devices: The devices to use where the model is loaded.
|
||||
:param invocation_layer_class: The custom invocation layer class to use. Use a dotted notation indicating the
|
||||
path from a module’s global scope to the class. If None, known invocation layers are used.
|
||||
:param invocation_layer_class: The custom invocation layer class to use. If None, known invocation layers are used.
|
||||
:param model_kwargs: Additional keyword arguments passed to the underlying model.
|
||||
|
||||
Note that Azure OpenAI InstructGPT models require two additional parameters: azure_base_url (The URL for the
|
||||
@ -216,33 +214,11 @@ class PromptModel(BaseComponent):
|
||||
self.devices = devices
|
||||
|
||||
self.model_kwargs = model_kwargs if model_kwargs else {}
|
||||
self.model_invocation_layer = self.create_invocation_layer(invocation_layer_class=invocation_layer_class)
|
||||
|
||||
self.invocation_layer_classes: List[Type[PromptModelInvocationLayer]] = known_providers()
|
||||
if invocation_layer_class:
|
||||
klass: Optional[Type[PromptModelInvocationLayer]] = None
|
||||
if isinstance(invocation_layer_class, str):
|
||||
# try to find the invocation_layer_class provider class
|
||||
search_path: List[str] = [
|
||||
f"haystack.nodes.prompt.providers.{invocation_layer_class}",
|
||||
invocation_layer_class,
|
||||
]
|
||||
klass = next((pydoc.locate(path) for path in search_path if pydoc.locate(path)), None) # type: ignore
|
||||
|
||||
if not klass:
|
||||
raise ValueError(
|
||||
f"Could not locate PromptModelInvocationLayer class with name {invocation_layer_class}. "
|
||||
f"Make sure to pass the full path to the class."
|
||||
)
|
||||
|
||||
if not issubclass(klass, PromptModelInvocationLayer):
|
||||
raise ValueError(f"Class {invocation_layer_class} is not a subclass of PromptModelInvocationLayer.")
|
||||
|
||||
logger.info("Registering custom invocation layer class %s", klass)
|
||||
self.register(klass)
|
||||
|
||||
self.model_invocation_layer = self.create_invocation_layer()
|
||||
|
||||
def create_invocation_layer(self) -> PromptModelInvocationLayer:
|
||||
def create_invocation_layer(
|
||||
self, invocation_layer_class: Optional[Type[PromptModelInvocationLayer]]
|
||||
) -> PromptModelInvocationLayer:
|
||||
kwargs = {
|
||||
"api_key": self.api_key,
|
||||
"use_auth_token": self.use_auth_token,
|
||||
@ -251,26 +227,24 @@ class PromptModel(BaseComponent):
|
||||
}
|
||||
all_kwargs = {**self.model_kwargs, **kwargs}
|
||||
|
||||
if invocation_layer_class:
|
||||
return invocation_layer_class(
|
||||
model_name_or_path=self.model_name_or_path, max_length=self.max_length, **all_kwargs
|
||||
)
|
||||
# search all invocation layer classes and find the first one that supports the model,
|
||||
# then create an instance of that invocation layer
|
||||
for invocation_layer in self.invocation_layer_classes:
|
||||
for invocation_layer in PromptModelInvocationLayer.invocation_layer_providers:
|
||||
if invocation_layer.supports(self.model_name_or_path, **all_kwargs):
|
||||
return invocation_layer(
|
||||
model_name_or_path=self.model_name_or_path, max_length=self.max_length, **all_kwargs
|
||||
)
|
||||
raise ValueError(
|
||||
f"Model {self.model_name_or_path} is not supported - no matching invocation layer found."
|
||||
f" Currently supported invocation layers are: {self.invocation_layer_classes}"
|
||||
f" You can implement and provide custom invocation layer for {self.model_name_or_path} via PromptModel init."
|
||||
f" Currently supported invocation layers are: {PromptModelInvocationLayer.invocation_layer_providers}"
|
||||
f" You can implement and provide custom invocation layer for {self.model_name_or_path} by subclassing "
|
||||
"PromptModelInvocationLayer."
|
||||
)
|
||||
|
||||
def register(self, invocation_layer: Type[PromptModelInvocationLayer]):
|
||||
"""
|
||||
Registers additional prompt model invocation layer. It takes a function that returns a boolean as a
|
||||
matching condition on `model_name_or_path` and a class that implements `PromptModelInvocationLayer` interface.
|
||||
"""
|
||||
self.invocation_layer_classes.append(invocation_layer)
|
||||
|
||||
def invoke(self, prompt: Union[str, List[str]], **kwargs) -> List[str]:
|
||||
"""
|
||||
It takes in a prompt, and returns a list of responses using the underlying invocation layer.
|
||||
|
||||
@ -35,6 +35,8 @@ class PromptModelInvocationLayer:
|
||||
could be even remote, for example, a call to a remote API endpoint.
|
||||
"""
|
||||
|
||||
invocation_layer_providers: List[Type["PromptModelInvocationLayer"]] = []
|
||||
|
||||
def __init__(self, model_name_or_path: str, **kwargs):
|
||||
"""
|
||||
Creates a new PromptModelInvocationLayer instance.
|
||||
@ -47,6 +49,15 @@ class PromptModelInvocationLayer:
|
||||
|
||||
self.model_name_or_path = model_name_or_path
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
"""
|
||||
Used to register user-defined invocation layers.
|
||||
|
||||
Called when a subclass of PromptModelInvocationLayer is imported.
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
cls.invocation_layer_providers.append(cls)
|
||||
|
||||
@abstractmethod
|
||||
def invoke(self, *args, **kwargs):
|
||||
"""
|
||||
@ -76,10 +87,6 @@ class PromptModelInvocationLayer:
|
||||
pass
|
||||
|
||||
|
||||
def known_providers() -> List[Type[PromptModelInvocationLayer]]:
|
||||
return [HFLocalInvocationLayer, OpenAIInvocationLayer, AzureOpenAIInvocationLayer]
|
||||
|
||||
|
||||
class StopWordsCriteria(StoppingCriteria):
|
||||
"""
|
||||
Stops text generation if any one of the stop words is generated.
|
||||
|
||||
@ -88,8 +88,8 @@ def test_prompt_template_repr():
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_prompt_node_with_custom_invocation_layer_from_string():
|
||||
model = PromptModel("fake_model", invocation_layer_class="test.nodes.test_prompt_node.CustomInvocationLayer")
|
||||
def test_prompt_node_with_custom_invocation_layer():
|
||||
model = PromptModel("fake_model")
|
||||
pn = PromptNode(model_name_or_path=model)
|
||||
output = pn("Some fake invocation")
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user