mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-06 03:57:19 +00:00
add gpu support for rag (#669)
* add gpu support for rag * Update transformers.py
This commit is contained in:
parent
149d98a0fd
commit
a0e146dde6
@ -109,7 +109,6 @@ class RAGenerator(BaseGenerator):
|
||||
|
||||
if use_gpu and torch.cuda.is_available():
|
||||
self.device = torch.device("cuda")
|
||||
raise AttributeError("Currently RAGenerator does not support GPU, try with use_gpu=False")
|
||||
else:
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
@ -121,7 +120,7 @@ class RAGenerator(BaseGenerator):
|
||||
# Also refer refer https://github.com/huggingface/transformers/issues/7829
|
||||
# self.model = RagSequenceForGeneration.from_pretrained(model_name_or_path)
|
||||
else:
|
||||
self.model = RagTokenForGeneration.from_pretrained(model_name_or_path)
|
||||
self.model = RagTokenForGeneration.from_pretrained(model_name_or_path).to(self.device)
|
||||
|
||||
# Copied cat_input_and_doc method from transformers.RagRetriever
|
||||
# Refer section 2.3 of https://arxiv.org/abs/2005.11401
|
||||
@ -179,9 +178,9 @@ class RAGenerator(BaseGenerator):
|
||||
embeddings_in_tensor = torch.cat(
|
||||
[torch.from_numpy(embedding).unsqueeze(0) for embedding in embeddings],
|
||||
dim=0
|
||||
).to(self.device)
|
||||
)
|
||||
|
||||
return embeddings_in_tensor
|
||||
return embeddings_in_tensor.to(self.device)
|
||||
|
||||
def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> Dict:
|
||||
"""
|
||||
@ -206,6 +205,7 @@ class RAGenerator(BaseGenerator):
|
||||
| }}]}
|
||||
```
|
||||
"""
|
||||
torch.set_grad_enabled(False)
|
||||
if len(documents) == 0:
|
||||
raise AttributeError("generator need documents to predict the answer")
|
||||
|
||||
@ -213,7 +213,7 @@ class RAGenerator(BaseGenerator):
|
||||
|
||||
if top_k_answers > self.num_beams:
|
||||
top_k_answers = self.num_beams
|
||||
logger.warning(f'top_k_answers value should not be greater than num_beams, '
|
||||
logger.warning(f'top_k value should not be greater than num_beams, '
|
||||
f'hence setting it to {top_k_answers}')
|
||||
|
||||
# Flatten the documents so easy to reference
|
||||
@ -235,9 +235,9 @@ class RAGenerator(BaseGenerator):
|
||||
src_texts=[query],
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
input_ids = input_dict['input_ids'].to(self.device)
|
||||
# Query embedding
|
||||
query_embedding = self.model.question_encoder(input_dict["input_ids"])[0]
|
||||
query_embedding = self.model.question_encoder(input_ids)[0]
|
||||
|
||||
# Prepare contextualized input_ids of documents
|
||||
# (will be transformed into contextualized inputs inside generator)
|
||||
@ -262,7 +262,7 @@ class RAGenerator(BaseGenerator):
|
||||
# TODO: Need transformers 3.4.0
|
||||
# Refer https://github.com/huggingface/transformers/issues/7871
|
||||
# Remove input_ids parameter once upgraded to 3.4.0
|
||||
input_ids=input_dict["input_ids"],
|
||||
input_ids=input_ids,
|
||||
context_input_ids=context_input_ids,
|
||||
context_attention_mask=context_attention_mask,
|
||||
doc_scores=doc_scores,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user