mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-03 11:19:57 +00:00
first version of save_to_remote for HF from FarmReader (#2618)
* first version of save_to_remote for HF from FarmReader * Update Documentation & Code Style * Changes based on comments * Update Documentation & Code Style * imports order * making small changes to pydoc * indent fix * Update Documentation & Code Style * keyword arguments instead of positional * Changing to repo_id huggingface-hub package would have to be v0.5 or higher - checking how to handle with Thomas * Update Documentation & Code Style * adding huggingface-hub dependency 0.5 or above Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Sara Zan <sarazanzo94@gmail.com>
This commit is contained in:
parent
f7d00476f9
commit
2a8b129bae
@ -333,6 +333,25 @@ Saves the Reader model so that it can be reused at a later point in time.
|
||||
|
||||
- `directory`: Directory where the Reader model should be saved
|
||||
|
||||
<a id="farm.FARMReader.save_to_remote"></a>
|
||||
|
||||
#### FARMReader.save\_to\_remote
|
||||
|
||||
```python
|
||||
def save_to_remote(repo_id: str, private: Optional[bool] = None, commit_message: str = "Add new model to Hugging Face.")
|
||||
```
|
||||
|
||||
Saves the Reader model to Hugging Face Model Hub with the given model_name. For this to work:
|
||||
|
||||
- Be logged in to Hugging Face on your machine via transformers-cli
|
||||
- Have git lfs installed (https://packagecloud.io/github/git-lfs/install), you can test it by git lfs --version
|
||||
|
||||
**Arguments**:
|
||||
|
||||
- `repo_id`: A namespace (user or an organization) and a repo name separated by a '/' of the model you want to save to Hugging Face
|
||||
- `private`: Set to true to make the model repository private
|
||||
- `commit_message`: Commit message while saving to Hugging Face
|
||||
|
||||
<a id="farm.FARMReader.predict_batch"></a>
|
||||
|
||||
#### FARMReader.predict\_batch
|
||||
|
||||
@ -4,8 +4,12 @@ import logging
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
import os
|
||||
import tempfile
|
||||
from time import perf_counter
|
||||
|
||||
import torch
|
||||
from huggingface_hub import create_repo, HfFolder, Repository
|
||||
|
||||
from haystack.errors import HaystackError
|
||||
from haystack.modeling.data_handler.data_silo import DataSilo, DistillationDataSilo
|
||||
@ -688,6 +692,58 @@ class FARMReader(BaseReader):
|
||||
self.inferencer.model.save(directory)
|
||||
self.inferencer.processor.save(directory)
|
||||
|
||||
def save_to_remote(
|
||||
self, repo_id: str, private: Optional[bool] = None, commit_message: str = "Add new model to Hugging Face."
|
||||
):
|
||||
"""
|
||||
Saves the Reader model to Hugging Face Model Hub with the given model_name. For this to work:
|
||||
- Be logged in to Hugging Face on your machine via transformers-cli
|
||||
- Have git lfs installed (https://packagecloud.io/github/git-lfs/install), you can test it by git lfs --version
|
||||
|
||||
:param repo_id: A namespace (user or an organization) and a repo name separated by a '/' of the model you want to save to Hugging Face
|
||||
:param private: Set to true to make the model repository private
|
||||
:param commit_message: Commit message while saving to Hugging Face
|
||||
"""
|
||||
# Note: This function was inspired by the save_to_hub function in the sentence-transformers repo (https://github.com/UKPLab/sentence-transformers/)
|
||||
# Especially for git-lfs tracking.
|
||||
|
||||
token = HfFolder.get_token()
|
||||
if token is None:
|
||||
raise ValueError(
|
||||
"To save this reader model to Hugging Face, make sure you login to the hub on this computer by typing `transformers-cli login`."
|
||||
)
|
||||
|
||||
repo_url = create_repo(token=token, repo_id=repo_id, private=private, repo_type=None, exist_ok=True)
|
||||
|
||||
transformer_models = self.inferencer.model.convert_to_transformers()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
repo = Repository(tmp_dir, clone_from=repo_url)
|
||||
|
||||
self.inferencer.processor.tokenizer.save_pretrained(tmp_dir)
|
||||
|
||||
# convert_to_transformers (above) creates one model per prediction head.
|
||||
# As the FarmReader models only have one head (QA) we go with this.
|
||||
transformer_models[0].save_pretrained(tmp_dir)
|
||||
|
||||
large_files = []
|
||||
for root, dirs, files in os.walk(tmp_dir):
|
||||
for filename in files:
|
||||
file_path = os.path.join(root, filename)
|
||||
rel_path = os.path.relpath(file_path, tmp_dir)
|
||||
|
||||
if os.path.getsize(file_path) > (5 * 1024 * 1024):
|
||||
large_files.append(rel_path)
|
||||
|
||||
if len(large_files) > 0:
|
||||
logger.info("Track files with git lfs: {}".format(", ".join(large_files)))
|
||||
repo.lfs_track(large_files)
|
||||
|
||||
logger.info("Push model to the hub. This might take a while")
|
||||
commit_url = repo.push_to_hub(commit_message=commit_message)
|
||||
|
||||
return commit_url
|
||||
|
||||
def predict_batch(
|
||||
self,
|
||||
queries: List[str],
|
||||
|
||||
@ -75,6 +75,8 @@ install_requires =
|
||||
# azure-core>=1.23 needs typing-extensions>=4.0.1
|
||||
# pip unfortunately backtracks into the databind direction ultimately getting lost.
|
||||
azure-core<1.23
|
||||
# audio's espnet-model-zoo requires huggingface-hub version <0.8 while we need >=0.5 to be able to use create_repo in FARMReader
|
||||
huggingface-hub<0.8.0,>=0.5.0
|
||||
|
||||
# Preprocessing
|
||||
more_itertools # for windowing
|
||||
@ -157,7 +159,6 @@ audio =
|
||||
espnet
|
||||
espnet-model-zoo
|
||||
pydub
|
||||
huggingface-hub<0.8.0
|
||||
beir =
|
||||
beir; platform_system != 'Windows'
|
||||
crawler =
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user