From c88bc19791981b238b779867b1c6dda6ff2d07c3 Mon Sep 17 00:00:00 2001 From: yuanwu2017 Date: Tue, 2 May 2023 23:04:42 +0800 Subject: [PATCH] fix: load the local finetuning model from pipeline YAML (#4729) (#4760) If using the local model in pipeline YAML. The PromptModel cannot select the HFLocalInvocationLayer, because of the get_task cannot support the offline model. *Local model usage: add the task_name parameter in model_kwargs for local model. for example text-generation or text2text-generation. - name: PModel type: PromptModel params: model_name_or_path: /local_model_path model_kwargs: task_name: text-generation - name: Prompter params: model_name_or_path: PModel default_prompt_template: question-answering type: PromptNode Signed-off-by: yuanwu --- haystack/nodes/prompt/invocation_layer/hugging_face.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/haystack/nodes/prompt/invocation_layer/hugging_face.py b/haystack/nodes/prompt/invocation_layer/hugging_face.py index b68137067..d0a795dc0 100644 --- a/haystack/nodes/prompt/invocation_layer/hugging_face.py +++ b/haystack/nodes/prompt/invocation_layer/hugging_face.py @@ -1,5 +1,6 @@ from typing import Optional, Union, List, Dict import logging +import os import torch @@ -266,6 +267,9 @@ class HFLocalInvocationLayer(PromptModelInvocationLayer): @classmethod def supports(cls, model_name_or_path: str, **kwargs) -> bool: task_name: Optional[str] = None + if os.path.exists(model_name_or_path): + return True + try: task_name = get_task(model_name_or_path, use_auth_token=kwargs.get("use_auth_token", None)) except RuntimeError: