Add models to demo docker image (#1978)

* Add utility to cache models and nltk data & modify Dockerfiles to use it

* Fix punkt data not being cached
This commit is contained in:
Sara Zan 2022-01-11 16:37:45 +01:00 committed by GitHub
parent 192e03be33
commit 9c3d9b4885
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 24 additions and 6 deletions

View File

@ -18,11 +18,10 @@ COPY haystack /home/user/haystack
# install as a package
COPY setup.py requirements.txt README.md /home/user/
RUN pip install --upgrade pip
RUN pip install -r requirements.txt
RUN pip install -e .
# download punkt tokenizer to be included in image
RUN python3 -c "import nltk;nltk.download('punkt', download_dir='/usr/nltk_data')"
RUN python3 -c "from haystack.utils.docker import cache_models;cache_models()"
# create folder for /file-upload API endpoint with write permissions, this might be adjusted depending on FILE_UPLOAD_PATH
RUN mkdir -p /home/user/file-upload

View File

@ -37,15 +37,13 @@ RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.7 1
# Copy package setup files
COPY setup.py requirements.txt README.md /home/user/
RUN pip install --upgrade pip
RUN echo "Install required packages" && \
# Install PyTorch for CUDA 11
pip3 install torch==1.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html && \
# Install from requirements.txt
pip3 install -r requirements.txt
# download punkt tokenizer to be included in image
RUN python3 -c "import nltk;nltk.download('punkt', download_dir='/usr/nltk_data')"
# copy saved models
COPY README.md models* /home/user/models/
@ -58,6 +56,9 @@ COPY haystack /home/user/haystack
# Install package
RUN pip3 install -e .
# Cache Roberta and NLTK data
RUN python3 -c "from haystack.utils.docker import cache_models;cache_models()"
# optional : copy sqlite db if needed for testing
#COPY qa.db /home/user/

18
haystack/utils/docker.py Normal file
View File

@ -0,0 +1,18 @@
import logging
def cache_models():
"""
Small function that caches models and other data.
Used only in the Dockerfile to include these caches in the images.
"""
# download punkt tokenizer
logging.info("Caching punkt data")
import nltk
nltk.download('punkt', download_dir='/root/nltk_data')
# Cache roberta-base-squad2 model
logging.info("Caching deepset/roberta-base-squad2")
import transformers
model_to_cache='deepset/roberta-base-squad2'
transformers.AutoTokenizer.from_pretrained(model_to_cache)
transformers.AutoModel.from_pretrained(model_to_cache)