From af24ffae554eed8f6d92098f191c4e5eb03fc39e Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Thu, 18 Aug 2022 11:55:29 +0200 Subject: [PATCH] feat: take the list of models to cache instead of hardcoding one (#3060) * take the list of models to cache as an input * let nltk find the cache dir on its own --- haystack/utils/docker.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/haystack/utils/docker.py b/haystack/utils/docker.py index 1993f5f2d..5e182a6df 100644 --- a/haystack/utils/docker.py +++ b/haystack/utils/docker.py @@ -1,21 +1,25 @@ import logging -def cache_models(): +def cache_models(models=None): """ Small function that caches models and other data. Used only in the Dockerfile to include these caches in the images. """ + # Backward compat after adding the `model` param + if models is None: + models = ["deepset/roberta-base-squad2"] + # download punkt tokenizer logging.info("Caching punkt data") import nltk - nltk.download("punkt", download_dir="/root/nltk_data") + nltk.download("punkt") - # Cache roberta-base-squad2 model - logging.info("Caching deepset/roberta-base-squad2") + # Cache models import transformers - model_to_cache = "deepset/roberta-base-squad2" - transformers.AutoTokenizer.from_pretrained(model_to_cache) - transformers.AutoModel.from_pretrained(model_to_cache) + for model_to_cache in models: + logging.info(f"Caching {model_to_cache}") + transformers.AutoTokenizer.from_pretrained(model_to_cache) + transformers.AutoModel.from_pretrained(model_to_cache)