mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-05 03:17:31 +00:00
feat: streaming_callback as run param from HF generators (#8406)
* feat: streaming_callback as run param from HF generators * apply feedback * add reno * fix test * fix test * fix mypy * fix excessive linting rule
This commit is contained in:
parent
811a54a3ef
commit
d430833f8f
@ -75,7 +75,7 @@ class HuggingFaceAPIGenerator:
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
def __init__( # pylint: disable=too-many-positional-arguments
|
||||
self,
|
||||
api_type: Union[HFGenerationAPIType, str],
|
||||
api_params: Dict[str, str],
|
||||
@ -179,12 +179,19 @@ class HuggingFaceAPIGenerator:
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(replies=List[str], meta=List[Dict[str, Any]])
|
||||
def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
|
||||
def run(
|
||||
self,
|
||||
prompt: str,
|
||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Invoke the text generation inference for the given prompt and generation parameters.
|
||||
|
||||
:param prompt:
|
||||
A string representing the prompt.
|
||||
:param streaming_callback:
|
||||
A callback function that is called when a new token is received from the stream.
|
||||
:param generation_kwargs:
|
||||
Additional keyword arguments for text generation.
|
||||
:returns:
|
||||
@ -194,25 +201,27 @@ class HuggingFaceAPIGenerator:
|
||||
# update generation kwargs by merging with the default ones
|
||||
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
|
||||
|
||||
if self.streaming_callback:
|
||||
return self._run_streaming(prompt, generation_kwargs)
|
||||
# check if streaming_callback is passed
|
||||
streaming_callback = streaming_callback or self.streaming_callback
|
||||
|
||||
return self._run_non_streaming(prompt, generation_kwargs)
|
||||
stream = streaming_callback is not None
|
||||
response = self._client.text_generation(prompt, details=True, stream=stream, **generation_kwargs)
|
||||
|
||||
def _run_streaming(self, prompt: str, generation_kwargs: Dict[str, Any]):
|
||||
res_chunk: Iterable[TextGenerationStreamOutput] = self._client.text_generation(
|
||||
prompt, details=True, stream=True, **generation_kwargs
|
||||
)
|
||||
output = self._get_stream_response(response, streaming_callback) if stream else self._get_response(response) # type: ignore
|
||||
return output
|
||||
|
||||
def _get_stream_response(
|
||||
self, response: Iterable[TextGenerationStreamOutput], streaming_callback: Callable[[StreamingChunk], None]
|
||||
):
|
||||
chunks: List[StreamingChunk] = []
|
||||
# pylint: disable=not-an-iterable
|
||||
for chunk in res_chunk:
|
||||
for chunk in response:
|
||||
token: TextGenerationOutputToken = chunk.token
|
||||
if token.special:
|
||||
continue
|
||||
chunk_metadata = {**asdict(token), **(asdict(chunk.details) if chunk.details else {})}
|
||||
stream_chunk = StreamingChunk(token.text, chunk_metadata)
|
||||
chunks.append(stream_chunk)
|
||||
self.streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method)
|
||||
streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method)
|
||||
metadata = {
|
||||
"finish_reason": chunks[-1].meta.get("finish_reason", None),
|
||||
"model": self._client.model,
|
||||
@ -220,13 +229,12 @@ class HuggingFaceAPIGenerator:
|
||||
}
|
||||
return {"replies": ["".join([chunk.content for chunk in chunks])], "meta": [metadata]}
|
||||
|
||||
def _run_non_streaming(self, prompt: str, generation_kwargs: Dict[str, Any]):
|
||||
tgr: TextGenerationOutput = self._client.text_generation(prompt, details=True, **generation_kwargs)
|
||||
def _get_response(self, response: TextGenerationOutput):
|
||||
meta = [
|
||||
{
|
||||
"model": self._client.model,
|
||||
"finish_reason": tgr.details.finish_reason if tgr.details else None,
|
||||
"usage": {"completion_tokens": len(tgr.details.tokens) if tgr.details else 0},
|
||||
"finish_reason": response.details.finish_reason if response.details else None,
|
||||
"usage": {"completion_tokens": len(response.details.tokens) if response.details else 0},
|
||||
}
|
||||
]
|
||||
return {"replies": [tgr.generated_text], "meta": meta}
|
||||
return {"replies": [response.generated_text], "meta": meta}
|
||||
|
||||
@ -54,7 +54,7 @@ class HuggingFaceLocalGenerator:
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
def __init__( # pylint: disable=too-many-positional-arguments
|
||||
self,
|
||||
model: str = "google/flan-t5-base",
|
||||
task: Optional[Literal["text-generation", "text2text-generation"]] = None,
|
||||
@ -204,12 +204,19 @@ class HuggingFaceLocalGenerator:
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(replies=List[str])
|
||||
def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
|
||||
def run(
|
||||
self,
|
||||
prompt: str,
|
||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Run the text generation model on the given prompt.
|
||||
|
||||
:param prompt:
|
||||
A string representing the prompt.
|
||||
:param streaming_callback:
|
||||
A callback function that is called when a new token is received from the stream.
|
||||
:param generation_kwargs:
|
||||
Additional keyword arguments for text generation.
|
||||
|
||||
@ -228,7 +235,10 @@ class HuggingFaceLocalGenerator:
|
||||
# merge generation kwargs from init method with those from run method
|
||||
updated_generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
|
||||
|
||||
if self.streaming_callback:
|
||||
# check if streaming_callback is passed
|
||||
streaming_callback = streaming_callback or self.streaming_callback
|
||||
|
||||
if streaming_callback:
|
||||
num_responses = updated_generation_kwargs.get("num_return_sequences", 1)
|
||||
if num_responses > 1:
|
||||
msg = (
|
||||
@ -241,7 +251,7 @@ class HuggingFaceLocalGenerator:
|
||||
# streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
|
||||
updated_generation_kwargs["streamer"] = HFTokenStreamingHandler(
|
||||
self.pipeline.tokenizer, # type: ignore
|
||||
self.streaming_callback,
|
||||
streaming_callback,
|
||||
self.stop_words, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
Add `streaming_callback` run parameter to `HuggingFaceAPIGenerator` and `HuggingFaceLocalGenerator` to allow users to pass a callback function that will be called after each chunk of the response is generated.
|
||||
@ -186,6 +186,7 @@ class TestHuggingFaceAPIGenerator:
|
||||
"details": True,
|
||||
"temperature": 0.6,
|
||||
"stop_sequences": ["stop", "words"],
|
||||
"stream": False,
|
||||
"max_new_tokens": 512,
|
||||
}
|
||||
|
||||
@ -208,7 +209,13 @@ class TestHuggingFaceAPIGenerator:
|
||||
|
||||
# check kwargs passed to text_generation
|
||||
_, kwargs = mock_text_generation.call_args
|
||||
assert kwargs == {"details": True, "max_new_tokens": 100, "stop_sequences": [], "temperature": 0.8}
|
||||
assert kwargs == {
|
||||
"details": True,
|
||||
"max_new_tokens": 100,
|
||||
"stop_sequences": [],
|
||||
"stream": False,
|
||||
"temperature": 0.8,
|
||||
}
|
||||
|
||||
# Assert that the response contains the generated replies and the right response
|
||||
assert "replies" in response
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user