mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-11-03 19:29:32 +00:00 
			
		
		
		
	first version of save_to_remote for HF from FarmReader (#2618)
* first version of save_to_remote for HF from FarmReader * Update Documentation & Code Style * Changes based on comments * Update Documentation & Code Style * imports order * making small changes to pydoc * indent fix * Update Documentation & Code Style * keyword arguments instead of positional * Changing to repo_id huggingface-hub package would have to be v0.5 or higher - checking how to handle with Thomas * Update Documentation & Code Style * adding huggingface-hub dependency 0.5 or above Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Sara Zan <sarazanzo94@gmail.com>
This commit is contained in:
		
							parent
							
								
									f7d00476f9
								
							
						
					
					
						commit
						2a8b129bae
					
				@ -333,6 +333,25 @@ Saves the Reader model so that it can be reused at a later point in time.
 | 
			
		||||
 | 
			
		||||
- `directory`: Directory where the Reader model should be saved
 | 
			
		||||
 | 
			
		||||
<a id="farm.FARMReader.save_to_remote"></a>
 | 
			
		||||
 | 
			
		||||
#### FARMReader.save\_to\_remote
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
def save_to_remote(repo_id: str, private: Optional[bool] = None, commit_message: str = "Add new model to Hugging Face.")
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Saves the Reader model to Hugging Face Model Hub with the given model_name. For this to work:
 | 
			
		||||
 | 
			
		||||
- Be logged in to Hugging Face on your machine via transformers-cli
 | 
			
		||||
- Have git lfs installed (https://packagecloud.io/github/git-lfs/install), you can test it by git lfs --version
 | 
			
		||||
 | 
			
		||||
**Arguments**:
 | 
			
		||||
 | 
			
		||||
- `repo_id`: A namespace (user or an organization) and a repo name separated by a '/' of the model you want to save to Hugging Face
 | 
			
		||||
- `private`: Set to true to make the model repository private
 | 
			
		||||
- `commit_message`: Commit message while saving to Hugging Face
 | 
			
		||||
 | 
			
		||||
<a id="farm.FARMReader.predict_batch"></a>
 | 
			
		||||
 | 
			
		||||
#### FARMReader.predict\_batch
 | 
			
		||||
 | 
			
		||||
@ -4,8 +4,12 @@ import logging
 | 
			
		||||
import multiprocessing
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
import os
 | 
			
		||||
import tempfile
 | 
			
		||||
from time import perf_counter
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from huggingface_hub import create_repo, HfFolder, Repository
 | 
			
		||||
 | 
			
		||||
from haystack.errors import HaystackError
 | 
			
		||||
from haystack.modeling.data_handler.data_silo import DataSilo, DistillationDataSilo
 | 
			
		||||
@ -688,6 +692,58 @@ class FARMReader(BaseReader):
 | 
			
		||||
        self.inferencer.model.save(directory)
 | 
			
		||||
        self.inferencer.processor.save(directory)
 | 
			
		||||
 | 
			
		||||
    def save_to_remote(
 | 
			
		||||
        self, repo_id: str, private: Optional[bool] = None, commit_message: str = "Add new model to Hugging Face."
 | 
			
		||||
    ):
 | 
			
		||||
        """
 | 
			
		||||
        Saves the Reader model to Hugging Face Model Hub with the given model_name. For this to work:
 | 
			
		||||
        - Be logged in to Hugging Face on your machine via transformers-cli
 | 
			
		||||
        - Have git lfs installed (https://packagecloud.io/github/git-lfs/install), you can test it by git lfs --version
 | 
			
		||||
 | 
			
		||||
        :param repo_id: A namespace (user or an organization) and a repo name separated by a '/' of the model you want to save to Hugging Face
 | 
			
		||||
        :param private: Set to true to make the model repository private
 | 
			
		||||
        :param commit_message: Commit message while saving to Hugging Face
 | 
			
		||||
        """
 | 
			
		||||
        # Note: This function was inspired by the save_to_hub function in the sentence-transformers repo (https://github.com/UKPLab/sentence-transformers/)
 | 
			
		||||
        # Especially for git-lfs tracking.
 | 
			
		||||
 | 
			
		||||
        token = HfFolder.get_token()
 | 
			
		||||
        if token is None:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                "To save this reader model to Hugging Face, make sure you login to the hub on this computer by typing `transformers-cli login`."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        repo_url = create_repo(token=token, repo_id=repo_id, private=private, repo_type=None, exist_ok=True)
 | 
			
		||||
 | 
			
		||||
        transformer_models = self.inferencer.model.convert_to_transformers()
 | 
			
		||||
 | 
			
		||||
        with tempfile.TemporaryDirectory() as tmp_dir:
 | 
			
		||||
            repo = Repository(tmp_dir, clone_from=repo_url)
 | 
			
		||||
 | 
			
		||||
            self.inferencer.processor.tokenizer.save_pretrained(tmp_dir)
 | 
			
		||||
 | 
			
		||||
            # convert_to_transformers (above) creates one model per prediction head.
 | 
			
		||||
            # As the FarmReader models only have one head (QA) we go with this.
 | 
			
		||||
            transformer_models[0].save_pretrained(tmp_dir)
 | 
			
		||||
 | 
			
		||||
            large_files = []
 | 
			
		||||
            for root, dirs, files in os.walk(tmp_dir):
 | 
			
		||||
                for filename in files:
 | 
			
		||||
                    file_path = os.path.join(root, filename)
 | 
			
		||||
                    rel_path = os.path.relpath(file_path, tmp_dir)
 | 
			
		||||
 | 
			
		||||
                    if os.path.getsize(file_path) > (5 * 1024 * 1024):
 | 
			
		||||
                        large_files.append(rel_path)
 | 
			
		||||
 | 
			
		||||
            if len(large_files) > 0:
 | 
			
		||||
                logger.info("Track files with git lfs: {}".format(", ".join(large_files)))
 | 
			
		||||
                repo.lfs_track(large_files)
 | 
			
		||||
 | 
			
		||||
            logger.info("Push model to the hub. This might take a while")
 | 
			
		||||
            commit_url = repo.push_to_hub(commit_message=commit_message)
 | 
			
		||||
 | 
			
		||||
        return commit_url
 | 
			
		||||
 | 
			
		||||
    def predict_batch(
 | 
			
		||||
        self,
 | 
			
		||||
        queries: List[str],
 | 
			
		||||
 | 
			
		||||
@ -75,6 +75,8 @@ install_requires =
 | 
			
		||||
    # azure-core>=1.23 needs typing-extensions>=4.0.1
 | 
			
		||||
    # pip unfortunately backtracks into the databind direction ultimately getting lost.
 | 
			
		||||
    azure-core<1.23
 | 
			
		||||
    # audio's espnet-model-zoo requires huggingface-hub version <0.8 while we need >=0.5 to be able to use create_repo in FARMReader
 | 
			
		||||
    huggingface-hub<0.8.0,>=0.5.0
 | 
			
		||||
 | 
			
		||||
    # Preprocessing
 | 
			
		||||
    more_itertools  # for windowing
 | 
			
		||||
@ -157,7 +159,6 @@ audio =
 | 
			
		||||
    espnet
 | 
			
		||||
    espnet-model-zoo
 | 
			
		||||
    pydub
 | 
			
		||||
    huggingface-hub<0.8.0
 | 
			
		||||
beir =
 | 
			
		||||
    beir; platform_system != 'Windows'
 | 
			
		||||
crawler =
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user