fix: Fixed local reader model loading (#3663)

* Fixed local loading issue
This commit is contained in:
Mayank Jobanputra 2022-12-23 23:16:36 +01:00 committed by GitHub
parent 450c3d4484
commit 76a16807d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 46 additions and 3 deletions

View File

@ -191,15 +191,17 @@ class Inferencer:
name = os.path.basename(model_name_or_path)
# a) either from local dir
if os.path.exists(model_name_or_path):
# a) non-hf models (i.e. FARM, ONNX) from local dir
farm_model_bin = os.path.join(model_name_or_path, "language_model.bin")
onnx_model = os.path.join(model_name_or_path, "model.onnx")
if os.path.isfile(farm_model_bin) or os.path.isfile(onnx_model):
model = BaseAdaptiveModel.load(load_dir=model_name_or_path, device=devices[0], strict=strict)
if task_type == "embeddings":
processor = InferenceProcessor.load_from_dir(model_name_or_path)
else:
processor = Processor.load_from_dir(model_name_or_path)
# b) or from remote transformers model hub
# b) transformers models from hub or from local
else:
if not task_type:
raise ValueError(

View File

@ -1,8 +1,11 @@
import math
import os
from pathlib import Path
from shutil import rmtree
import pytest
from huggingface_hub import snapshot_download
from haystack.modeling.data_handler.inputs import QAInput, Question
from haystack.schema import Document, Answer
@ -250,6 +253,44 @@ def test_farm_reader_update_params(docs):
reader.predict(query="Who lives in Berlin?", documents=docs, top_k=3)
# There are 5 different ways to load a FARMReader model.
# 1. HuggingFace Hub (online load)
# 2. HuggingFace downloaded (local load)
# 3. HF Model saved as FARM Model (same works for trained FARM model) (local load)
# 4. FARM Model converted to transformers (same as hf local model) (local load)
# 5. ONNX Model load (covered by test_farm_reader_onnx_conversion_and_inference)
@pytest.mark.integration
def test_farm_reader_load_hf_online():
# Test Case: 1. HuggingFace Hub (online load)
hf_model = "hf-internal-testing/tiny-random-RobertaForQuestionAnswering"
_ = FARMReader(model_name_or_path=hf_model, use_gpu=False, no_ans_boost=0, num_processes=0)
@pytest.mark.integration
def test_farm_reader_load_hf_local(tmp_path):
# Test Case: 2. HuggingFace downloaded (local load)
hf_model = "hf-internal-testing/tiny-random-RobertaForQuestionAnswering"
# TODO: change the /tmp to proper tmp_path and get rid of rmtree
# local_model_path = str(Path.joinpath(tmp_path, "locally_saved_hf"))
local_model_path = "/tmp/locally_saved_hf"
model_path = snapshot_download(repo_id=hf_model, revision="main", cache_dir=local_model_path)
_ = FARMReader(model_name_or_path=model_path, use_gpu=False, no_ans_boost=0, num_processes=0)
rmtree(local_model_path)
@pytest.mark.integration
def test_farm_reader_load_farm_local(tmp_path):
# Test Case: 3. HF Model saved as FARM Model (same works for trained FARM model) (local load)
hf_model = "hf-internal-testing/tiny-random-RobertaForQuestionAnswering"
local_model_path = f"{tmp_path}/locally_saved_farm"
reader = FARMReader(model_name_or_path=hf_model, use_gpu=False, no_ans_boost=0, num_processes=0)
reader.save(Path(local_model_path))
_ = FARMReader(model_name_or_path=local_model_path, use_gpu=False, no_ans_boost=0, num_processes=0)
@pytest.mark.parametrize("use_confidence_scores", [True, False])
def test_farm_reader_uses_same_sorting_as_QAPredictionHead(use_confidence_scores):
reader = FARMReader(