feat: streaming_callback as run param from HF generators (#8406)

* feat: streaming_callback as run param from HF generators

* apply feedback

* add reno

* fix test

* fix test

* fix mypy

* fix excessive linting rule
This commit is contained in:
tstadel 2024-10-29 15:32:06 +01:00 committed by GitHub
parent 811a54a3ef
commit d430833f8f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 51 additions and 22 deletions

View File

@ -75,7 +75,7 @@ class HuggingFaceAPIGenerator:
```
"""
def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
api_type: Union[HFGenerationAPIType, str],
api_params: Dict[str, str],
@ -179,12 +179,19 @@ class HuggingFaceAPIGenerator:
return default_from_dict(cls, data)
@component.output_types(replies=List[str], meta=List[Dict[str, Any]])
def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
def run(
self,
prompt: str,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Invoke the text generation inference for the given prompt and generation parameters.
:param prompt:
A string representing the prompt.
:param streaming_callback:
A callback function that is called when a new token is received from the stream.
:param generation_kwargs:
Additional keyword arguments for text generation.
:returns:
@ -194,25 +201,27 @@ class HuggingFaceAPIGenerator:
# update generation kwargs by merging with the default ones
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
if self.streaming_callback:
return self._run_streaming(prompt, generation_kwargs)
# check if streaming_callback is passed
streaming_callback = streaming_callback or self.streaming_callback
return self._run_non_streaming(prompt, generation_kwargs)
stream = streaming_callback is not None
response = self._client.text_generation(prompt, details=True, stream=stream, **generation_kwargs)
def _run_streaming(self, prompt: str, generation_kwargs: Dict[str, Any]):
res_chunk: Iterable[TextGenerationStreamOutput] = self._client.text_generation(
prompt, details=True, stream=True, **generation_kwargs
)
output = self._get_stream_response(response, streaming_callback) if stream else self._get_response(response) # type: ignore
return output
def _get_stream_response(
self, response: Iterable[TextGenerationStreamOutput], streaming_callback: Callable[[StreamingChunk], None]
):
chunks: List[StreamingChunk] = []
# pylint: disable=not-an-iterable
for chunk in res_chunk:
for chunk in response:
token: TextGenerationOutputToken = chunk.token
if token.special:
continue
chunk_metadata = {**asdict(token), **(asdict(chunk.details) if chunk.details else {})}
stream_chunk = StreamingChunk(token.text, chunk_metadata)
chunks.append(stream_chunk)
self.streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method)
streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method)
metadata = {
"finish_reason": chunks[-1].meta.get("finish_reason", None),
"model": self._client.model,
@ -220,13 +229,12 @@ class HuggingFaceAPIGenerator:
}
return {"replies": ["".join([chunk.content for chunk in chunks])], "meta": [metadata]}
def _run_non_streaming(self, prompt: str, generation_kwargs: Dict[str, Any]):
tgr: TextGenerationOutput = self._client.text_generation(prompt, details=True, **generation_kwargs)
def _get_response(self, response: TextGenerationOutput):
meta = [
{
"model": self._client.model,
"finish_reason": tgr.details.finish_reason if tgr.details else None,
"usage": {"completion_tokens": len(tgr.details.tokens) if tgr.details else 0},
"finish_reason": response.details.finish_reason if response.details else None,
"usage": {"completion_tokens": len(response.details.tokens) if response.details else 0},
}
]
return {"replies": [tgr.generated_text], "meta": meta}
return {"replies": [response.generated_text], "meta": meta}

View File

@ -54,7 +54,7 @@ class HuggingFaceLocalGenerator:
```
"""
def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
model: str = "google/flan-t5-base",
task: Optional[Literal["text-generation", "text2text-generation"]] = None,
@ -204,12 +204,19 @@ class HuggingFaceLocalGenerator:
return default_from_dict(cls, data)
@component.output_types(replies=List[str])
def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
def run(
self,
prompt: str,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Run the text generation model on the given prompt.
:param prompt:
A string representing the prompt.
:param streaming_callback:
A callback function that is called when a new token is received from the stream.
:param generation_kwargs:
Additional keyword arguments for text generation.
@ -228,7 +235,10 @@ class HuggingFaceLocalGenerator:
# merge generation kwargs from init method with those from run method
updated_generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
if self.streaming_callback:
# check if streaming_callback is passed
streaming_callback = streaming_callback or self.streaming_callback
if streaming_callback:
num_responses = updated_generation_kwargs.get("num_return_sequences", 1)
if num_responses > 1:
msg = (
@ -241,7 +251,7 @@ class HuggingFaceLocalGenerator:
# streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
updated_generation_kwargs["streamer"] = HFTokenStreamingHandler(
self.pipeline.tokenizer, # type: ignore
self.streaming_callback,
streaming_callback,
self.stop_words, # type: ignore
)

View File

@ -0,0 +1,4 @@
---
enhancements:
- |
Add `streaming_callback` run parameter to `HuggingFaceAPIGenerator` and `HuggingFaceLocalGenerator` to allow users to pass a callback function that will be called after each chunk of the response is generated.

View File

@ -186,6 +186,7 @@ class TestHuggingFaceAPIGenerator:
"details": True,
"temperature": 0.6,
"stop_sequences": ["stop", "words"],
"stream": False,
"max_new_tokens": 512,
}
@ -208,7 +209,13 @@ class TestHuggingFaceAPIGenerator:
# check kwargs passed to text_generation
_, kwargs = mock_text_generation.call_args
assert kwargs == {"details": True, "max_new_tokens": 100, "stop_sequences": [], "temperature": 0.8}
assert kwargs == {
"details": True,
"max_new_tokens": 100,
"stop_sequences": [],
"stream": False,
"temperature": 0.8,
}
# Assert that the response contains the generated replies and the right response
assert "replies" in response