diff --git a/docs/_src/api/api/extractor.md b/docs/_src/api/api/extractor.md index 8eb8101de..09160782e 100644 --- a/docs/_src/api/api/extractor.md +++ b/docs/_src/api/api/extractor.md @@ -21,6 +21,7 @@ The entities extracted by this Node will populate Document.entities **Arguments**: - `model_name_or_path`: The name of the model to use for entity extraction. +- `model_version`: The version of the model to use for entity extraction. - `use_gpu`: Whether to use the GPU or not. - `batch_size`: The batch size to use for entity extraction. - `progress_bar`: Whether to show a progress bar or not. diff --git a/docs/_src/api/api/question_generator.md b/docs/_src/api/api/question_generator.md index becdfa740..faaa14fe0 100644 --- a/docs/_src/api/api/question_generator.md +++ b/docs/_src/api/api/question_generator.md @@ -23,18 +23,18 @@ come from earlier in the document. #### QuestionGenerator.\_\_init\_\_ ```python -def __init__(model_name_or_path="valhalla/t5-base-e2e-qg", - model_version=None, - num_beams=4, - max_length=256, - no_repeat_ngram_size=3, - length_penalty=1.5, - early_stopping=True, - split_length=50, - split_overlap=10, - use_gpu=True, - prompt="generate questions:", - num_queries_per_doc=1, +def __init__(model_name_or_path: str = "valhalla/t5-base-e2e-qg", + model_version: Optional[str] = None, + num_beams: int = 4, + max_length: int = 256, + no_repeat_ngram_size: int = 3, + length_penalty: float = 1.5, + early_stopping: bool = True, + split_length: int = 50, + split_overlap: int = 10, + use_gpu: bool = True, + prompt: str = "generate questions:", + num_queries_per_doc: int = 1, sep_token: str = "", batch_size: int = 16, progress_bar: bool = True, diff --git a/docs/_src/api/api/retriever.md b/docs/_src/api/api/retriever.md index ad6078caa..2cbee8048 100644 --- a/docs/_src/api/api/retriever.md +++ b/docs/_src/api/api/retriever.md @@ -2010,8 +2010,9 @@ The generated SPARQL query is executed on a knowledge graph. #### Text2SparqlRetriever.\_\_init\_\_ ```python -def __init__(knowledge_graph, - model_name_or_path, +def __init__(knowledge_graph: BaseKnowledgeGraph, + model_name_or_path: str = None, + model_version: Optional[str] = None, top_k: int = 1, use_auth_token: Optional[Union[str, bool]] = None) ``` @@ -2022,6 +2023,7 @@ Init the Retriever by providing a knowledge graph and a pre-trained BART model - `knowledge_graph`: An instance of BaseKnowledgeGraph on which to execute SPARQL queries. - `model_name_or_path`: Name of or path to a pre-trained BartForConditionalGeneration model. +- `model_version`: The version of the model to use for entity extraction. - `top_k`: How many SPARQL queries to generate per text query. - `use_auth_token`: The API token used to download private models from Huggingface. If this parameter is set to `True`, then the token generated when running diff --git a/haystack/json-schemas/haystack-pipeline-main.schema.json b/haystack/json-schemas/haystack-pipeline-main.schema.json index a6d96d5f7..90b14f50c 100644 --- a/haystack/json-schemas/haystack-pipeline-main.schema.json +++ b/haystack/json-schemas/haystack-pipeline-main.schema.json @@ -3014,6 +3014,17 @@ "default": "dslim/bert-base-NER", "type": "string" }, + "model_version": { + "title": "Model Version", + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ] + }, "use_gpu": { "title": "Use Gpu", "default": true, @@ -4478,50 +4489,69 @@ "properties": { "model_name_or_path": { "title": "Model Name Or Path", - "default": "valhalla/t5-base-e2e-qg" + "default": "valhalla/t5-base-e2e-qg", + "type": "string" }, "model_version": { - "title": "Model Version" + "title": "Model Version", + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ] }, "num_beams": { "title": "Num Beams", - "default": 4 + "default": 4, + "type": "integer" }, "max_length": { "title": "Max Length", - "default": 256 + "default": 256, + "type": "integer" }, "no_repeat_ngram_size": { "title": "No Repeat Ngram Size", - "default": 3 + "default": 3, + "type": "integer" }, "length_penalty": { "title": "Length Penalty", - "default": 1.5 + "default": 1.5, + "type": "number" }, "early_stopping": { "title": "Early Stopping", - "default": true + "default": true, + "type": "boolean" }, "split_length": { "title": "Split Length", - "default": 50 + "default": 50, + "type": "integer" }, "split_overlap": { "title": "Split Overlap", - "default": 10 + "default": 10, + "type": "integer" }, "use_gpu": { "title": "Use Gpu", - "default": true + "default": true, + "type": "boolean" }, "prompt": { "title": "Prompt", - "default": "generate questions:" + "default": "generate questions:", + "type": "string" }, "num_queries_per_doc": { "title": "Num Queries Per Doc", - "default": 1 + "default": 1, + "type": "integer" }, "sep_token": { "title": "Sep Token", @@ -5508,10 +5538,23 @@ "type": "object", "properties": { "knowledge_graph": { - "title": "Knowledge Graph" + "title": "Knowledge Graph", + "type": "string" }, "model_name_or_path": { - "title": "Model Name Or Path" + "title": "Model Name Or Path", + "type": "string" + }, + "model_version": { + "title": "Model Version", + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ] }, "top_k": { "title": "Top K", @@ -5534,8 +5577,7 @@ } }, "required": [ - "knowledge_graph", - "model_name_or_path" + "knowledge_graph" ], "additionalProperties": false, "description": "Each parameter can reference other components defined in the same YAML file." diff --git a/haystack/nodes/extractor/entity.py b/haystack/nodes/extractor/entity.py index 1eb0033e2..a01a46627 100644 --- a/haystack/nodes/extractor/entity.py +++ b/haystack/nodes/extractor/entity.py @@ -25,6 +25,7 @@ class EntityExtractor(BaseComponent): The entities extracted by this Node will populate Document.entities :param model_name_or_path: The name of the model to use for entity extraction. + :param model_version: The version of the model to use for entity extraction. :param use_gpu: Whether to use the GPU or not. :param batch_size: The batch size to use for entity extraction. :param progress_bar: Whether to show a progress bar or not. @@ -44,6 +45,7 @@ class EntityExtractor(BaseComponent): def __init__( self, model_name_or_path: str = "dslim/bert-base-NER", + model_version: Optional[str] = None, use_gpu: bool = True, batch_size: int = 16, progress_bar: bool = True, @@ -58,7 +60,7 @@ class EntityExtractor(BaseComponent): tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_auth_token=use_auth_token) token_classifier = AutoModelForTokenClassification.from_pretrained( - model_name_or_path, use_auth_token=use_auth_token + model_name_or_path, use_auth_token=use_auth_token, revision=model_version ) token_classifier.to(str(self.devices[0])) self.model = pipeline( diff --git a/haystack/nodes/question_generator/question_generator.py b/haystack/nodes/question_generator/question_generator.py index 1704eca71..2ea61c058 100644 --- a/haystack/nodes/question_generator/question_generator.py +++ b/haystack/nodes/question_generator/question_generator.py @@ -31,18 +31,18 @@ class QuestionGenerator(BaseComponent): def __init__( self, - model_name_or_path="valhalla/t5-base-e2e-qg", - model_version=None, - num_beams=4, - max_length=256, - no_repeat_ngram_size=3, - length_penalty=1.5, - early_stopping=True, - split_length=50, - split_overlap=10, - use_gpu=True, - prompt="generate questions:", - num_queries_per_doc=1, + model_name_or_path: str = "valhalla/t5-base-e2e-qg", + model_version: Optional[str] = None, + num_beams: int = 4, + max_length: int = 256, + no_repeat_ngram_size: int = 3, + length_penalty: float = 1.5, + early_stopping: bool = True, + split_length: int = 50, + split_overlap: int = 10, + use_gpu: bool = True, + prompt: str = "generate questions:", + num_queries_per_doc: int = 1, sep_token: str = "", batch_size: int = 16, progress_bar: bool = True, @@ -79,7 +79,9 @@ class QuestionGenerator(BaseComponent): f"Multiple devices are not supported in {self.__class__.__name__} inference, " f"using the first device {self.devices[0]}." ) - self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, use_auth_token=use_auth_token) + self.model = AutoModelForSeq2SeqLM.from_pretrained( + model_name_or_path, revision=model_version, use_auth_token=use_auth_token + ) self.model.to(str(self.devices[0])) self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_auth_token=use_auth_token) self.num_beams = num_beams diff --git a/haystack/nodes/retriever/text2sparql.py b/haystack/nodes/retriever/text2sparql.py index 6fdbdf47d..b81f9a353 100644 --- a/haystack/nodes/retriever/text2sparql.py +++ b/haystack/nodes/retriever/text2sparql.py @@ -2,6 +2,8 @@ from typing import Optional, List, Union import logging from transformers import BartForConditionalGeneration, BartTokenizer + +from haystack.document_stores import BaseKnowledgeGraph from haystack.nodes.retriever.base import BaseGraphRetriever @@ -16,13 +18,19 @@ class Text2SparqlRetriever(BaseGraphRetriever): """ def __init__( - self, knowledge_graph, model_name_or_path, top_k: int = 1, use_auth_token: Optional[Union[str, bool]] = None + self, + knowledge_graph: BaseKnowledgeGraph, + model_name_or_path: str = None, + model_version: Optional[str] = None, + top_k: int = 1, + use_auth_token: Optional[Union[str, bool]] = None, ): """ Init the Retriever by providing a knowledge graph and a pre-trained BART model :param knowledge_graph: An instance of BaseKnowledgeGraph on which to execute SPARQL queries. :param model_name_or_path: Name of or path to a pre-trained BartForConditionalGeneration model. + :param model_version: The version of the model to use for entity extraction. :param top_k: How many SPARQL queries to generate per text query. :param use_auth_token: The API token used to download private models from Huggingface. If this parameter is set to `True`, then the token generated when running @@ -35,7 +43,7 @@ class Text2SparqlRetriever(BaseGraphRetriever): self.knowledge_graph = knowledge_graph # TODO We should extend this to any seq2seq models and use the AutoModel class self.model = BartForConditionalGeneration.from_pretrained( - model_name_or_path, forced_bos_token_id=0, use_auth_token=use_auth_token + model_name_or_path, forced_bos_token_id=0, use_auth_token=use_auth_token, revision=model_version ) self.tok = BartTokenizer.from_pretrained(model_name_or_path, use_auth_token=use_auth_token) self.top_k = top_k