diff --git a/haystack/modeling/infer.py b/haystack/modeling/infer.py index 081254eea..315727197 100644 --- a/haystack/modeling/infer.py +++ b/haystack/modeling/infer.py @@ -191,15 +191,17 @@ class Inferencer: name = os.path.basename(model_name_or_path) - # a) either from local dir - if os.path.exists(model_name_or_path): + # a) non-hf models (i.e. FARM, ONNX) from local dir + farm_model_bin = os.path.join(model_name_or_path, "language_model.bin") + onnx_model = os.path.join(model_name_or_path, "model.onnx") + if os.path.isfile(farm_model_bin) or os.path.isfile(onnx_model): model = BaseAdaptiveModel.load(load_dir=model_name_or_path, device=devices[0], strict=strict) if task_type == "embeddings": processor = InferenceProcessor.load_from_dir(model_name_or_path) else: processor = Processor.load_from_dir(model_name_or_path) - # b) or from remote transformers model hub + # b) transformers models from hub or from local else: if not task_type: raise ValueError( diff --git a/test/nodes/test_reader.py b/test/nodes/test_reader.py index a8248be70..dfd7182ff 100644 --- a/test/nodes/test_reader.py +++ b/test/nodes/test_reader.py @@ -1,8 +1,11 @@ import math import os from pathlib import Path +from shutil import rmtree import pytest + +from huggingface_hub import snapshot_download from haystack.modeling.data_handler.inputs import QAInput, Question from haystack.schema import Document, Answer @@ -250,6 +253,44 @@ def test_farm_reader_update_params(docs): reader.predict(query="Who lives in Berlin?", documents=docs, top_k=3) +# There are 5 different ways to load a FARMReader model. +# 1. HuggingFace Hub (online load) +# 2. HuggingFace downloaded (local load) +# 3. HF Model saved as FARM Model (same works for trained FARM model) (local load) +# 4. FARM Model converted to transformers (same as hf local model) (local load) +# 5. ONNX Model load (covered by test_farm_reader_onnx_conversion_and_inference) +@pytest.mark.integration +def test_farm_reader_load_hf_online(): + # Test Case: 1. HuggingFace Hub (online load) + + hf_model = "hf-internal-testing/tiny-random-RobertaForQuestionAnswering" + _ = FARMReader(model_name_or_path=hf_model, use_gpu=False, no_ans_boost=0, num_processes=0) + + +@pytest.mark.integration +def test_farm_reader_load_hf_local(tmp_path): + # Test Case: 2. HuggingFace downloaded (local load) + + hf_model = "hf-internal-testing/tiny-random-RobertaForQuestionAnswering" + # TODO: change the /tmp to proper tmp_path and get rid of rmtree + # local_model_path = str(Path.joinpath(tmp_path, "locally_saved_hf")) + local_model_path = "/tmp/locally_saved_hf" + model_path = snapshot_download(repo_id=hf_model, revision="main", cache_dir=local_model_path) + _ = FARMReader(model_name_or_path=model_path, use_gpu=False, no_ans_boost=0, num_processes=0) + rmtree(local_model_path) + + +@pytest.mark.integration +def test_farm_reader_load_farm_local(tmp_path): + # Test Case: 3. HF Model saved as FARM Model (same works for trained FARM model) (local load) + + hf_model = "hf-internal-testing/tiny-random-RobertaForQuestionAnswering" + local_model_path = f"{tmp_path}/locally_saved_farm" + reader = FARMReader(model_name_or_path=hf_model, use_gpu=False, no_ans_boost=0, num_processes=0) + reader.save(Path(local_model_path)) + _ = FARMReader(model_name_or_path=local_model_path, use_gpu=False, no_ans_boost=0, num_processes=0) + + @pytest.mark.parametrize("use_confidence_scores", [True, False]) def test_farm_reader_uses_same_sorting_as_QAPredictionHead(use_confidence_scores): reader = FARMReader(