mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-10-31 09:49:48 +00:00 
			
		
		
		
	 13510aa753
			
		
	
	
		13510aa753
		
			
		
	
	
	
	
		
			
			* Files moved, imports all broken * Fix most imports and docstrings into * Fix the paths to the modules in the API docs * Add latest docstring and tutorial changes * Add a few pipelines that were lost in the inports * Fix a bunch of mypy warnings * Add latest docstring and tutorial changes * Create a file_classifier module * Add docs for file_classifier * Fixed most circular imports, now the REST API can start * Add latest docstring and tutorial changes * Tackling more mypy issues * Reintroduce from FARM and fix last mypy issues hopefully * Re-enable old-style imports * Fix some more import from the top-level package in an attempt to sort out circular imports * Fix some imports in tests to new-style to prevent failed class equalities from breaking tests * Change document_store into document_stores * Update imports in tutorials * Add latest docstring and tutorial changes * Probably fixes summarizer tests * Improve the old-style import allowing module imports (should work) * Try to fix the docs * Remove dedicated KnowledgeGraph page from autodocs * Remove dedicated GraphRetriever page from autodocs * Fix generate_docstrings.sh with an updated list of yaml files to look for * Fix some more modules in the docs * Fix the document stores docs too * Fix a small issue on Tutorial14 * Add latest docstring and tutorial changes * Add deprecation warning to old-style imports * Remove stray folder and import Dict into dense.py * Change import path for MLFlowLogger * Add old loggers path to the import path aliases * Fix debug output of convert_ipynb.py * Fix circular import on BaseRetriever * Missed one merge block * re-run tutorial 5 * Fix imports in tutorial 5 * Re-enable squad_to_dpr CLI from the root package and move get_batches_from_generator into document_stores.base * Add latest docstring and tutorial changes * Fix typo in utils __init__ * Fix a few more imports * Fix benchmarks too * New-style imports in test_knowledge_graph * Rollback setup.py * Rollback squad_to_dpr too Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
		
			
				
	
	
		
			86 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			86 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| def tutorial9_dpr_training():
 | |
|     # Training Your Own "Dense Passage Retrieval" Model
 | |
| 
 | |
|     # Here are some imports that we'll need
 | |
| 
 | |
|     from haystack.nodes import DensePassageRetriever
 | |
|     from haystack.utils import fetch_archive_from_http
 | |
|     from haystack.document_stores import InMemoryDocumentStore
 | |
| 
 | |
|     # Download original DPR data
 | |
|     # WARNING: the train set is 7.4GB and the dev set is 800MB
 | |
| 
 | |
|     doc_dir = "data/dpr_training/"
 | |
| 
 | |
|     s3_url_train = "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-train.json.gz"
 | |
|     s3_url_dev = "https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-dev.json.gz"
 | |
| 
 | |
|     fetch_archive_from_http(s3_url_train, output_dir=doc_dir + "train/")
 | |
|     fetch_archive_from_http(s3_url_dev, output_dir=doc_dir + "dev/")
 | |
| 
 | |
|     ## Option 1: Training DPR from Scratch
 | |
| 
 | |
|     # Here are the variables to specify our training data, the models that we use to initialize DPR
 | |
|     # and the directory where we'll be saving the model
 | |
| 
 | |
|     doc_dir = "data/dpr_training/"
 | |
| 
 | |
|     train_filename = "train/biencoder-nq-train.json"
 | |
|     dev_filename = "dev/biencoder-nq-dev.json"
 | |
| 
 | |
|     query_model = "bert-base-uncased"
 | |
|     passage_model = "bert-base-uncased"
 | |
| 
 | |
|     save_dir = "../saved_models/dpr"
 | |
| 
 | |
|     # ## Option 2: Finetuning DPR
 | |
|     #
 | |
|     # # Here are the variables you might want to use instead of the set above
 | |
|     # # in order to perform pretraining
 | |
|     #
 | |
|     # doc_dir = "PATH_TO_YOUR_DATA_DIR"
 | |
|     # train_filename = "TRAIN_FILENAME"
 | |
|     # dev_filename = "DEV_FILENAME"
 | |
|     #
 | |
|     # query_model = "facebook/dpr-question_encoder-single-nq-base"
 | |
|     # passage_model = "facebook/dpr-ctx_encoder-single-nq-base"
 | |
|     #
 | |
|     # save_dir = "..saved_models/dpr"
 | |
| 
 | |
|     ## Initialize DPR model
 | |
| 
 | |
|     retriever = DensePassageRetriever(
 | |
|         document_store=InMemoryDocumentStore(),
 | |
|         query_embedding_model=query_model,
 | |
|         passage_embedding_model=passage_model,
 | |
|         max_seq_len_query=64,
 | |
|         max_seq_len_passage=256
 | |
|     )
 | |
| 
 | |
|     # Start training our model and save it when it is finished
 | |
| 
 | |
|     retriever.train(
 | |
|         data_dir=doc_dir,
 | |
|         train_filename=train_filename,
 | |
|         dev_filename=dev_filename,
 | |
|         test_filename=dev_filename,
 | |
|         n_epochs=1,
 | |
|         batch_size=16,
 | |
|         grad_acc_steps=8,
 | |
|         save_dir=save_dir,
 | |
|         evaluate_every=3000,
 | |
|         embed_title=True,
 | |
|         num_positives=1,
 | |
|         num_hard_negatives=1
 | |
|     )
 | |
| 
 | |
|     ## Loading
 | |
| 
 | |
|     reloaded_retriever = DensePassageRetriever.load(load_dir=save_dir, document_store=None)
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     tutorial9_dpr_training()
 | |
| 
 | |
| # This Haystack script was made with love by deepset in Berlin, Germany
 | |
| # Haystack: https://github.com/deepset-ai/haystack
 | |
| # deepset: https://deepset.ai/ |