diff --git a/docs/_src/api/api/document_store.md b/docs/_src/api/api/document_store.md index aad273977..f60d677d2 100644 --- a/docs/_src/api/api/document_store.md +++ b/docs/_src/api/api/document_store.md @@ -766,7 +766,7 @@ class SQLDocumentStore(BaseDocumentStore) #### \_\_init\_\_ ```python - | __init__(url: str = "sqlite://", index: str = "document", label_index: str = "label", duplicate_documents: str = "overwrite") + | __init__(url: str = "sqlite://", index: str = "document", label_index: str = "label", duplicate_documents: str = "overwrite", check_same_thread: bool = False) ``` An SQL backed DocumentStore. Currently supports SQLite, PostgreSQL and MySQL backends. @@ -783,6 +783,7 @@ An SQL backed DocumentStore. Currently supports SQLite, PostgreSQL and MySQL bac overwrite: Update any existing documents with the same ID when adding documents. fail: an error is raised if the document ID of the document being added already exists. +- `check_same_thread`: Set to False to mitigate multithreading issues in older SQLite versions (see https://docs.sqlalchemy.org/en/14/dialects/sqlite.html?highlight=check_same_thread#threading-pooling-behavior) #### get\_document\_by\_id diff --git a/docs/_src/api/api/question_generator.md b/docs/_src/api/api/question_generator.md index 92bebd0cc..b9c15f6b9 100644 --- a/docs/_src/api/api/question_generator.md +++ b/docs/_src/api/api/question_generator.md @@ -20,7 +20,7 @@ come from earlier in the document. #### \_\_init\_\_ ```python - | __init__(model_name_or_path="valhalla/t5-base-e2e-qg", model_version=None, num_beams=4, max_length=256, no_repeat_ngram_size=3, length_penalty=1.5, early_stopping=True, split_length=50, split_overlap=10, prompt="generate questions:") + | __init__(model_name_or_path="valhalla/t5-base-e2e-qg", model_version=None, num_beams=4, max_length=256, no_repeat_ngram_size=3, length_penalty=1.5, early_stopping=True, split_length=50, split_overlap=10, use_gpu=True, prompt="generate questions:") ``` Uses the valhalla/t5-base-e2e-qg model by default. This class supports any question generation model that is diff --git a/haystack/question_generator/question_generator.py b/haystack/question_generator/question_generator.py index 4cb0cb9a1..7b567ec22 100644 --- a/haystack/question_generator/question_generator.py +++ b/haystack/question_generator/question_generator.py @@ -2,6 +2,8 @@ from transformers import AutoModelForSeq2SeqLM from transformers import AutoTokenizer from haystack import BaseComponent, Document from haystack.preprocessor import PreProcessor +from haystack.modeling.utils import initialize_device_settings + from typing import List @@ -26,6 +28,7 @@ class QuestionGenerator(BaseComponent): early_stopping=True, split_length=50, split_overlap=10, + use_gpu=True, prompt="generate questions:"): """ Uses the valhalla/t5-base-e2e-qg model by default. This class supports any question generation model that is @@ -34,6 +37,8 @@ class QuestionGenerator(BaseComponent): generation is not currently supported. """ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path) + self.device, _ = initialize_device_settings(use_cuda=use_gpu) + self.model.to(self.device) self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) self.set_config( model_name_or_path=model_name_or_path, model_version=model_version, @@ -78,8 +83,8 @@ class QuestionGenerator(BaseComponent): if self.prompt not in split_text: split_text = self.prompt + " " + split_text tokenized = self.tokenizer([split_text], return_tensors="pt") - input_ids = tokenized["input_ids"] - attention_mask = tokenized["attention_mask"] # necessary if padding is enabled so the model won't attend pad tokens + input_ids = tokenized["input_ids"].to(self.device) + attention_mask = tokenized["attention_mask"].to(self.device) # necessary if padding is enabled so the model won't attend pad tokens tokens_output = self.model.generate( input_ids=input_ids, attention_mask=attention_mask,