fix: HuggingFaceAPIGenerator - use forward references (#8502)

* hf API generator: forward references + refactor

* release note
This commit is contained in:
Stefano Fiorucci 2024-10-30 11:51:07 +01:00 committed by GitHub
parent 8a35e792b9
commit 700684a31c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 12 deletions

View File

@ -204,24 +204,27 @@ class HuggingFaceAPIGenerator:
# check if streaming_callback is passed
streaming_callback = streaming_callback or self.streaming_callback
stream = streaming_callback is not None
response = self._client.text_generation(prompt, details=True, stream=stream, **generation_kwargs)
hf_output = self._client.text_generation(
prompt, details=True, stream=streaming_callback is not None, **generation_kwargs
)
output = self._get_stream_response(response, streaming_callback) if stream else self._get_response(response) # type: ignore
return output
if streaming_callback is not None:
return self._stream_and_build_response(hf_output, streaming_callback)
def _get_stream_response(
self, response: Iterable[TextGenerationStreamOutput], streaming_callback: Callable[[StreamingChunk], None]
return self._build_non_streaming_response(hf_output)
def _stream_and_build_response(
self, hf_output: Iterable["TextGenerationStreamOutput"], streaming_callback: Callable[[StreamingChunk], None]
):
chunks: List[StreamingChunk] = []
for chunk in response:
for chunk in hf_output:
token: TextGenerationOutputToken = chunk.token
if token.special:
continue
chunk_metadata = {**asdict(token), **(asdict(chunk.details) if chunk.details else {})}
stream_chunk = StreamingChunk(token.text, chunk_metadata)
chunks.append(stream_chunk)
streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method)
streaming_callback(stream_chunk)
metadata = {
"finish_reason": chunks[-1].meta.get("finish_reason", None),
"model": self._client.model,
@ -229,12 +232,12 @@ class HuggingFaceAPIGenerator:
}
return {"replies": ["".join([chunk.content for chunk in chunks])], "meta": [metadata]}
def _get_response(self, response: TextGenerationOutput):
def _build_non_streaming_response(self, hf_output: "TextGenerationOutput"):
meta = [
{
"model": self._client.model,
"finish_reason": response.details.finish_reason if response.details else None,
"usage": {"completion_tokens": len(response.details.tokens) if response.details else 0},
"finish_reason": hf_output.details.finish_reason if hf_output.details else None,
"usage": {"completion_tokens": len(hf_output.details.tokens) if hf_output.details else 0},
}
]
return {"replies": [response.generated_text], "meta": meta}
return {"replies": [hf_output.generated_text], "meta": meta}

View File

@ -0,0 +1,5 @@
---
fixes:
- |
Use forward references for Hugging Face Hub types in the `HuggingFaceAPIGenerator` component
to prevent import errors.