diff --git a/haystack/utils/docker.py b/haystack/utils/docker.py index 1993f5f2d..5e182a6df 100644 --- a/haystack/utils/docker.py +++ b/haystack/utils/docker.py @@ -1,21 +1,25 @@ import logging -def cache_models(): +def cache_models(models=None): """ Small function that caches models and other data. Used only in the Dockerfile to include these caches in the images. """ + # Backward compat after adding the `model` param + if models is None: + models = ["deepset/roberta-base-squad2"] + # download punkt tokenizer logging.info("Caching punkt data") import nltk - nltk.download("punkt", download_dir="/root/nltk_data") + nltk.download("punkt") - # Cache roberta-base-squad2 model - logging.info("Caching deepset/roberta-base-squad2") + # Cache models import transformers - model_to_cache = "deepset/roberta-base-squad2" - transformers.AutoTokenizer.from_pretrained(model_to_cache) - transformers.AutoModel.from_pretrained(model_to_cache) + for model_to_cache in models: + logging.info(f"Caching {model_to_cache}") + transformers.AutoTokenizer.from_pretrained(model_to_cache) + transformers.AutoModel.from_pretrained(model_to_cache)