fix broken serialization of HFAPI components (#7661)

This commit is contained in:
Stefano Fiorucci 2024-05-08 17:14:37 +02:00 committed by GitHub
parent 94467149c1
commit 7c9532b200
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 13 additions and 8 deletions

View File

@ -175,7 +175,7 @@ class HuggingFaceAPIDocumentEmbedder:
""" """
return default_to_dict( return default_to_dict(
self, self,
api_type=self.api_type, api_type=str(self.api_type),
api_params=self.api_params, api_params=self.api_params,
prefix=self.prefix, prefix=self.prefix,
suffix=self.suffix, suffix=self.suffix,

View File

@ -142,7 +142,7 @@ class HuggingFaceAPITextEmbedder:
""" """
return default_to_dict( return default_to_dict(
self, self,
api_type=self.api_type, api_type=str(self.api_type),
api_params=self.api_params, api_params=self.api_params,
prefix=self.prefix, prefix=self.prefix,
suffix=self.suffix, suffix=self.suffix,

View File

@ -158,7 +158,7 @@ class HuggingFaceAPIChatGenerator:
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
return default_to_dict( return default_to_dict(
self, self,
api_type=self.api_type, api_type=str(self.api_type),
api_params=self.api_params, api_params=self.api_params,
token=self.token.to_dict() if self.token else None, token=self.token.to_dict() if self.token else None,
generation_kwargs=self.generation_kwargs, generation_kwargs=self.generation_kwargs,

View File

@ -142,7 +142,7 @@ class HuggingFaceAPIGenerator:
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
return default_to_dict( return default_to_dict(
self, self,
api_type=self.api_type, api_type=str(self.api_type),
api_params=self.api_params, api_params=self.api_params,
token=self.token.to_dict() if self.token else None, token=self.token.to_dict() if self.token else None,
generation_kwargs=self.generation_kwargs, generation_kwargs=self.generation_kwargs,

View File

@ -0,0 +1,5 @@
---
fixes:
- |
Fix the broken serialization of HuggingFaceAPITextEmbedder, HuggingFaceAPIDocumentEmbedder,
HuggingFaceAPIGenerator, and HuggingFaceAPIChatGenerator.

View File

@ -109,7 +109,7 @@ class TestHuggingFaceAPIDocumentEmbedder:
assert data == { assert data == {
"type": "haystack.components.embedders.hugging_face_api_document_embedder.HuggingFaceAPIDocumentEmbedder", "type": "haystack.components.embedders.hugging_face_api_document_embedder.HuggingFaceAPIDocumentEmbedder",
"init_parameters": { "init_parameters": {
"api_type": HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, "api_type": "serverless_inference_api",
"api_params": {"model": "BAAI/bge-small-en-v1.5"}, "api_params": {"model": "BAAI/bge-small-en-v1.5"},
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
"prefix": "prefix", "prefix": "prefix",

View File

@ -95,7 +95,7 @@ class TestHuggingFaceAPITextEmbedder:
assert data == { assert data == {
"type": "haystack.components.embedders.hugging_face_api_text_embedder.HuggingFaceAPITextEmbedder", "type": "haystack.components.embedders.hugging_face_api_text_embedder.HuggingFaceAPITextEmbedder",
"init_parameters": { "init_parameters": {
"api_type": HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, "api_type": "serverless_inference_api",
"api_params": {"model": "BAAI/bge-small-en-v1.5"}, "api_params": {"model": "BAAI/bge-small-en-v1.5"},
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
"prefix": "prefix", "prefix": "prefix",

View File

@ -138,7 +138,7 @@ class TestHuggingFaceAPIGenerator:
result = generator.to_dict() result = generator.to_dict()
init_params = result["init_parameters"] init_params = result["init_parameters"]
assert init_params["api_type"] == HFGenerationAPIType.SERVERLESS_INFERENCE_API assert init_params["api_type"] == "serverless_inference_api"
assert init_params["api_params"] == {"model": "HuggingFaceH4/zephyr-7b-beta"} assert init_params["api_params"] == {"model": "HuggingFaceH4/zephyr-7b-beta"}
assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"} assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}
assert init_params["generation_kwargs"] == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512} assert init_params["generation_kwargs"] == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512}

View File

@ -131,7 +131,7 @@ class TestHuggingFaceAPIGenerator:
result = generator.to_dict() result = generator.to_dict()
init_params = result["init_parameters"] init_params = result["init_parameters"]
assert init_params["api_type"] == HFGenerationAPIType.SERVERLESS_INFERENCE_API assert init_params["api_type"] == "serverless_inference_api"
assert init_params["api_params"] == {"model": "HuggingFaceH4/zephyr-7b-beta"} assert init_params["api_params"] == {"model": "HuggingFaceH4/zephyr-7b-beta"}
assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"} assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}
assert init_params["generation_kwargs"] == { assert init_params["generation_kwargs"] == {