mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-24 13:38:53 +00:00
Enable GPU usage for QuestionGenerator (#1571)
* enable GPU usage for question generator * Add latest docstring and tutorial changes Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
54947cb840
commit
38652dd4dd
@ -766,7 +766,7 @@ class SQLDocumentStore(BaseDocumentStore)
|
||||
#### \_\_init\_\_
|
||||
|
||||
```python
|
||||
| __init__(url: str = "sqlite://", index: str = "document", label_index: str = "label", duplicate_documents: str = "overwrite")
|
||||
| __init__(url: str = "sqlite://", index: str = "document", label_index: str = "label", duplicate_documents: str = "overwrite", check_same_thread: bool = False)
|
||||
```
|
||||
|
||||
An SQL backed DocumentStore. Currently supports SQLite, PostgreSQL and MySQL backends.
|
||||
@ -783,6 +783,7 @@ An SQL backed DocumentStore. Currently supports SQLite, PostgreSQL and MySQL bac
|
||||
overwrite: Update any existing documents with the same ID when adding documents.
|
||||
fail: an error is raised if the document ID of the document being added already
|
||||
exists.
|
||||
- `check_same_thread`: Set to False to mitigate multithreading issues in older SQLite versions (see https://docs.sqlalchemy.org/en/14/dialects/sqlite.html?highlight=check_same_thread#threading-pooling-behavior)
|
||||
|
||||
<a name="sql.SQLDocumentStore.get_document_by_id"></a>
|
||||
#### get\_document\_by\_id
|
||||
|
||||
@ -20,7 +20,7 @@ come from earlier in the document.
|
||||
#### \_\_init\_\_
|
||||
|
||||
```python
|
||||
| __init__(model_name_or_path="valhalla/t5-base-e2e-qg", model_version=None, num_beams=4, max_length=256, no_repeat_ngram_size=3, length_penalty=1.5, early_stopping=True, split_length=50, split_overlap=10, prompt="generate questions:")
|
||||
| __init__(model_name_or_path="valhalla/t5-base-e2e-qg", model_version=None, num_beams=4, max_length=256, no_repeat_ngram_size=3, length_penalty=1.5, early_stopping=True, split_length=50, split_overlap=10, use_gpu=True, prompt="generate questions:")
|
||||
```
|
||||
|
||||
Uses the valhalla/t5-base-e2e-qg model by default. This class supports any question generation model that is
|
||||
|
||||
@ -2,6 +2,8 @@ from transformers import AutoModelForSeq2SeqLM
|
||||
from transformers import AutoTokenizer
|
||||
from haystack import BaseComponent, Document
|
||||
from haystack.preprocessor import PreProcessor
|
||||
from haystack.modeling.utils import initialize_device_settings
|
||||
|
||||
from typing import List
|
||||
|
||||
|
||||
@ -26,6 +28,7 @@ class QuestionGenerator(BaseComponent):
|
||||
early_stopping=True,
|
||||
split_length=50,
|
||||
split_overlap=10,
|
||||
use_gpu=True,
|
||||
prompt="generate questions:"):
|
||||
"""
|
||||
Uses the valhalla/t5-base-e2e-qg model by default. This class supports any question generation model that is
|
||||
@ -34,6 +37,8 @@ class QuestionGenerator(BaseComponent):
|
||||
generation is not currently supported.
|
||||
"""
|
||||
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
|
||||
self.device, _ = initialize_device_settings(use_cuda=use_gpu)
|
||||
self.model.to(self.device)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
||||
self.set_config(
|
||||
model_name_or_path=model_name_or_path, model_version=model_version,
|
||||
@ -78,8 +83,8 @@ class QuestionGenerator(BaseComponent):
|
||||
if self.prompt not in split_text:
|
||||
split_text = self.prompt + " " + split_text
|
||||
tokenized = self.tokenizer([split_text], return_tensors="pt")
|
||||
input_ids = tokenized["input_ids"]
|
||||
attention_mask = tokenized["attention_mask"] # necessary if padding is enabled so the model won't attend pad tokens
|
||||
input_ids = tokenized["input_ids"].to(self.device)
|
||||
attention_mask = tokenized["attention_mask"].to(self.device) # necessary if padding is enabled so the model won't attend pad tokens
|
||||
tokens_output = self.model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user