mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-13 08:33:57 +00:00
refactor: simplify registration of PromptModelInvocationLayer (#4339)
* use __init_subclass__ and remove registering functions
This commit is contained in:
parent
7d5e7c089c
commit
024332f98f
@ -1,16 +1,15 @@
|
|||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import pydoc
|
|
||||||
import re
|
import re
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from string import Template
|
from string import Template
|
||||||
from typing import Dict, List, Optional, Tuple, Union, Any, Type, Iterator
|
from typing import Dict, List, Optional, Tuple, Union, Any, Iterator, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from haystack import MultiLabel
|
from haystack import MultiLabel
|
||||||
from haystack.nodes.base import BaseComponent
|
from haystack.nodes.base import BaseComponent
|
||||||
from haystack.nodes.prompt.providers import PromptModelInvocationLayer, known_providers
|
from haystack.nodes.prompt.providers import PromptModelInvocationLayer
|
||||||
from haystack.schema import Document
|
from haystack.schema import Document
|
||||||
from haystack.telemetry_2 import send_event
|
from haystack.telemetry_2 import send_event
|
||||||
|
|
||||||
@ -186,7 +185,7 @@ class PromptModel(BaseComponent):
|
|||||||
use_auth_token: Optional[Union[str, bool]] = None,
|
use_auth_token: Optional[Union[str, bool]] = None,
|
||||||
use_gpu: Optional[bool] = None,
|
use_gpu: Optional[bool] = None,
|
||||||
devices: Optional[List[Union[str, torch.device]]] = None,
|
devices: Optional[List[Union[str, torch.device]]] = None,
|
||||||
invocation_layer_class: Optional[str] = None,
|
invocation_layer_class: Optional[Type[PromptModelInvocationLayer]] = None,
|
||||||
model_kwargs: Optional[Dict] = None,
|
model_kwargs: Optional[Dict] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -198,8 +197,7 @@ class PromptModel(BaseComponent):
|
|||||||
:param use_auth_token: The Hugging Face token to use.
|
:param use_auth_token: The Hugging Face token to use.
|
||||||
:param use_gpu: Whether to use GPU or not.
|
:param use_gpu: Whether to use GPU or not.
|
||||||
:param devices: The devices to use where the model is loaded.
|
:param devices: The devices to use where the model is loaded.
|
||||||
:param invocation_layer_class: The custom invocation layer class to use. Use a dotted notation indicating the
|
:param invocation_layer_class: The custom invocation layer class to use. If None, known invocation layers are used.
|
||||||
path from a module’s global scope to the class. If None, known invocation layers are used.
|
|
||||||
:param model_kwargs: Additional keyword arguments passed to the underlying model.
|
:param model_kwargs: Additional keyword arguments passed to the underlying model.
|
||||||
|
|
||||||
Note that Azure OpenAI InstructGPT models require two additional parameters: azure_base_url (The URL for the
|
Note that Azure OpenAI InstructGPT models require two additional parameters: azure_base_url (The URL for the
|
||||||
@ -216,33 +214,11 @@ class PromptModel(BaseComponent):
|
|||||||
self.devices = devices
|
self.devices = devices
|
||||||
|
|
||||||
self.model_kwargs = model_kwargs if model_kwargs else {}
|
self.model_kwargs = model_kwargs if model_kwargs else {}
|
||||||
|
self.model_invocation_layer = self.create_invocation_layer(invocation_layer_class=invocation_layer_class)
|
||||||
|
|
||||||
self.invocation_layer_classes: List[Type[PromptModelInvocationLayer]] = known_providers()
|
def create_invocation_layer(
|
||||||
if invocation_layer_class:
|
self, invocation_layer_class: Optional[Type[PromptModelInvocationLayer]]
|
||||||
klass: Optional[Type[PromptModelInvocationLayer]] = None
|
) -> PromptModelInvocationLayer:
|
||||||
if isinstance(invocation_layer_class, str):
|
|
||||||
# try to find the invocation_layer_class provider class
|
|
||||||
search_path: List[str] = [
|
|
||||||
f"haystack.nodes.prompt.providers.{invocation_layer_class}",
|
|
||||||
invocation_layer_class,
|
|
||||||
]
|
|
||||||
klass = next((pydoc.locate(path) for path in search_path if pydoc.locate(path)), None) # type: ignore
|
|
||||||
|
|
||||||
if not klass:
|
|
||||||
raise ValueError(
|
|
||||||
f"Could not locate PromptModelInvocationLayer class with name {invocation_layer_class}. "
|
|
||||||
f"Make sure to pass the full path to the class."
|
|
||||||
)
|
|
||||||
|
|
||||||
if not issubclass(klass, PromptModelInvocationLayer):
|
|
||||||
raise ValueError(f"Class {invocation_layer_class} is not a subclass of PromptModelInvocationLayer.")
|
|
||||||
|
|
||||||
logger.info("Registering custom invocation layer class %s", klass)
|
|
||||||
self.register(klass)
|
|
||||||
|
|
||||||
self.model_invocation_layer = self.create_invocation_layer()
|
|
||||||
|
|
||||||
def create_invocation_layer(self) -> PromptModelInvocationLayer:
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"api_key": self.api_key,
|
"api_key": self.api_key,
|
||||||
"use_auth_token": self.use_auth_token,
|
"use_auth_token": self.use_auth_token,
|
||||||
@ -251,26 +227,24 @@ class PromptModel(BaseComponent):
|
|||||||
}
|
}
|
||||||
all_kwargs = {**self.model_kwargs, **kwargs}
|
all_kwargs = {**self.model_kwargs, **kwargs}
|
||||||
|
|
||||||
|
if invocation_layer_class:
|
||||||
|
return invocation_layer_class(
|
||||||
|
model_name_or_path=self.model_name_or_path, max_length=self.max_length, **all_kwargs
|
||||||
|
)
|
||||||
# search all invocation layer classes and find the first one that supports the model,
|
# search all invocation layer classes and find the first one that supports the model,
|
||||||
# then create an instance of that invocation layer
|
# then create an instance of that invocation layer
|
||||||
for invocation_layer in self.invocation_layer_classes:
|
for invocation_layer in PromptModelInvocationLayer.invocation_layer_providers:
|
||||||
if invocation_layer.supports(self.model_name_or_path, **all_kwargs):
|
if invocation_layer.supports(self.model_name_or_path, **all_kwargs):
|
||||||
return invocation_layer(
|
return invocation_layer(
|
||||||
model_name_or_path=self.model_name_or_path, max_length=self.max_length, **all_kwargs
|
model_name_or_path=self.model_name_or_path, max_length=self.max_length, **all_kwargs
|
||||||
)
|
)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Model {self.model_name_or_path} is not supported - no matching invocation layer found."
|
f"Model {self.model_name_or_path} is not supported - no matching invocation layer found."
|
||||||
f" Currently supported invocation layers are: {self.invocation_layer_classes}"
|
f" Currently supported invocation layers are: {PromptModelInvocationLayer.invocation_layer_providers}"
|
||||||
f" You can implement and provide custom invocation layer for {self.model_name_or_path} via PromptModel init."
|
f" You can implement and provide custom invocation layer for {self.model_name_or_path} by subclassing "
|
||||||
|
"PromptModelInvocationLayer."
|
||||||
)
|
)
|
||||||
|
|
||||||
def register(self, invocation_layer: Type[PromptModelInvocationLayer]):
|
|
||||||
"""
|
|
||||||
Registers additional prompt model invocation layer. It takes a function that returns a boolean as a
|
|
||||||
matching condition on `model_name_or_path` and a class that implements `PromptModelInvocationLayer` interface.
|
|
||||||
"""
|
|
||||||
self.invocation_layer_classes.append(invocation_layer)
|
|
||||||
|
|
||||||
def invoke(self, prompt: Union[str, List[str]], **kwargs) -> List[str]:
|
def invoke(self, prompt: Union[str, List[str]], **kwargs) -> List[str]:
|
||||||
"""
|
"""
|
||||||
It takes in a prompt, and returns a list of responses using the underlying invocation layer.
|
It takes in a prompt, and returns a list of responses using the underlying invocation layer.
|
||||||
|
|||||||
@ -35,6 +35,8 @@ class PromptModelInvocationLayer:
|
|||||||
could be even remote, for example, a call to a remote API endpoint.
|
could be even remote, for example, a call to a remote API endpoint.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
invocation_layer_providers: List[Type["PromptModelInvocationLayer"]] = []
|
||||||
|
|
||||||
def __init__(self, model_name_or_path: str, **kwargs):
|
def __init__(self, model_name_or_path: str, **kwargs):
|
||||||
"""
|
"""
|
||||||
Creates a new PromptModelInvocationLayer instance.
|
Creates a new PromptModelInvocationLayer instance.
|
||||||
@ -47,6 +49,15 @@ class PromptModelInvocationLayer:
|
|||||||
|
|
||||||
self.model_name_or_path = model_name_or_path
|
self.model_name_or_path = model_name_or_path
|
||||||
|
|
||||||
|
def __init_subclass__(cls, **kwargs):
|
||||||
|
"""
|
||||||
|
Used to register user-defined invocation layers.
|
||||||
|
|
||||||
|
Called when a subclass of PromptModelInvocationLayer is imported.
|
||||||
|
"""
|
||||||
|
super().__init_subclass__(**kwargs)
|
||||||
|
cls.invocation_layer_providers.append(cls)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def invoke(self, *args, **kwargs):
|
def invoke(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
@ -76,10 +87,6 @@ class PromptModelInvocationLayer:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def known_providers() -> List[Type[PromptModelInvocationLayer]]:
|
|
||||||
return [HFLocalInvocationLayer, OpenAIInvocationLayer, AzureOpenAIInvocationLayer]
|
|
||||||
|
|
||||||
|
|
||||||
class StopWordsCriteria(StoppingCriteria):
|
class StopWordsCriteria(StoppingCriteria):
|
||||||
"""
|
"""
|
||||||
Stops text generation if any one of the stop words is generated.
|
Stops text generation if any one of the stop words is generated.
|
||||||
|
|||||||
@ -88,8 +88,8 @@ def test_prompt_template_repr():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_prompt_node_with_custom_invocation_layer_from_string():
|
def test_prompt_node_with_custom_invocation_layer():
|
||||||
model = PromptModel("fake_model", invocation_layer_class="test.nodes.test_prompt_node.CustomInvocationLayer")
|
model = PromptModel("fake_model")
|
||||||
pn = PromptNode(model_name_or_path=model)
|
pn = PromptNode(model_name_or_path=model)
|
||||||
output = pn("Some fake invocation")
|
output = pn("Some fake invocation")
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user