mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-27 06:58:35 +00:00
fix: HuggingFaceAPIGenerator - use forward references (#8502)
* hf API generator: forward references + refactor * release note
This commit is contained in:
parent
8a35e792b9
commit
700684a31c
@ -204,24 +204,27 @@ class HuggingFaceAPIGenerator:
|
||||
# check if streaming_callback is passed
|
||||
streaming_callback = streaming_callback or self.streaming_callback
|
||||
|
||||
stream = streaming_callback is not None
|
||||
response = self._client.text_generation(prompt, details=True, stream=stream, **generation_kwargs)
|
||||
hf_output = self._client.text_generation(
|
||||
prompt, details=True, stream=streaming_callback is not None, **generation_kwargs
|
||||
)
|
||||
|
||||
output = self._get_stream_response(response, streaming_callback) if stream else self._get_response(response) # type: ignore
|
||||
return output
|
||||
if streaming_callback is not None:
|
||||
return self._stream_and_build_response(hf_output, streaming_callback)
|
||||
|
||||
def _get_stream_response(
|
||||
self, response: Iterable[TextGenerationStreamOutput], streaming_callback: Callable[[StreamingChunk], None]
|
||||
return self._build_non_streaming_response(hf_output)
|
||||
|
||||
def _stream_and_build_response(
|
||||
self, hf_output: Iterable["TextGenerationStreamOutput"], streaming_callback: Callable[[StreamingChunk], None]
|
||||
):
|
||||
chunks: List[StreamingChunk] = []
|
||||
for chunk in response:
|
||||
for chunk in hf_output:
|
||||
token: TextGenerationOutputToken = chunk.token
|
||||
if token.special:
|
||||
continue
|
||||
chunk_metadata = {**asdict(token), **(asdict(chunk.details) if chunk.details else {})}
|
||||
stream_chunk = StreamingChunk(token.text, chunk_metadata)
|
||||
chunks.append(stream_chunk)
|
||||
streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method)
|
||||
streaming_callback(stream_chunk)
|
||||
metadata = {
|
||||
"finish_reason": chunks[-1].meta.get("finish_reason", None),
|
||||
"model": self._client.model,
|
||||
@ -229,12 +232,12 @@ class HuggingFaceAPIGenerator:
|
||||
}
|
||||
return {"replies": ["".join([chunk.content for chunk in chunks])], "meta": [metadata]}
|
||||
|
||||
def _get_response(self, response: TextGenerationOutput):
|
||||
def _build_non_streaming_response(self, hf_output: "TextGenerationOutput"):
|
||||
meta = [
|
||||
{
|
||||
"model": self._client.model,
|
||||
"finish_reason": response.details.finish_reason if response.details else None,
|
||||
"usage": {"completion_tokens": len(response.details.tokens) if response.details else 0},
|
||||
"finish_reason": hf_output.details.finish_reason if hf_output.details else None,
|
||||
"usage": {"completion_tokens": len(hf_output.details.tokens) if hf_output.details else 0},
|
||||
}
|
||||
]
|
||||
return {"replies": [response.generated_text], "meta": meta}
|
||||
return {"replies": [hf_output.generated_text], "meta": meta}
|
||||
|
||||
@ -0,0 +1,5 @@
|
||||
---
|
||||
fixes:
|
||||
- |
|
||||
Use forward references for Hugging Face Hub types in the `HuggingFaceAPIGenerator` component
|
||||
to prevent import errors.
|
||||
Loading…
x
Reference in New Issue
Block a user