refactor!: make TGI generators compatible with huggingface_hub>=0.22.0 (#7425)

* progress

* progress

* better lazy imports

* fixes

* reno
This commit is contained in:
Stefano Fiorucci 2024-03-26 16:10:06 +01:00 committed by GitHub
parent fcd48d662c
commit e26ee0f1db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 54 additions and 35 deletions

View File

@ -9,7 +9,7 @@ from haystack.lazy_imports import LazyImport
from haystack.utils import Secret, deserialize_secrets_inplace
from haystack.utils.hf import HFModelType, check_valid_model
with LazyImport(message="Run 'pip install huggingface_hub'") as huggingface_hub_import:
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.22.0\"'") as huggingface_hub_import:
from huggingface_hub import InferenceClient
logger = logging.getLogger(__name__)

View File

@ -6,7 +6,7 @@ from haystack.lazy_imports import LazyImport
from haystack.utils import Secret, deserialize_secrets_inplace
from haystack.utils.hf import HFModelType, check_valid_model
with LazyImport(message="Run 'pip install huggingface_hub'") as huggingface_hub_import:
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.22.0\"'") as huggingface_hub_import:
from huggingface_hub import InferenceClient
logger = logging.getLogger(__name__)

View File

@ -8,9 +8,13 @@ from haystack.lazy_imports import LazyImport
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
from haystack.utils.hf import HFModelType, check_generation_params, check_valid_model, list_inference_deployed_models
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
from huggingface_hub import InferenceClient
from huggingface_hub.inference._text_generation import TextGenerationResponse, TextGenerationStreamResponse, Token
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.22.0\" transformers'") as transformers_import:
from huggingface_hub import (
InferenceClient,
TextGenerationOutput,
TextGenerationOutputToken,
TextGenerationStreamOutput,
)
from transformers import AutoTokenizer
logger = logging.getLogger(__name__)
@ -245,13 +249,13 @@ class HuggingFaceTGIChatGenerator:
def _run_streaming(
self, prepared_prompt: str, prompt_token_count: int, generation_kwargs: Dict[str, Any]
) -> Dict[str, List[ChatMessage]]:
res: Iterable[TextGenerationStreamResponse] = self.client.text_generation(
res: Iterable[TextGenerationStreamOutput] = self.client.text_generation(
prepared_prompt, stream=True, details=True, **generation_kwargs
)
chunk = None
# pylint: disable=not-an-iterable
for chunk in res:
token: Token = chunk.token
token: TextGenerationOutputToken = chunk.token
if token.special:
continue
chunk_metadata = {**asdict(token), **(asdict(chunk.details) if chunk.details else {})}
@ -261,7 +265,7 @@ class HuggingFaceTGIChatGenerator:
message = ChatMessage.from_assistant(chunk.generated_text)
message.meta.update(
{
"finish_reason": chunk.details.finish_reason.value,
"finish_reason": chunk.details.finish_reason,
"index": 0,
"model": self.client.model,
"usage": {
@ -278,13 +282,11 @@ class HuggingFaceTGIChatGenerator:
) -> Dict[str, List[ChatMessage]]:
chat_messages: List[ChatMessage] = []
for _i in range(num_responses):
tgr: TextGenerationResponse = self.client.text_generation(
prepared_prompt, details=True, **generation_kwargs
)
tgr: TextGenerationOutput = self.client.text_generation(prepared_prompt, details=True, **generation_kwargs)
message = ChatMessage.from_assistant(tgr.generated_text)
message.meta.update(
{
"finish_reason": tgr.details.finish_reason.value,
"finish_reason": tgr.details.finish_reason,
"index": _i,
"model": self.client.model,
"usage": {

View File

@ -8,9 +8,13 @@ from haystack.lazy_imports import LazyImport
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
from haystack.utils.hf import HFModelType, check_generation_params, check_valid_model, list_inference_deployed_models
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
from huggingface_hub import InferenceClient
from huggingface_hub.inference._text_generation import TextGenerationResponse, TextGenerationStreamResponse, Token
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.22.0\" transformers'") as transformers_import:
from huggingface_hub import (
InferenceClient,
TextGenerationOutput,
TextGenerationOutputToken,
TextGenerationStreamOutput,
)
from transformers import AutoTokenizer
@ -213,13 +217,13 @@ class HuggingFaceTGIGenerator:
return self._run_non_streaming(prompt, prompt_token_count, num_responses, generation_kwargs)
def _run_streaming(self, prompt: str, prompt_token_count: int, generation_kwargs: Dict[str, Any]):
res_chunk: Iterable[TextGenerationStreamResponse] = self.client.text_generation(
res_chunk: Iterable[TextGenerationStreamOutput] = self.client.text_generation(
prompt, details=True, stream=True, **generation_kwargs
)
chunks: List[StreamingChunk] = []
# pylint: disable=not-an-iterable
for chunk in res_chunk:
token: Token = chunk.token
token: TextGenerationOutputToken = chunk.token
if token.special:
continue
chunk_metadata = {**asdict(token), **(asdict(chunk.details) if chunk.details else {})}
@ -243,12 +247,12 @@ class HuggingFaceTGIGenerator:
responses: List[str] = []
all_metadata: List[Dict[str, Any]] = []
for _i in range(num_responses):
tgr: TextGenerationResponse = self.client.text_generation(prompt, details=True, **generation_kwargs)
tgr: TextGenerationOutput = self.client.text_generation(prompt, details=True, **generation_kwargs)
all_metadata.append(
{
"model": self.client.model,
"index": _i,
"finish_reason": tgr.details.finish_reason.value,
"finish_reason": tgr.details.finish_reason,
"usage": {
"completion_tokens": len(tgr.details.tokens),
"prompt_tokens": prompt_token_count,

View File

@ -14,7 +14,7 @@ from haystack.utils.device import ComponentDevice
with LazyImport(message="Run 'pip install transformers[torch]'") as torch_import:
import torch
with LazyImport(message="Run 'pip install huggingface_hub'") as huggingface_hub_import:
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.22.0\"'") as huggingface_hub_import:
from huggingface_hub import HfApi, InferenceClient, model_info
from huggingface_hub.utils import RepositoryNotFoundError

View File

@ -102,7 +102,7 @@ format-check = "black --check ."
[tool.hatch.envs.test]
extra-dependencies = [
"transformers[torch,sentencepiece]==4.38.2", # ExtractiveReader, TransformersSimilarityRanker, LocalWhisperTranscriber, HFGenerators...
"huggingface_hub<0.22.0", # TGI Generators and TEI Embedders
"huggingface_hub>=0.22.0", # TGI Generators and TEI Embedders
"spacy>=3.7,<3.8", # NamedEntityExtractor
"spacy-curated-transformers>=0.2,<=0.3", # NamedEntityExtractor
"en-core-web-trf @ https://github.com/explosion/spacy-models/releases/download/en_core_web_trf-3.7.3/en_core_web_trf-3.7.3-py3-none-any.whl", # NamedEntityExtractor

View File

@ -0,0 +1,11 @@
---
upgrade:
- |
The `HuggingFaceTGIGenerator` and `HuggingFaceTGIChatGenerator` components have been modified to be compatible with
`huggingface_hub>=0.22.0`.
If you use these components, you may need to upgrade the `huggingface_hub` library.
To do this, run the following command in your environment:
```bash
pip install "huggingface_hub>=0.22.0"
```

View File

@ -1,7 +1,7 @@
from unittest.mock import MagicMock, Mock, patch
import pytest
from huggingface_hub.inference._text_generation import FinishReason, StreamDetails, TextGenerationStreamResponse, Token
from huggingface_hub import TextGenerationOutputToken, TextGenerationStreamDetails, TextGenerationStreamOutput
from huggingface_hub.utils import RepositoryNotFoundError
from haystack.components.generators.chat import HuggingFaceTGIChatGenerator
@ -328,13 +328,14 @@ class TestHuggingFaceTGIChatGenerator:
# Create a fake streamed response
# self needed here, don't remove
def mock_iter(self):
yield TextGenerationStreamResponse(
generated_text=None, token=Token(id=1, text="I'm fine, thanks.", logprob=0.0, special=False)
)
yield TextGenerationStreamResponse(
yield TextGenerationStreamOutput(
generated_text=None,
token=Token(id=1, text="Ok bye", logprob=0.0, special=False),
details=StreamDetails(finish_reason=FinishReason.Length, generated_tokens=5),
token=TextGenerationOutputToken(id=1, text="I'm fine, thanks.", logprob=0.0, special=False),
)
yield TextGenerationStreamOutput(
generated_text=None,
token=TextGenerationOutputToken(id=1, text="Ok bye", logprob=0.0, special=False),
details=TextGenerationStreamDetails(finish_reason="length", generated_tokens=5, seed=None),
)
mock_response = Mock(**{"__iter__": mock_iter})

View File

@ -1,7 +1,7 @@
from unittest.mock import MagicMock, Mock, patch
import pytest
from huggingface_hub.inference._text_generation import FinishReason, StreamDetails, TextGenerationStreamResponse, Token
from huggingface_hub import TextGenerationOutputToken, TextGenerationStreamDetails, TextGenerationStreamOutput
from huggingface_hub.utils import RepositoryNotFoundError
from haystack.components.generators import HuggingFaceTGIGenerator
@ -270,13 +270,14 @@ class TestHuggingFaceTGIGenerator:
# Create a fake streamed response
# Don't remove self
def mock_iter(self):
yield TextGenerationStreamResponse(
generated_text=None, token=Token(id=1, text="I'm fine, thanks.", logprob=0.0, special=False)
)
yield TextGenerationStreamResponse(
yield TextGenerationStreamOutput(
generated_text=None,
token=Token(id=1, text="Ok bye", logprob=0.0, special=False),
details=StreamDetails(finish_reason=FinishReason.Length, generated_tokens=5),
token=TextGenerationOutputToken(id=1, text="I'm fine, thanks.", logprob=0.0, special=False),
)
yield TextGenerationStreamOutput(
generated_text=None,
token=TextGenerationOutputToken(id=1, text="Ok bye", logprob=0.0, special=False),
details=TextGenerationStreamDetails(finish_reason="length", generated_tokens=5, seed=None),
)
mock_response = Mock(**{"__iter__": mock_iter})