diff --git a/haystack/components/agents/state/state.py b/haystack/components/agents/state/state.py index 1224f0b6a..18f43562f 100644 --- a/haystack/components/agents/state/state.py +++ b/haystack/components/agents/state/state.py @@ -69,7 +69,7 @@ def _validate_schema(schema: Dict[str, Any]) -> None: raise ValueError(f"StateSchema: 'type' for key '{param}' must be a Python type, got {definition['type']}") if definition.get("handler") is not None and not callable(definition["handler"]): raise ValueError(f"StateSchema: 'handler' for key '{param}' must be callable or None") - if param == "messages" and definition["type"] is not List[ChatMessage]: + if param == "messages" and definition["type"] != List[ChatMessage]: raise ValueError(f"StateSchema: 'messages' must be of type List[ChatMessage], got {definition['type']}") diff --git a/haystack/components/agents/state/state_utils.py b/haystack/components/agents/state/state_utils.py index 2b392d812..8b8caec7d 100644 --- a/haystack/components/agents/state/state_utils.py +++ b/haystack/components/agents/state/state_utils.py @@ -31,7 +31,7 @@ def _is_valid_type(obj: Any) -> bool: False """ # Handle Union types (including Optional) - if hasattr(obj, "__origin__") and obj.__origin__ is Union: + if hasattr(obj, "__origin__") and obj.__origin__ == Union: return True # Handle normal classes and generic types @@ -45,7 +45,7 @@ def _is_list_type(type_hint: Any) -> bool: :param type_hint: The type hint to check :return: True if the type hint represents a list, False otherwise """ - return type_hint is list or (hasattr(type_hint, "__origin__") and get_origin(type_hint) is list) + return type_hint == list or (hasattr(type_hint, "__origin__") and get_origin(type_hint) == list) def merge_lists(current: Union[List[T], T, None], new: Union[List[T], T]) -> List[T]: diff --git a/haystack/components/generators/chat/hugging_face_local.py b/haystack/components/generators/chat/hugging_face_local.py index 52a990d79..a174bfb93 100644 --- a/haystack/components/generators/chat/hugging_face_local.py +++ b/haystack/components/generators/chat/hugging_face_local.py @@ -422,7 +422,7 @@ class HuggingFaceLocalChatGenerator: replies = [o.get("generated_text", "") for o in output] # Remove stop words from replies if present - for stop_word in stop_words: + for stop_word in stop_words or []: replies = [reply.replace(stop_word, "").rstrip() for reply in replies] chat_messages = [ diff --git a/releasenotes/notes/fix-state-validate-schema-5ae41ce9c82de61a.yaml b/releasenotes/notes/fix-state-validate-schema-5ae41ce9c82de61a.yaml new file mode 100644 index 000000000..bfb93f8d8 --- /dev/null +++ b/releasenotes/notes/fix-state-validate-schema-5ae41ce9c82de61a.yaml @@ -0,0 +1,5 @@ +--- +fixes: + - | + Fix type comparison in schema validation by replacing `is not` with `!=` when checking the type `List[ChatMessage]`. + This prevents false mismatches due to Python's `is` operator comparing object identity instead of equality.