feat: take the list of models to cache instead of hardcoding one (#3060)

* take the list of models to cache as an input

* let nltk find the cache dir on its own
This commit is contained in:
Massimiliano Pippi 2022-08-18 11:55:29 +02:00 committed by GitHub
parent 1027ab3624
commit af24ffae55
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,21 +1,25 @@
import logging
def cache_models():
def cache_models(models=None):
"""
Small function that caches models and other data.
Used only in the Dockerfile to include these caches in the images.
"""
# Backward compat after adding the `model` param
if models is None:
models = ["deepset/roberta-base-squad2"]
# download punkt tokenizer
logging.info("Caching punkt data")
import nltk
nltk.download("punkt", download_dir="/root/nltk_data")
nltk.download("punkt")
# Cache roberta-base-squad2 model
logging.info("Caching deepset/roberta-base-squad2")
# Cache models
import transformers
model_to_cache = "deepset/roberta-base-squad2"
transformers.AutoTokenizer.from_pretrained(model_to_cache)
transformers.AutoModel.from_pretrained(model_to_cache)
for model_to_cache in models:
logging.info(f"Caching {model_to_cache}")
transformers.AutoTokenizer.from_pretrained(model_to_cache)
transformers.AutoModel.from_pretrained(model_to_cache)