mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-05 19:36:55 +00:00
feat: streaming_callback as run param from HF generators (#8406)
* feat: streaming_callback as run param from HF generators * apply feedback * add reno * fix test * fix test * fix mypy * fix excessive linting rule
This commit is contained in:
parent
811a54a3ef
commit
d430833f8f
@ -75,7 +75,7 @@ class HuggingFaceAPIGenerator:
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__( # pylint: disable=too-many-positional-arguments
|
||||||
self,
|
self,
|
||||||
api_type: Union[HFGenerationAPIType, str],
|
api_type: Union[HFGenerationAPIType, str],
|
||||||
api_params: Dict[str, str],
|
api_params: Dict[str, str],
|
||||||
@ -179,12 +179,19 @@ class HuggingFaceAPIGenerator:
|
|||||||
return default_from_dict(cls, data)
|
return default_from_dict(cls, data)
|
||||||
|
|
||||||
@component.output_types(replies=List[str], meta=List[Dict[str, Any]])
|
@component.output_types(replies=List[str], meta=List[Dict[str, Any]])
|
||||||
def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
|
def run(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||||
|
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Invoke the text generation inference for the given prompt and generation parameters.
|
Invoke the text generation inference for the given prompt and generation parameters.
|
||||||
|
|
||||||
:param prompt:
|
:param prompt:
|
||||||
A string representing the prompt.
|
A string representing the prompt.
|
||||||
|
:param streaming_callback:
|
||||||
|
A callback function that is called when a new token is received from the stream.
|
||||||
:param generation_kwargs:
|
:param generation_kwargs:
|
||||||
Additional keyword arguments for text generation.
|
Additional keyword arguments for text generation.
|
||||||
:returns:
|
:returns:
|
||||||
@ -194,25 +201,27 @@ class HuggingFaceAPIGenerator:
|
|||||||
# update generation kwargs by merging with the default ones
|
# update generation kwargs by merging with the default ones
|
||||||
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
|
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
|
||||||
|
|
||||||
if self.streaming_callback:
|
# check if streaming_callback is passed
|
||||||
return self._run_streaming(prompt, generation_kwargs)
|
streaming_callback = streaming_callback or self.streaming_callback
|
||||||
|
|
||||||
return self._run_non_streaming(prompt, generation_kwargs)
|
stream = streaming_callback is not None
|
||||||
|
response = self._client.text_generation(prompt, details=True, stream=stream, **generation_kwargs)
|
||||||
|
|
||||||
def _run_streaming(self, prompt: str, generation_kwargs: Dict[str, Any]):
|
output = self._get_stream_response(response, streaming_callback) if stream else self._get_response(response) # type: ignore
|
||||||
res_chunk: Iterable[TextGenerationStreamOutput] = self._client.text_generation(
|
return output
|
||||||
prompt, details=True, stream=True, **generation_kwargs
|
|
||||||
)
|
def _get_stream_response(
|
||||||
|
self, response: Iterable[TextGenerationStreamOutput], streaming_callback: Callable[[StreamingChunk], None]
|
||||||
|
):
|
||||||
chunks: List[StreamingChunk] = []
|
chunks: List[StreamingChunk] = []
|
||||||
# pylint: disable=not-an-iterable
|
for chunk in response:
|
||||||
for chunk in res_chunk:
|
|
||||||
token: TextGenerationOutputToken = chunk.token
|
token: TextGenerationOutputToken = chunk.token
|
||||||
if token.special:
|
if token.special:
|
||||||
continue
|
continue
|
||||||
chunk_metadata = {**asdict(token), **(asdict(chunk.details) if chunk.details else {})}
|
chunk_metadata = {**asdict(token), **(asdict(chunk.details) if chunk.details else {})}
|
||||||
stream_chunk = StreamingChunk(token.text, chunk_metadata)
|
stream_chunk = StreamingChunk(token.text, chunk_metadata)
|
||||||
chunks.append(stream_chunk)
|
chunks.append(stream_chunk)
|
||||||
self.streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method)
|
streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method)
|
||||||
metadata = {
|
metadata = {
|
||||||
"finish_reason": chunks[-1].meta.get("finish_reason", None),
|
"finish_reason": chunks[-1].meta.get("finish_reason", None),
|
||||||
"model": self._client.model,
|
"model": self._client.model,
|
||||||
@ -220,13 +229,12 @@ class HuggingFaceAPIGenerator:
|
|||||||
}
|
}
|
||||||
return {"replies": ["".join([chunk.content for chunk in chunks])], "meta": [metadata]}
|
return {"replies": ["".join([chunk.content for chunk in chunks])], "meta": [metadata]}
|
||||||
|
|
||||||
def _run_non_streaming(self, prompt: str, generation_kwargs: Dict[str, Any]):
|
def _get_response(self, response: TextGenerationOutput):
|
||||||
tgr: TextGenerationOutput = self._client.text_generation(prompt, details=True, **generation_kwargs)
|
|
||||||
meta = [
|
meta = [
|
||||||
{
|
{
|
||||||
"model": self._client.model,
|
"model": self._client.model,
|
||||||
"finish_reason": tgr.details.finish_reason if tgr.details else None,
|
"finish_reason": response.details.finish_reason if response.details else None,
|
||||||
"usage": {"completion_tokens": len(tgr.details.tokens) if tgr.details else 0},
|
"usage": {"completion_tokens": len(response.details.tokens) if response.details else 0},
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
return {"replies": [tgr.generated_text], "meta": meta}
|
return {"replies": [response.generated_text], "meta": meta}
|
||||||
|
|||||||
@ -54,7 +54,7 @@ class HuggingFaceLocalGenerator:
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__( # pylint: disable=too-many-positional-arguments
|
||||||
self,
|
self,
|
||||||
model: str = "google/flan-t5-base",
|
model: str = "google/flan-t5-base",
|
||||||
task: Optional[Literal["text-generation", "text2text-generation"]] = None,
|
task: Optional[Literal["text-generation", "text2text-generation"]] = None,
|
||||||
@ -204,12 +204,19 @@ class HuggingFaceLocalGenerator:
|
|||||||
return default_from_dict(cls, data)
|
return default_from_dict(cls, data)
|
||||||
|
|
||||||
@component.output_types(replies=List[str])
|
@component.output_types(replies=List[str])
|
||||||
def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
|
def run(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||||
|
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Run the text generation model on the given prompt.
|
Run the text generation model on the given prompt.
|
||||||
|
|
||||||
:param prompt:
|
:param prompt:
|
||||||
A string representing the prompt.
|
A string representing the prompt.
|
||||||
|
:param streaming_callback:
|
||||||
|
A callback function that is called when a new token is received from the stream.
|
||||||
:param generation_kwargs:
|
:param generation_kwargs:
|
||||||
Additional keyword arguments for text generation.
|
Additional keyword arguments for text generation.
|
||||||
|
|
||||||
@ -228,7 +235,10 @@ class HuggingFaceLocalGenerator:
|
|||||||
# merge generation kwargs from init method with those from run method
|
# merge generation kwargs from init method with those from run method
|
||||||
updated_generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
|
updated_generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
|
||||||
|
|
||||||
if self.streaming_callback:
|
# check if streaming_callback is passed
|
||||||
|
streaming_callback = streaming_callback or self.streaming_callback
|
||||||
|
|
||||||
|
if streaming_callback:
|
||||||
num_responses = updated_generation_kwargs.get("num_return_sequences", 1)
|
num_responses = updated_generation_kwargs.get("num_return_sequences", 1)
|
||||||
if num_responses > 1:
|
if num_responses > 1:
|
||||||
msg = (
|
msg = (
|
||||||
@ -241,7 +251,7 @@ class HuggingFaceLocalGenerator:
|
|||||||
# streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
|
# streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
|
||||||
updated_generation_kwargs["streamer"] = HFTokenStreamingHandler(
|
updated_generation_kwargs["streamer"] = HFTokenStreamingHandler(
|
||||||
self.pipeline.tokenizer, # type: ignore
|
self.pipeline.tokenizer, # type: ignore
|
||||||
self.streaming_callback,
|
streaming_callback,
|
||||||
self.stop_words, # type: ignore
|
self.stop_words, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -0,0 +1,4 @@
|
|||||||
|
---
|
||||||
|
enhancements:
|
||||||
|
- |
|
||||||
|
Add `streaming_callback` run parameter to `HuggingFaceAPIGenerator` and `HuggingFaceLocalGenerator` to allow users to pass a callback function that will be called after each chunk of the response is generated.
|
||||||
@ -186,6 +186,7 @@ class TestHuggingFaceAPIGenerator:
|
|||||||
"details": True,
|
"details": True,
|
||||||
"temperature": 0.6,
|
"temperature": 0.6,
|
||||||
"stop_sequences": ["stop", "words"],
|
"stop_sequences": ["stop", "words"],
|
||||||
|
"stream": False,
|
||||||
"max_new_tokens": 512,
|
"max_new_tokens": 512,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -208,7 +209,13 @@ class TestHuggingFaceAPIGenerator:
|
|||||||
|
|
||||||
# check kwargs passed to text_generation
|
# check kwargs passed to text_generation
|
||||||
_, kwargs = mock_text_generation.call_args
|
_, kwargs = mock_text_generation.call_args
|
||||||
assert kwargs == {"details": True, "max_new_tokens": 100, "stop_sequences": [], "temperature": 0.8}
|
assert kwargs == {
|
||||||
|
"details": True,
|
||||||
|
"max_new_tokens": 100,
|
||||||
|
"stop_sequences": [],
|
||||||
|
"stream": False,
|
||||||
|
"temperature": 0.8,
|
||||||
|
}
|
||||||
|
|
||||||
# Assert that the response contains the generated replies and the right response
|
# Assert that the response contains the generated replies and the right response
|
||||||
assert "replies" in response
|
assert "replies" in response
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user