mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-05 19:47:45 +00:00
Type all parameter constructors, add model_version optional parameter where applicable (#3152)
This commit is contained in:
parent
20c2320434
commit
84acb6584f
@ -21,6 +21,7 @@ The entities extracted by this Node will populate Document.entities
|
||||
**Arguments**:
|
||||
|
||||
- `model_name_or_path`: The name of the model to use for entity extraction.
|
||||
- `model_version`: The version of the model to use for entity extraction.
|
||||
- `use_gpu`: Whether to use the GPU or not.
|
||||
- `batch_size`: The batch size to use for entity extraction.
|
||||
- `progress_bar`: Whether to show a progress bar or not.
|
||||
|
||||
@ -23,18 +23,18 @@ come from earlier in the document.
|
||||
#### QuestionGenerator.\_\_init\_\_
|
||||
|
||||
```python
|
||||
def __init__(model_name_or_path="valhalla/t5-base-e2e-qg",
|
||||
model_version=None,
|
||||
num_beams=4,
|
||||
max_length=256,
|
||||
no_repeat_ngram_size=3,
|
||||
length_penalty=1.5,
|
||||
early_stopping=True,
|
||||
split_length=50,
|
||||
split_overlap=10,
|
||||
use_gpu=True,
|
||||
prompt="generate questions:",
|
||||
num_queries_per_doc=1,
|
||||
def __init__(model_name_or_path: str = "valhalla/t5-base-e2e-qg",
|
||||
model_version: Optional[str] = None,
|
||||
num_beams: int = 4,
|
||||
max_length: int = 256,
|
||||
no_repeat_ngram_size: int = 3,
|
||||
length_penalty: float = 1.5,
|
||||
early_stopping: bool = True,
|
||||
split_length: int = 50,
|
||||
split_overlap: int = 10,
|
||||
use_gpu: bool = True,
|
||||
prompt: str = "generate questions:",
|
||||
num_queries_per_doc: int = 1,
|
||||
sep_token: str = "<sep>",
|
||||
batch_size: int = 16,
|
||||
progress_bar: bool = True,
|
||||
|
||||
@ -2010,8 +2010,9 @@ The generated SPARQL query is executed on a knowledge graph.
|
||||
#### Text2SparqlRetriever.\_\_init\_\_
|
||||
|
||||
```python
|
||||
def __init__(knowledge_graph,
|
||||
model_name_or_path,
|
||||
def __init__(knowledge_graph: BaseKnowledgeGraph,
|
||||
model_name_or_path: str = None,
|
||||
model_version: Optional[str] = None,
|
||||
top_k: int = 1,
|
||||
use_auth_token: Optional[Union[str, bool]] = None)
|
||||
```
|
||||
@ -2022,6 +2023,7 @@ Init the Retriever by providing a knowledge graph and a pre-trained BART model
|
||||
|
||||
- `knowledge_graph`: An instance of BaseKnowledgeGraph on which to execute SPARQL queries.
|
||||
- `model_name_or_path`: Name of or path to a pre-trained BartForConditionalGeneration model.
|
||||
- `model_version`: The version of the model to use for entity extraction.
|
||||
- `top_k`: How many SPARQL queries to generate per text query.
|
||||
- `use_auth_token`: The API token used to download private models from Huggingface.
|
||||
If this parameter is set to `True`, then the token generated when running
|
||||
|
||||
@ -3014,6 +3014,17 @@
|
||||
"default": "dslim/bert-base-NER",
|
||||
"type": "string"
|
||||
},
|
||||
"model_version": {
|
||||
"title": "Model Version",
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
]
|
||||
},
|
||||
"use_gpu": {
|
||||
"title": "Use Gpu",
|
||||
"default": true,
|
||||
@ -4478,50 +4489,69 @@
|
||||
"properties": {
|
||||
"model_name_or_path": {
|
||||
"title": "Model Name Or Path",
|
||||
"default": "valhalla/t5-base-e2e-qg"
|
||||
"default": "valhalla/t5-base-e2e-qg",
|
||||
"type": "string"
|
||||
},
|
||||
"model_version": {
|
||||
"title": "Model Version"
|
||||
"title": "Model Version",
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
]
|
||||
},
|
||||
"num_beams": {
|
||||
"title": "Num Beams",
|
||||
"default": 4
|
||||
"default": 4,
|
||||
"type": "integer"
|
||||
},
|
||||
"max_length": {
|
||||
"title": "Max Length",
|
||||
"default": 256
|
||||
"default": 256,
|
||||
"type": "integer"
|
||||
},
|
||||
"no_repeat_ngram_size": {
|
||||
"title": "No Repeat Ngram Size",
|
||||
"default": 3
|
||||
"default": 3,
|
||||
"type": "integer"
|
||||
},
|
||||
"length_penalty": {
|
||||
"title": "Length Penalty",
|
||||
"default": 1.5
|
||||
"default": 1.5,
|
||||
"type": "number"
|
||||
},
|
||||
"early_stopping": {
|
||||
"title": "Early Stopping",
|
||||
"default": true
|
||||
"default": true,
|
||||
"type": "boolean"
|
||||
},
|
||||
"split_length": {
|
||||
"title": "Split Length",
|
||||
"default": 50
|
||||
"default": 50,
|
||||
"type": "integer"
|
||||
},
|
||||
"split_overlap": {
|
||||
"title": "Split Overlap",
|
||||
"default": 10
|
||||
"default": 10,
|
||||
"type": "integer"
|
||||
},
|
||||
"use_gpu": {
|
||||
"title": "Use Gpu",
|
||||
"default": true
|
||||
"default": true,
|
||||
"type": "boolean"
|
||||
},
|
||||
"prompt": {
|
||||
"title": "Prompt",
|
||||
"default": "generate questions:"
|
||||
"default": "generate questions:",
|
||||
"type": "string"
|
||||
},
|
||||
"num_queries_per_doc": {
|
||||
"title": "Num Queries Per Doc",
|
||||
"default": 1
|
||||
"default": 1,
|
||||
"type": "integer"
|
||||
},
|
||||
"sep_token": {
|
||||
"title": "Sep Token",
|
||||
@ -5508,10 +5538,23 @@
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"knowledge_graph": {
|
||||
"title": "Knowledge Graph"
|
||||
"title": "Knowledge Graph",
|
||||
"type": "string"
|
||||
},
|
||||
"model_name_or_path": {
|
||||
"title": "Model Name Or Path"
|
||||
"title": "Model Name Or Path",
|
||||
"type": "string"
|
||||
},
|
||||
"model_version": {
|
||||
"title": "Model Version",
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
]
|
||||
},
|
||||
"top_k": {
|
||||
"title": "Top K",
|
||||
@ -5534,8 +5577,7 @@
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"knowledge_graph",
|
||||
"model_name_or_path"
|
||||
"knowledge_graph"
|
||||
],
|
||||
"additionalProperties": false,
|
||||
"description": "Each parameter can reference other components defined in the same YAML file."
|
||||
|
||||
@ -25,6 +25,7 @@ class EntityExtractor(BaseComponent):
|
||||
The entities extracted by this Node will populate Document.entities
|
||||
|
||||
:param model_name_or_path: The name of the model to use for entity extraction.
|
||||
:param model_version: The version of the model to use for entity extraction.
|
||||
:param use_gpu: Whether to use the GPU or not.
|
||||
:param batch_size: The batch size to use for entity extraction.
|
||||
:param progress_bar: Whether to show a progress bar or not.
|
||||
@ -44,6 +45,7 @@ class EntityExtractor(BaseComponent):
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: str = "dslim/bert-base-NER",
|
||||
model_version: Optional[str] = None,
|
||||
use_gpu: bool = True,
|
||||
batch_size: int = 16,
|
||||
progress_bar: bool = True,
|
||||
@ -58,7 +60,7 @@ class EntityExtractor(BaseComponent):
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_auth_token=use_auth_token)
|
||||
token_classifier = AutoModelForTokenClassification.from_pretrained(
|
||||
model_name_or_path, use_auth_token=use_auth_token
|
||||
model_name_or_path, use_auth_token=use_auth_token, revision=model_version
|
||||
)
|
||||
token_classifier.to(str(self.devices[0]))
|
||||
self.model = pipeline(
|
||||
|
||||
@ -31,18 +31,18 @@ class QuestionGenerator(BaseComponent):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path="valhalla/t5-base-e2e-qg",
|
||||
model_version=None,
|
||||
num_beams=4,
|
||||
max_length=256,
|
||||
no_repeat_ngram_size=3,
|
||||
length_penalty=1.5,
|
||||
early_stopping=True,
|
||||
split_length=50,
|
||||
split_overlap=10,
|
||||
use_gpu=True,
|
||||
prompt="generate questions:",
|
||||
num_queries_per_doc=1,
|
||||
model_name_or_path: str = "valhalla/t5-base-e2e-qg",
|
||||
model_version: Optional[str] = None,
|
||||
num_beams: int = 4,
|
||||
max_length: int = 256,
|
||||
no_repeat_ngram_size: int = 3,
|
||||
length_penalty: float = 1.5,
|
||||
early_stopping: bool = True,
|
||||
split_length: int = 50,
|
||||
split_overlap: int = 10,
|
||||
use_gpu: bool = True,
|
||||
prompt: str = "generate questions:",
|
||||
num_queries_per_doc: int = 1,
|
||||
sep_token: str = "<sep>",
|
||||
batch_size: int = 16,
|
||||
progress_bar: bool = True,
|
||||
@ -79,7 +79,9 @@ class QuestionGenerator(BaseComponent):
|
||||
f"Multiple devices are not supported in {self.__class__.__name__} inference, "
|
||||
f"using the first device {self.devices[0]}."
|
||||
)
|
||||
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, use_auth_token=use_auth_token)
|
||||
self.model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||
model_name_or_path, revision=model_version, use_auth_token=use_auth_token
|
||||
)
|
||||
self.model.to(str(self.devices[0]))
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_auth_token=use_auth_token)
|
||||
self.num_beams = num_beams
|
||||
|
||||
@ -2,6 +2,8 @@ from typing import Optional, List, Union
|
||||
|
||||
import logging
|
||||
from transformers import BartForConditionalGeneration, BartTokenizer
|
||||
|
||||
from haystack.document_stores import BaseKnowledgeGraph
|
||||
from haystack.nodes.retriever.base import BaseGraphRetriever
|
||||
|
||||
|
||||
@ -16,13 +18,19 @@ class Text2SparqlRetriever(BaseGraphRetriever):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, knowledge_graph, model_name_or_path, top_k: int = 1, use_auth_token: Optional[Union[str, bool]] = None
|
||||
self,
|
||||
knowledge_graph: BaseKnowledgeGraph,
|
||||
model_name_or_path: str = None,
|
||||
model_version: Optional[str] = None,
|
||||
top_k: int = 1,
|
||||
use_auth_token: Optional[Union[str, bool]] = None,
|
||||
):
|
||||
"""
|
||||
Init the Retriever by providing a knowledge graph and a pre-trained BART model
|
||||
|
||||
:param knowledge_graph: An instance of BaseKnowledgeGraph on which to execute SPARQL queries.
|
||||
:param model_name_or_path: Name of or path to a pre-trained BartForConditionalGeneration model.
|
||||
:param model_version: The version of the model to use for entity extraction.
|
||||
:param top_k: How many SPARQL queries to generate per text query.
|
||||
:param use_auth_token: The API token used to download private models from Huggingface.
|
||||
If this parameter is set to `True`, then the token generated when running
|
||||
@ -35,7 +43,7 @@ class Text2SparqlRetriever(BaseGraphRetriever):
|
||||
self.knowledge_graph = knowledge_graph
|
||||
# TODO We should extend this to any seq2seq models and use the AutoModel class
|
||||
self.model = BartForConditionalGeneration.from_pretrained(
|
||||
model_name_or_path, forced_bos_token_id=0, use_auth_token=use_auth_token
|
||||
model_name_or_path, forced_bos_token_id=0, use_auth_token=use_auth_token, revision=model_version
|
||||
)
|
||||
self.tok = BartTokenizer.from_pretrained(model_name_or_path, use_auth_token=use_auth_token)
|
||||
self.top_k = top_k
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user