fix: Deprecate Seq2SeqGenerator and RAGenerator (#4745)

* Deprecate Seq2SeqGenerator

* changed the warning to include suggestion

* Added example and msg to API reference docs

* Added RAG deprecation

* renamed name to adapt to naming conven

* update docstrings

---------

Co-authored-by: Mayank Jobanputra <mayankjobanputra@gmail.com>
Co-authored-by: Darja Fokina <daria.f93@gmail.com>
This commit is contained in:
Vladimir Blagojevic 2023-04-26 13:59:35 +02:00 committed by GitHub
parent 9cbe9e0949
commit 3fefc475b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 44 additions and 0 deletions

View File

@ -0,0 +1,28 @@
from haystack import Document
from haystack.nodes import PromptNode, PromptTemplate
p = PromptNode("vblagoje/bart_lfqa")
# Start by defining a question/query
query = "Why does water heated to room temperature feel colder than the air around it?"
# Given the question above, suppose the documents below were found in some document store
documents = [
"when the skin is completely wet. The body continuously loses water by...",
"at greater pressures. There is an ambiguity, however, as to the meaning of the terms 'heating' and 'cooling'...",
"are not in a relation of thermal equilibrium, heat will flow from the hotter to the colder, by whatever pathway...",
"air condition and moving along a line of constant enthalpy toward a state of higher humidity. A simple example ...",
"Thermal contact conductance. In physics, thermal contact conductance is the study of heat conduction between solid ...",
]
# Manually concatenate the question and support documents into BART input
# conditioned_doc = "<P> " + " <P> ".join([d for d in documents])
# query_and_docs = "question: {} context: {}".format(query, conditioned_doc)
# Or use the PromptTemplate as shown here
pt = PromptTemplate("lfqa", "question: {query} context: {join(documents, delimiter='<P>')}")
res = p.prompt(prompt_template=pt, query=query, documents=[Document(d) for d in documents])
print(res)

View File

@ -1,3 +1,4 @@
import warnings
from typing import Dict, List, Optional, Union
import logging
@ -82,6 +83,8 @@ class RAGenerator(BaseGenerator):
devices: Optional[List[Union[str, torch.device]]] = None,
):
"""
This component is now deprecated and will be removed in future versions. Use `PromptNode` instead of `RAGenerator`.
Load a RAG model from Transformers along with passage_embedding_model.
See https://huggingface.co/transformers/model_doc/rag.html for more details
@ -110,6 +113,12 @@ class RAGenerator(BaseGenerator):
[torch.device('cuda:0'), "mps", "cuda:1"]). When specifying `use_gpu=False` the devices
parameter is not used and a single cpu device is used for inference.
"""
warnings.warn(
"`RAGenerator` component is deprecated and will be removed in future versions. Use `PromptNode` "
"instead of `RAGenerator`.",
category=DeprecationWarning,
)
super().__init__(progress_bar=progress_bar)
self.model_name_or_path = model_name_or_path
@ -345,6 +354,8 @@ class Seq2SeqGenerator(BaseGenerator):
devices: Optional[List[Union[str, torch.device]]] = None,
):
"""
This component is now deprecated and will be removed in future versions. Use `PromptNode` instead of `Seq2SeqGenerator`.
:param model_name_or_path: A Hugging Face model name for auto-regressive language model like GPT2, XLNet, XLM, Bart, T5, and so on.
:param input_converter: An optional callable to prepare model input for the underlying language model
specified in the `model_name_or_path` parameter. The required `__call__` method signature for
@ -367,6 +378,11 @@ class Seq2SeqGenerator(BaseGenerator):
[torch.device('cuda:0'), "mps", "cuda:1"]). When specifying `use_gpu=False` the devices
parameter is not used and a single cpu device is used for inference.
"""
warnings.warn(
"`Seq2SeqGenerator` component is deprecated and will be removed in future versions. Use `PromptNode` "
"instead of `Seq2SeqGenerator`.",
category=DeprecationWarning,
)
super().__init__(progress_bar=progress_bar)
self.model_name_or_path = model_name_or_path
self.max_length = max_length