mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-30 00:30:09 +00:00
feat: take the list of models to cache instead of hardcoding one (#3060)
* take the list of models to cache as an input * let nltk find the cache dir on its own
This commit is contained in:
parent
1027ab3624
commit
af24ffae55
@ -1,21 +1,25 @@
|
||||
import logging
|
||||
|
||||
|
||||
def cache_models():
|
||||
def cache_models(models=None):
|
||||
"""
|
||||
Small function that caches models and other data.
|
||||
Used only in the Dockerfile to include these caches in the images.
|
||||
"""
|
||||
# Backward compat after adding the `model` param
|
||||
if models is None:
|
||||
models = ["deepset/roberta-base-squad2"]
|
||||
|
||||
# download punkt tokenizer
|
||||
logging.info("Caching punkt data")
|
||||
import nltk
|
||||
|
||||
nltk.download("punkt", download_dir="/root/nltk_data")
|
||||
nltk.download("punkt")
|
||||
|
||||
# Cache roberta-base-squad2 model
|
||||
logging.info("Caching deepset/roberta-base-squad2")
|
||||
# Cache models
|
||||
import transformers
|
||||
|
||||
model_to_cache = "deepset/roberta-base-squad2"
|
||||
transformers.AutoTokenizer.from_pretrained(model_to_cache)
|
||||
transformers.AutoModel.from_pretrained(model_to_cache)
|
||||
for model_to_cache in models:
|
||||
logging.info(f"Caching {model_to_cache}")
|
||||
transformers.AutoTokenizer.from_pretrained(model_to_cache)
|
||||
transformers.AutoModel.from_pretrained(model_to_cache)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user