mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-06 12:07:04 +00:00
refactor!: make TGI generators compatible with huggingface_hub>=0.22.0 (#7425)
* progress * progress * better lazy imports * fixes * reno
This commit is contained in:
parent
fcd48d662c
commit
e26ee0f1db
@ -9,7 +9,7 @@ from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
||||
from haystack.utils.hf import HFModelType, check_valid_model
|
||||
|
||||
with LazyImport(message="Run 'pip install huggingface_hub'") as huggingface_hub_import:
|
||||
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.22.0\"'") as huggingface_hub_import:
|
||||
from huggingface_hub import InferenceClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -6,7 +6,7 @@ from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
||||
from haystack.utils.hf import HFModelType, check_valid_model
|
||||
|
||||
with LazyImport(message="Run 'pip install huggingface_hub'") as huggingface_hub_import:
|
||||
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.22.0\"'") as huggingface_hub_import:
|
||||
from huggingface_hub import InferenceClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -8,9 +8,13 @@ from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
|
||||
from haystack.utils.hf import HFModelType, check_generation_params, check_valid_model, list_inference_deployed_models
|
||||
|
||||
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
|
||||
from huggingface_hub import InferenceClient
|
||||
from huggingface_hub.inference._text_generation import TextGenerationResponse, TextGenerationStreamResponse, Token
|
||||
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.22.0\" transformers'") as transformers_import:
|
||||
from huggingface_hub import (
|
||||
InferenceClient,
|
||||
TextGenerationOutput,
|
||||
TextGenerationOutputToken,
|
||||
TextGenerationStreamOutput,
|
||||
)
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -245,13 +249,13 @@ class HuggingFaceTGIChatGenerator:
|
||||
def _run_streaming(
|
||||
self, prepared_prompt: str, prompt_token_count: int, generation_kwargs: Dict[str, Any]
|
||||
) -> Dict[str, List[ChatMessage]]:
|
||||
res: Iterable[TextGenerationStreamResponse] = self.client.text_generation(
|
||||
res: Iterable[TextGenerationStreamOutput] = self.client.text_generation(
|
||||
prepared_prompt, stream=True, details=True, **generation_kwargs
|
||||
)
|
||||
chunk = None
|
||||
# pylint: disable=not-an-iterable
|
||||
for chunk in res:
|
||||
token: Token = chunk.token
|
||||
token: TextGenerationOutputToken = chunk.token
|
||||
if token.special:
|
||||
continue
|
||||
chunk_metadata = {**asdict(token), **(asdict(chunk.details) if chunk.details else {})}
|
||||
@ -261,7 +265,7 @@ class HuggingFaceTGIChatGenerator:
|
||||
message = ChatMessage.from_assistant(chunk.generated_text)
|
||||
message.meta.update(
|
||||
{
|
||||
"finish_reason": chunk.details.finish_reason.value,
|
||||
"finish_reason": chunk.details.finish_reason,
|
||||
"index": 0,
|
||||
"model": self.client.model,
|
||||
"usage": {
|
||||
@ -278,13 +282,11 @@ class HuggingFaceTGIChatGenerator:
|
||||
) -> Dict[str, List[ChatMessage]]:
|
||||
chat_messages: List[ChatMessage] = []
|
||||
for _i in range(num_responses):
|
||||
tgr: TextGenerationResponse = self.client.text_generation(
|
||||
prepared_prompt, details=True, **generation_kwargs
|
||||
)
|
||||
tgr: TextGenerationOutput = self.client.text_generation(prepared_prompt, details=True, **generation_kwargs)
|
||||
message = ChatMessage.from_assistant(tgr.generated_text)
|
||||
message.meta.update(
|
||||
{
|
||||
"finish_reason": tgr.details.finish_reason.value,
|
||||
"finish_reason": tgr.details.finish_reason,
|
||||
"index": _i,
|
||||
"model": self.client.model,
|
||||
"usage": {
|
||||
|
||||
@ -8,9 +8,13 @@ from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
|
||||
from haystack.utils.hf import HFModelType, check_generation_params, check_valid_model, list_inference_deployed_models
|
||||
|
||||
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
|
||||
from huggingface_hub import InferenceClient
|
||||
from huggingface_hub.inference._text_generation import TextGenerationResponse, TextGenerationStreamResponse, Token
|
||||
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.22.0\" transformers'") as transformers_import:
|
||||
from huggingface_hub import (
|
||||
InferenceClient,
|
||||
TextGenerationOutput,
|
||||
TextGenerationOutputToken,
|
||||
TextGenerationStreamOutput,
|
||||
)
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
@ -213,13 +217,13 @@ class HuggingFaceTGIGenerator:
|
||||
return self._run_non_streaming(prompt, prompt_token_count, num_responses, generation_kwargs)
|
||||
|
||||
def _run_streaming(self, prompt: str, prompt_token_count: int, generation_kwargs: Dict[str, Any]):
|
||||
res_chunk: Iterable[TextGenerationStreamResponse] = self.client.text_generation(
|
||||
res_chunk: Iterable[TextGenerationStreamOutput] = self.client.text_generation(
|
||||
prompt, details=True, stream=True, **generation_kwargs
|
||||
)
|
||||
chunks: List[StreamingChunk] = []
|
||||
# pylint: disable=not-an-iterable
|
||||
for chunk in res_chunk:
|
||||
token: Token = chunk.token
|
||||
token: TextGenerationOutputToken = chunk.token
|
||||
if token.special:
|
||||
continue
|
||||
chunk_metadata = {**asdict(token), **(asdict(chunk.details) if chunk.details else {})}
|
||||
@ -243,12 +247,12 @@ class HuggingFaceTGIGenerator:
|
||||
responses: List[str] = []
|
||||
all_metadata: List[Dict[str, Any]] = []
|
||||
for _i in range(num_responses):
|
||||
tgr: TextGenerationResponse = self.client.text_generation(prompt, details=True, **generation_kwargs)
|
||||
tgr: TextGenerationOutput = self.client.text_generation(prompt, details=True, **generation_kwargs)
|
||||
all_metadata.append(
|
||||
{
|
||||
"model": self.client.model,
|
||||
"index": _i,
|
||||
"finish_reason": tgr.details.finish_reason.value,
|
||||
"finish_reason": tgr.details.finish_reason,
|
||||
"usage": {
|
||||
"completion_tokens": len(tgr.details.tokens),
|
||||
"prompt_tokens": prompt_token_count,
|
||||
|
||||
@ -14,7 +14,7 @@ from haystack.utils.device import ComponentDevice
|
||||
with LazyImport(message="Run 'pip install transformers[torch]'") as torch_import:
|
||||
import torch
|
||||
|
||||
with LazyImport(message="Run 'pip install huggingface_hub'") as huggingface_hub_import:
|
||||
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.22.0\"'") as huggingface_hub_import:
|
||||
from huggingface_hub import HfApi, InferenceClient, model_info
|
||||
from huggingface_hub.utils import RepositoryNotFoundError
|
||||
|
||||
|
||||
@ -102,7 +102,7 @@ format-check = "black --check ."
|
||||
[tool.hatch.envs.test]
|
||||
extra-dependencies = [
|
||||
"transformers[torch,sentencepiece]==4.38.2", # ExtractiveReader, TransformersSimilarityRanker, LocalWhisperTranscriber, HFGenerators...
|
||||
"huggingface_hub<0.22.0", # TGI Generators and TEI Embedders
|
||||
"huggingface_hub>=0.22.0", # TGI Generators and TEI Embedders
|
||||
"spacy>=3.7,<3.8", # NamedEntityExtractor
|
||||
"spacy-curated-transformers>=0.2,<=0.3", # NamedEntityExtractor
|
||||
"en-core-web-trf @ https://github.com/explosion/spacy-models/releases/download/en_core_web_trf-3.7.3/en_core_web_trf-3.7.3-py3-none-any.whl", # NamedEntityExtractor
|
||||
|
||||
@ -0,0 +1,11 @@
|
||||
---
|
||||
upgrade:
|
||||
- |
|
||||
The `HuggingFaceTGIGenerator` and `HuggingFaceTGIChatGenerator` components have been modified to be compatible with
|
||||
`huggingface_hub>=0.22.0`.
|
||||
|
||||
If you use these components, you may need to upgrade the `huggingface_hub` library.
|
||||
To do this, run the following command in your environment:
|
||||
```bash
|
||||
pip install "huggingface_hub>=0.22.0"
|
||||
```
|
||||
@ -1,7 +1,7 @@
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from huggingface_hub.inference._text_generation import FinishReason, StreamDetails, TextGenerationStreamResponse, Token
|
||||
from huggingface_hub import TextGenerationOutputToken, TextGenerationStreamDetails, TextGenerationStreamOutput
|
||||
from huggingface_hub.utils import RepositoryNotFoundError
|
||||
|
||||
from haystack.components.generators.chat import HuggingFaceTGIChatGenerator
|
||||
@ -328,13 +328,14 @@ class TestHuggingFaceTGIChatGenerator:
|
||||
# Create a fake streamed response
|
||||
# self needed here, don't remove
|
||||
def mock_iter(self):
|
||||
yield TextGenerationStreamResponse(
|
||||
generated_text=None, token=Token(id=1, text="I'm fine, thanks.", logprob=0.0, special=False)
|
||||
)
|
||||
yield TextGenerationStreamResponse(
|
||||
yield TextGenerationStreamOutput(
|
||||
generated_text=None,
|
||||
token=Token(id=1, text="Ok bye", logprob=0.0, special=False),
|
||||
details=StreamDetails(finish_reason=FinishReason.Length, generated_tokens=5),
|
||||
token=TextGenerationOutputToken(id=1, text="I'm fine, thanks.", logprob=0.0, special=False),
|
||||
)
|
||||
yield TextGenerationStreamOutput(
|
||||
generated_text=None,
|
||||
token=TextGenerationOutputToken(id=1, text="Ok bye", logprob=0.0, special=False),
|
||||
details=TextGenerationStreamDetails(finish_reason="length", generated_tokens=5, seed=None),
|
||||
)
|
||||
|
||||
mock_response = Mock(**{"__iter__": mock_iter})
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from huggingface_hub.inference._text_generation import FinishReason, StreamDetails, TextGenerationStreamResponse, Token
|
||||
from huggingface_hub import TextGenerationOutputToken, TextGenerationStreamDetails, TextGenerationStreamOutput
|
||||
from huggingface_hub.utils import RepositoryNotFoundError
|
||||
|
||||
from haystack.components.generators import HuggingFaceTGIGenerator
|
||||
@ -270,13 +270,14 @@ class TestHuggingFaceTGIGenerator:
|
||||
# Create a fake streamed response
|
||||
# Don't remove self
|
||||
def mock_iter(self):
|
||||
yield TextGenerationStreamResponse(
|
||||
generated_text=None, token=Token(id=1, text="I'm fine, thanks.", logprob=0.0, special=False)
|
||||
)
|
||||
yield TextGenerationStreamResponse(
|
||||
yield TextGenerationStreamOutput(
|
||||
generated_text=None,
|
||||
token=Token(id=1, text="Ok bye", logprob=0.0, special=False),
|
||||
details=StreamDetails(finish_reason=FinishReason.Length, generated_tokens=5),
|
||||
token=TextGenerationOutputToken(id=1, text="I'm fine, thanks.", logprob=0.0, special=False),
|
||||
)
|
||||
yield TextGenerationStreamOutput(
|
||||
generated_text=None,
|
||||
token=TextGenerationOutputToken(id=1, text="Ok bye", logprob=0.0, special=False),
|
||||
details=TextGenerationStreamDetails(finish_reason="length", generated_tokens=5, seed=None),
|
||||
)
|
||||
|
||||
mock_response = Mock(**{"__iter__": mock_iter})
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user