diff --git a/haystack/components/generators/hugging_face_api.py b/haystack/components/generators/hugging_face_api.py index 8fe38c013..a9d987354 100644 --- a/haystack/components/generators/hugging_face_api.py +++ b/haystack/components/generators/hugging_face_api.py @@ -75,7 +75,7 @@ class HuggingFaceAPIGenerator: ``` """ - def __init__( + def __init__( # pylint: disable=too-many-positional-arguments self, api_type: Union[HFGenerationAPIType, str], api_params: Dict[str, str], @@ -179,12 +179,19 @@ class HuggingFaceAPIGenerator: return default_from_dict(cls, data) @component.output_types(replies=List[str], meta=List[Dict[str, Any]]) - def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): + def run( + self, + prompt: str, + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + ): """ Invoke the text generation inference for the given prompt and generation parameters. :param prompt: A string representing the prompt. + :param streaming_callback: + A callback function that is called when a new token is received from the stream. :param generation_kwargs: Additional keyword arguments for text generation. :returns: @@ -194,25 +201,27 @@ class HuggingFaceAPIGenerator: # update generation kwargs by merging with the default ones generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} - if self.streaming_callback: - return self._run_streaming(prompt, generation_kwargs) + # check if streaming_callback is passed + streaming_callback = streaming_callback or self.streaming_callback - return self._run_non_streaming(prompt, generation_kwargs) + stream = streaming_callback is not None + response = self._client.text_generation(prompt, details=True, stream=stream, **generation_kwargs) - def _run_streaming(self, prompt: str, generation_kwargs: Dict[str, Any]): - res_chunk: Iterable[TextGenerationStreamOutput] = self._client.text_generation( - prompt, details=True, stream=True, **generation_kwargs - ) + output = self._get_stream_response(response, streaming_callback) if stream else self._get_response(response) # type: ignore + return output + + def _get_stream_response( + self, response: Iterable[TextGenerationStreamOutput], streaming_callback: Callable[[StreamingChunk], None] + ): chunks: List[StreamingChunk] = [] - # pylint: disable=not-an-iterable - for chunk in res_chunk: + for chunk in response: token: TextGenerationOutputToken = chunk.token if token.special: continue chunk_metadata = {**asdict(token), **(asdict(chunk.details) if chunk.details else {})} stream_chunk = StreamingChunk(token.text, chunk_metadata) chunks.append(stream_chunk) - self.streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method) + streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method) metadata = { "finish_reason": chunks[-1].meta.get("finish_reason", None), "model": self._client.model, @@ -220,13 +229,12 @@ class HuggingFaceAPIGenerator: } return {"replies": ["".join([chunk.content for chunk in chunks])], "meta": [metadata]} - def _run_non_streaming(self, prompt: str, generation_kwargs: Dict[str, Any]): - tgr: TextGenerationOutput = self._client.text_generation(prompt, details=True, **generation_kwargs) + def _get_response(self, response: TextGenerationOutput): meta = [ { "model": self._client.model, - "finish_reason": tgr.details.finish_reason if tgr.details else None, - "usage": {"completion_tokens": len(tgr.details.tokens) if tgr.details else 0}, + "finish_reason": response.details.finish_reason if response.details else None, + "usage": {"completion_tokens": len(response.details.tokens) if response.details else 0}, } ] - return {"replies": [tgr.generated_text], "meta": meta} + return {"replies": [response.generated_text], "meta": meta} diff --git a/haystack/components/generators/hugging_face_local.py b/haystack/components/generators/hugging_face_local.py index 178548c61..0e2c6ae5f 100644 --- a/haystack/components/generators/hugging_face_local.py +++ b/haystack/components/generators/hugging_face_local.py @@ -54,7 +54,7 @@ class HuggingFaceLocalGenerator: ``` """ - def __init__( + def __init__( # pylint: disable=too-many-positional-arguments self, model: str = "google/flan-t5-base", task: Optional[Literal["text-generation", "text2text-generation"]] = None, @@ -204,12 +204,19 @@ class HuggingFaceLocalGenerator: return default_from_dict(cls, data) @component.output_types(replies=List[str]) - def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): + def run( + self, + prompt: str, + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + ): """ Run the text generation model on the given prompt. :param prompt: A string representing the prompt. + :param streaming_callback: + A callback function that is called when a new token is received from the stream. :param generation_kwargs: Additional keyword arguments for text generation. @@ -228,7 +235,10 @@ class HuggingFaceLocalGenerator: # merge generation kwargs from init method with those from run method updated_generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} - if self.streaming_callback: + # check if streaming_callback is passed + streaming_callback = streaming_callback or self.streaming_callback + + if streaming_callback: num_responses = updated_generation_kwargs.get("num_return_sequences", 1) if num_responses > 1: msg = ( @@ -241,7 +251,7 @@ class HuggingFaceLocalGenerator: # streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming updated_generation_kwargs["streamer"] = HFTokenStreamingHandler( self.pipeline.tokenizer, # type: ignore - self.streaming_callback, + streaming_callback, self.stop_words, # type: ignore ) diff --git a/releasenotes/notes/add-streaming-callback-run-param-to-hf-generators-5ebde8fad75cb49f.yaml b/releasenotes/notes/add-streaming-callback-run-param-to-hf-generators-5ebde8fad75cb49f.yaml new file mode 100644 index 000000000..792a6c327 --- /dev/null +++ b/releasenotes/notes/add-streaming-callback-run-param-to-hf-generators-5ebde8fad75cb49f.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Add `streaming_callback` run parameter to `HuggingFaceAPIGenerator` and `HuggingFaceLocalGenerator` to allow users to pass a callback function that will be called after each chunk of the response is generated. diff --git a/test/components/generators/test_hugging_face_api.py b/test/components/generators/test_hugging_face_api.py index ce6c7ecd2..0f4be2f9c 100644 --- a/test/components/generators/test_hugging_face_api.py +++ b/test/components/generators/test_hugging_face_api.py @@ -186,6 +186,7 @@ class TestHuggingFaceAPIGenerator: "details": True, "temperature": 0.6, "stop_sequences": ["stop", "words"], + "stream": False, "max_new_tokens": 512, } @@ -208,7 +209,13 @@ class TestHuggingFaceAPIGenerator: # check kwargs passed to text_generation _, kwargs = mock_text_generation.call_args - assert kwargs == {"details": True, "max_new_tokens": 100, "stop_sequences": [], "temperature": 0.8} + assert kwargs == { + "details": True, + "max_new_tokens": 100, + "stop_sequences": [], + "stream": False, + "temperature": 0.8, + } # Assert that the response contains the generated replies and the right response assert "replies" in response