diff --git a/haystack/generator/transformers.py b/haystack/generator/transformers.py index c5d26f08f..fd1595457 100644 --- a/haystack/generator/transformers.py +++ b/haystack/generator/transformers.py @@ -109,7 +109,6 @@ class RAGenerator(BaseGenerator): if use_gpu and torch.cuda.is_available(): self.device = torch.device("cuda") - raise AttributeError("Currently RAGenerator does not support GPU, try with use_gpu=False") else: self.device = torch.device("cpu") @@ -121,7 +120,7 @@ class RAGenerator(BaseGenerator): # Also refer refer https://github.com/huggingface/transformers/issues/7829 # self.model = RagSequenceForGeneration.from_pretrained(model_name_or_path) else: - self.model = RagTokenForGeneration.from_pretrained(model_name_or_path) + self.model = RagTokenForGeneration.from_pretrained(model_name_or_path).to(self.device) # Copied cat_input_and_doc method from transformers.RagRetriever # Refer section 2.3 of https://arxiv.org/abs/2005.11401 @@ -179,9 +178,9 @@ class RAGenerator(BaseGenerator): embeddings_in_tensor = torch.cat( [torch.from_numpy(embedding).unsqueeze(0) for embedding in embeddings], dim=0 - ).to(self.device) + ) - return embeddings_in_tensor + return embeddings_in_tensor.to(self.device) def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> Dict: """ @@ -206,6 +205,7 @@ class RAGenerator(BaseGenerator): | }}]} ``` """ + torch.set_grad_enabled(False) if len(documents) == 0: raise AttributeError("generator need documents to predict the answer") @@ -213,7 +213,7 @@ class RAGenerator(BaseGenerator): if top_k_answers > self.num_beams: top_k_answers = self.num_beams - logger.warning(f'top_k_answers value should not be greater than num_beams, ' + logger.warning(f'top_k value should not be greater than num_beams, ' f'hence setting it to {top_k_answers}') # Flatten the documents so easy to reference @@ -235,9 +235,9 @@ class RAGenerator(BaseGenerator): src_texts=[query], return_tensors="pt" ) - + input_ids = input_dict['input_ids'].to(self.device) # Query embedding - query_embedding = self.model.question_encoder(input_dict["input_ids"])[0] + query_embedding = self.model.question_encoder(input_ids)[0] # Prepare contextualized input_ids of documents # (will be transformed into contextualized inputs inside generator) @@ -262,7 +262,7 @@ class RAGenerator(BaseGenerator): # TODO: Need transformers 3.4.0 # Refer https://github.com/huggingface/transformers/issues/7871 # Remove input_ids parameter once upgraded to 3.4.0 - input_ids=input_dict["input_ids"], + input_ids=input_ids, context_input_ids=context_input_ids, context_attention_mask=context_attention_mask, doc_scores=doc_scores,