mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-10-31 09:49:48 +00:00 
			
		
		
		
	Fix saving tokenizers in DPR training + unify save and load dirs (#682)
This commit is contained in:
		
							parent
							
								
									4c2804e38e
								
							
						
					
					
						commit
						a9bcabc42d
					
				| @ -240,9 +240,9 @@ class DensePassageRetriever(BaseRetriever): | ||||
|               grad_acc_steps: int = 1, | ||||
|               optimizer_name: str = "TransformersAdamW", | ||||
|               optimizer_correct_bias: bool = True, | ||||
|               save_dir: str = "../saved_models/dpr-tutorial", | ||||
|               query_encoder_save_dir: str = "lm1", | ||||
|               passage_encoder_save_dir: str = "lm2" | ||||
|               save_dir: str = "../saved_models/dpr", | ||||
|               query_encoder_save_dir: str = "query_encoder", | ||||
|               passage_encoder_save_dir: str = "passage_encoder" | ||||
|               ): | ||||
|         """ | ||||
|         train a DensePassageRetrieval model | ||||
| @ -317,20 +317,24 @@ class DensePassageRetriever(BaseRetriever): | ||||
|         trainer.train() | ||||
| 
 | ||||
|         self.model.save(Path(save_dir), lm1_name=query_encoder_save_dir, lm2_name=passage_encoder_save_dir) | ||||
|         self.processor.save(Path(save_dir)) | ||||
|         self.query_tokenizer.save_pretrained(f"{save_dir}/{query_encoder_save_dir}") | ||||
|         self.passage_tokenizer.save_pretrained(f"{save_dir}/{passage_encoder_save_dir}") | ||||
| 
 | ||||
|     def save(self, save_dir: Union[Path, str]): | ||||
|     def save(self, save_dir: Union[Path, str], query_encoder_dir: str = "query_encoder", | ||||
|              passage_encoder_dir: str = "passage_encoder"): | ||||
|         """ | ||||
|         Save DensePassageRetriever to the specified directory. | ||||
| 
 | ||||
|         :param save_dir: Directory to save to. | ||||
|         :param query_encoder_dir: Directory in save_dir that contains query encoder model. | ||||
|         :param passage_encoder_dir: Directory in save_dir that contains passage encoder model. | ||||
|         :return: None | ||||
|         """ | ||||
|         save_dir = Path(save_dir) | ||||
|         self.model.save(save_dir, lm1_name="query_encoder", lm2_name="passage_encoder") | ||||
|         self.model.save(save_dir, lm1_name=query_encoder_dir, lm2_name=passage_encoder_dir) | ||||
|         save_dir = str(save_dir) | ||||
|         self.query_tokenizer.save_pretrained(save_dir + "/query_encoder") | ||||
|         self.passage_tokenizer.save_pretrained(save_dir + "/passage_encoder") | ||||
|         self.query_tokenizer.save_pretrained(save_dir + f"/{query_encoder_dir}") | ||||
|         self.passage_tokenizer.save_pretrained(save_dir + f"/{passage_encoder_dir}") | ||||
| 
 | ||||
|     @classmethod | ||||
|     def load(cls, | ||||
| @ -343,6 +347,8 @@ class DensePassageRetriever(BaseRetriever): | ||||
|              embed_title: bool = True, | ||||
|              use_fast_tokenizers: bool = True, | ||||
|              similarity_function: str = "dot_product", | ||||
|              query_encoder_dir: str = "query_encoder", | ||||
|              passage_encoder_dir: str = "passage_encoder" | ||||
|              ): | ||||
|         """ | ||||
|         Load DensePassageRetriever from the specified directory. | ||||
| @ -351,8 +357,8 @@ class DensePassageRetriever(BaseRetriever): | ||||
|         load_dir = Path(load_dir) | ||||
|         dpr = cls( | ||||
|             document_store=document_store, | ||||
|             query_embedding_model=Path(load_dir) / "query_encoder", | ||||
|             passage_embedding_model=Path(load_dir) / "passage_encoder", | ||||
|             query_embedding_model=Path(load_dir) / query_encoder_dir, | ||||
|             passage_embedding_model=Path(load_dir) / passage_encoder_dir, | ||||
|             max_seq_len_query=max_seq_len_query, | ||||
|             max_seq_len_passage=max_seq_len_passage, | ||||
|             use_gpu=use_gpu, | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 bogdankostic
						bogdankostic