add gpu support for rag (#669)

* add gpu support for rag

* Update transformers.py
This commit is contained in:
demSd 2020-12-11 12:08:01 +01:00 committed by GitHub
parent 149d98a0fd
commit a0e146dde6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -109,7 +109,6 @@ class RAGenerator(BaseGenerator):
if use_gpu and torch.cuda.is_available():
self.device = torch.device("cuda")
raise AttributeError("Currently RAGenerator does not support GPU, try with use_gpu=False")
else:
self.device = torch.device("cpu")
@ -121,7 +120,7 @@ class RAGenerator(BaseGenerator):
# Also refer refer https://github.com/huggingface/transformers/issues/7829
# self.model = RagSequenceForGeneration.from_pretrained(model_name_or_path)
else:
self.model = RagTokenForGeneration.from_pretrained(model_name_or_path)
self.model = RagTokenForGeneration.from_pretrained(model_name_or_path).to(self.device)
# Copied cat_input_and_doc method from transformers.RagRetriever
# Refer section 2.3 of https://arxiv.org/abs/2005.11401
@ -179,9 +178,9 @@ class RAGenerator(BaseGenerator):
embeddings_in_tensor = torch.cat(
[torch.from_numpy(embedding).unsqueeze(0) for embedding in embeddings],
dim=0
).to(self.device)
)
return embeddings_in_tensor
return embeddings_in_tensor.to(self.device)
def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> Dict:
"""
@ -206,6 +205,7 @@ class RAGenerator(BaseGenerator):
| }}]}
```
"""
torch.set_grad_enabled(False)
if len(documents) == 0:
raise AttributeError("generator need documents to predict the answer")
@ -213,7 +213,7 @@ class RAGenerator(BaseGenerator):
if top_k_answers > self.num_beams:
top_k_answers = self.num_beams
logger.warning(f'top_k_answers value should not be greater than num_beams, '
logger.warning(f'top_k value should not be greater than num_beams, '
f'hence setting it to {top_k_answers}')
# Flatten the documents so easy to reference
@ -235,9 +235,9 @@ class RAGenerator(BaseGenerator):
src_texts=[query],
return_tensors="pt"
)
input_ids = input_dict['input_ids'].to(self.device)
# Query embedding
query_embedding = self.model.question_encoder(input_dict["input_ids"])[0]
query_embedding = self.model.question_encoder(input_ids)[0]
# Prepare contextualized input_ids of documents
# (will be transformed into contextualized inputs inside generator)
@ -262,7 +262,7 @@ class RAGenerator(BaseGenerator):
# TODO: Need transformers 3.4.0
# Refer https://github.com/huggingface/transformers/issues/7871
# Remove input_ids parameter once upgraded to 3.4.0
input_ids=input_dict["input_ids"],
input_ids=input_ids,
context_input_ids=context_input_ids,
context_attention_mask=context_attention_mask,
doc_scores=doc_scores,