mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-06 20:17:14 +00:00
fix: Fixed local reader model loading (#3663)
* Fixed local loading issue
This commit is contained in:
parent
450c3d4484
commit
76a16807d5
@ -191,15 +191,17 @@ class Inferencer:
|
||||
|
||||
name = os.path.basename(model_name_or_path)
|
||||
|
||||
# a) either from local dir
|
||||
if os.path.exists(model_name_or_path):
|
||||
# a) non-hf models (i.e. FARM, ONNX) from local dir
|
||||
farm_model_bin = os.path.join(model_name_or_path, "language_model.bin")
|
||||
onnx_model = os.path.join(model_name_or_path, "model.onnx")
|
||||
if os.path.isfile(farm_model_bin) or os.path.isfile(onnx_model):
|
||||
model = BaseAdaptiveModel.load(load_dir=model_name_or_path, device=devices[0], strict=strict)
|
||||
if task_type == "embeddings":
|
||||
processor = InferenceProcessor.load_from_dir(model_name_or_path)
|
||||
else:
|
||||
processor = Processor.load_from_dir(model_name_or_path)
|
||||
|
||||
# b) or from remote transformers model hub
|
||||
# b) transformers models from hub or from local
|
||||
else:
|
||||
if not task_type:
|
||||
raise ValueError(
|
||||
|
||||
@ -1,8 +1,11 @@
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
|
||||
import pytest
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from haystack.modeling.data_handler.inputs import QAInput, Question
|
||||
|
||||
from haystack.schema import Document, Answer
|
||||
@ -250,6 +253,44 @@ def test_farm_reader_update_params(docs):
|
||||
reader.predict(query="Who lives in Berlin?", documents=docs, top_k=3)
|
||||
|
||||
|
||||
# There are 5 different ways to load a FARMReader model.
|
||||
# 1. HuggingFace Hub (online load)
|
||||
# 2. HuggingFace downloaded (local load)
|
||||
# 3. HF Model saved as FARM Model (same works for trained FARM model) (local load)
|
||||
# 4. FARM Model converted to transformers (same as hf local model) (local load)
|
||||
# 5. ONNX Model load (covered by test_farm_reader_onnx_conversion_and_inference)
|
||||
@pytest.mark.integration
|
||||
def test_farm_reader_load_hf_online():
|
||||
# Test Case: 1. HuggingFace Hub (online load)
|
||||
|
||||
hf_model = "hf-internal-testing/tiny-random-RobertaForQuestionAnswering"
|
||||
_ = FARMReader(model_name_or_path=hf_model, use_gpu=False, no_ans_boost=0, num_processes=0)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_farm_reader_load_hf_local(tmp_path):
|
||||
# Test Case: 2. HuggingFace downloaded (local load)
|
||||
|
||||
hf_model = "hf-internal-testing/tiny-random-RobertaForQuestionAnswering"
|
||||
# TODO: change the /tmp to proper tmp_path and get rid of rmtree
|
||||
# local_model_path = str(Path.joinpath(tmp_path, "locally_saved_hf"))
|
||||
local_model_path = "/tmp/locally_saved_hf"
|
||||
model_path = snapshot_download(repo_id=hf_model, revision="main", cache_dir=local_model_path)
|
||||
_ = FARMReader(model_name_or_path=model_path, use_gpu=False, no_ans_boost=0, num_processes=0)
|
||||
rmtree(local_model_path)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_farm_reader_load_farm_local(tmp_path):
|
||||
# Test Case: 3. HF Model saved as FARM Model (same works for trained FARM model) (local load)
|
||||
|
||||
hf_model = "hf-internal-testing/tiny-random-RobertaForQuestionAnswering"
|
||||
local_model_path = f"{tmp_path}/locally_saved_farm"
|
||||
reader = FARMReader(model_name_or_path=hf_model, use_gpu=False, no_ans_boost=0, num_processes=0)
|
||||
reader.save(Path(local_model_path))
|
||||
_ = FARMReader(model_name_or_path=local_model_path, use_gpu=False, no_ans_boost=0, num_processes=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_confidence_scores", [True, False])
|
||||
def test_farm_reader_uses_same_sorting_as_QAPredictionHead(use_confidence_scores):
|
||||
reader = FARMReader(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user