mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-10-31 09:49:48 +00:00 
			
		
		
		
	Fix saving tokenizers in DPR training + unify save and load dirs (#682)
This commit is contained in:
		
							parent
							
								
									4c2804e38e
								
							
						
					
					
						commit
						a9bcabc42d
					
				| @ -240,9 +240,9 @@ class DensePassageRetriever(BaseRetriever): | |||||||
|               grad_acc_steps: int = 1, |               grad_acc_steps: int = 1, | ||||||
|               optimizer_name: str = "TransformersAdamW", |               optimizer_name: str = "TransformersAdamW", | ||||||
|               optimizer_correct_bias: bool = True, |               optimizer_correct_bias: bool = True, | ||||||
|               save_dir: str = "../saved_models/dpr-tutorial", |               save_dir: str = "../saved_models/dpr", | ||||||
|               query_encoder_save_dir: str = "lm1", |               query_encoder_save_dir: str = "query_encoder", | ||||||
|               passage_encoder_save_dir: str = "lm2" |               passage_encoder_save_dir: str = "passage_encoder" | ||||||
|               ): |               ): | ||||||
|         """ |         """ | ||||||
|         train a DensePassageRetrieval model |         train a DensePassageRetrieval model | ||||||
| @ -317,20 +317,24 @@ class DensePassageRetriever(BaseRetriever): | |||||||
|         trainer.train() |         trainer.train() | ||||||
| 
 | 
 | ||||||
|         self.model.save(Path(save_dir), lm1_name=query_encoder_save_dir, lm2_name=passage_encoder_save_dir) |         self.model.save(Path(save_dir), lm1_name=query_encoder_save_dir, lm2_name=passage_encoder_save_dir) | ||||||
|         self.processor.save(Path(save_dir)) |         self.query_tokenizer.save_pretrained(f"{save_dir}/{query_encoder_save_dir}") | ||||||
|  |         self.passage_tokenizer.save_pretrained(f"{save_dir}/{passage_encoder_save_dir}") | ||||||
| 
 | 
 | ||||||
|     def save(self, save_dir: Union[Path, str]): |     def save(self, save_dir: Union[Path, str], query_encoder_dir: str = "query_encoder", | ||||||
|  |              passage_encoder_dir: str = "passage_encoder"): | ||||||
|         """ |         """ | ||||||
|         Save DensePassageRetriever to the specified directory. |         Save DensePassageRetriever to the specified directory. | ||||||
| 
 | 
 | ||||||
|         :param save_dir: Directory to save to. |         :param save_dir: Directory to save to. | ||||||
|  |         :param query_encoder_dir: Directory in save_dir that contains query encoder model. | ||||||
|  |         :param passage_encoder_dir: Directory in save_dir that contains passage encoder model. | ||||||
|         :return: None |         :return: None | ||||||
|         """ |         """ | ||||||
|         save_dir = Path(save_dir) |         save_dir = Path(save_dir) | ||||||
|         self.model.save(save_dir, lm1_name="query_encoder", lm2_name="passage_encoder") |         self.model.save(save_dir, lm1_name=query_encoder_dir, lm2_name=passage_encoder_dir) | ||||||
|         save_dir = str(save_dir) |         save_dir = str(save_dir) | ||||||
|         self.query_tokenizer.save_pretrained(save_dir + "/query_encoder") |         self.query_tokenizer.save_pretrained(save_dir + f"/{query_encoder_dir}") | ||||||
|         self.passage_tokenizer.save_pretrained(save_dir + "/passage_encoder") |         self.passage_tokenizer.save_pretrained(save_dir + f"/{passage_encoder_dir}") | ||||||
| 
 | 
 | ||||||
|     @classmethod |     @classmethod | ||||||
|     def load(cls, |     def load(cls, | ||||||
| @ -343,6 +347,8 @@ class DensePassageRetriever(BaseRetriever): | |||||||
|              embed_title: bool = True, |              embed_title: bool = True, | ||||||
|              use_fast_tokenizers: bool = True, |              use_fast_tokenizers: bool = True, | ||||||
|              similarity_function: str = "dot_product", |              similarity_function: str = "dot_product", | ||||||
|  |              query_encoder_dir: str = "query_encoder", | ||||||
|  |              passage_encoder_dir: str = "passage_encoder" | ||||||
|              ): |              ): | ||||||
|         """ |         """ | ||||||
|         Load DensePassageRetriever from the specified directory. |         Load DensePassageRetriever from the specified directory. | ||||||
| @ -351,8 +357,8 @@ class DensePassageRetriever(BaseRetriever): | |||||||
|         load_dir = Path(load_dir) |         load_dir = Path(load_dir) | ||||||
|         dpr = cls( |         dpr = cls( | ||||||
|             document_store=document_store, |             document_store=document_store, | ||||||
|             query_embedding_model=Path(load_dir) / "query_encoder", |             query_embedding_model=Path(load_dir) / query_encoder_dir, | ||||||
|             passage_embedding_model=Path(load_dir) / "passage_encoder", |             passage_embedding_model=Path(load_dir) / passage_encoder_dir, | ||||||
|             max_seq_len_query=max_seq_len_query, |             max_seq_len_query=max_seq_len_query, | ||||||
|             max_seq_len_passage=max_seq_len_passage, |             max_seq_len_passage=max_seq_len_passage, | ||||||
|             use_gpu=use_gpu, |             use_gpu=use_gpu, | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 bogdankostic
						bogdankostic