Enable GPU usage for QuestionGenerator (#1571)

* enable GPU usage for question generator

* Add latest docstring and tutorial changes

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Malte Pietsch 2021-10-08 12:17:48 +02:00 committed by GitHub
parent 54947cb840
commit 38652dd4dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 10 additions and 4 deletions

View File

@ -766,7 +766,7 @@ class SQLDocumentStore(BaseDocumentStore)
#### \_\_init\_\_
```python
| __init__(url: str = "sqlite://", index: str = "document", label_index: str = "label", duplicate_documents: str = "overwrite")
| __init__(url: str = "sqlite://", index: str = "document", label_index: str = "label", duplicate_documents: str = "overwrite", check_same_thread: bool = False)
```
An SQL backed DocumentStore. Currently supports SQLite, PostgreSQL and MySQL backends.
@ -783,6 +783,7 @@ An SQL backed DocumentStore. Currently supports SQLite, PostgreSQL and MySQL bac
overwrite: Update any existing documents with the same ID when adding documents.
fail: an error is raised if the document ID of the document being added already
exists.
- `check_same_thread`: Set to False to mitigate multithreading issues in older SQLite versions (see https://docs.sqlalchemy.org/en/14/dialects/sqlite.html?highlight=check_same_thread#threading-pooling-behavior)
<a name="sql.SQLDocumentStore.get_document_by_id"></a>
#### get\_document\_by\_id

View File

@ -20,7 +20,7 @@ come from earlier in the document.
#### \_\_init\_\_
```python
| __init__(model_name_or_path="valhalla/t5-base-e2e-qg", model_version=None, num_beams=4, max_length=256, no_repeat_ngram_size=3, length_penalty=1.5, early_stopping=True, split_length=50, split_overlap=10, prompt="generate questions:")
| __init__(model_name_or_path="valhalla/t5-base-e2e-qg", model_version=None, num_beams=4, max_length=256, no_repeat_ngram_size=3, length_penalty=1.5, early_stopping=True, split_length=50, split_overlap=10, use_gpu=True, prompt="generate questions:")
```
Uses the valhalla/t5-base-e2e-qg model by default. This class supports any question generation model that is

View File

@ -2,6 +2,8 @@ from transformers import AutoModelForSeq2SeqLM
from transformers import AutoTokenizer
from haystack import BaseComponent, Document
from haystack.preprocessor import PreProcessor
from haystack.modeling.utils import initialize_device_settings
from typing import List
@ -26,6 +28,7 @@ class QuestionGenerator(BaseComponent):
early_stopping=True,
split_length=50,
split_overlap=10,
use_gpu=True,
prompt="generate questions:"):
"""
Uses the valhalla/t5-base-e2e-qg model by default. This class supports any question generation model that is
@ -34,6 +37,8 @@ class QuestionGenerator(BaseComponent):
generation is not currently supported.
"""
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
self.device, _ = initialize_device_settings(use_cuda=use_gpu)
self.model.to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.set_config(
model_name_or_path=model_name_or_path, model_version=model_version,
@ -78,8 +83,8 @@ class QuestionGenerator(BaseComponent):
if self.prompt not in split_text:
split_text = self.prompt + " " + split_text
tokenized = self.tokenizer([split_text], return_tensors="pt")
input_ids = tokenized["input_ids"]
attention_mask = tokenized["attention_mask"] # necessary if padding is enabled so the model won't attend pad tokens
input_ids = tokenized["input_ids"].to(self.device)
attention_mask = tokenized["attention_mask"].to(self.device) # necessary if padding is enabled so the model won't attend pad tokens
tokens_output = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,