mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-07 20:46:31 +00:00
feat: generators (2.0) (#5690)
* add generators module * add tests for module helper * reno * add another test * move into openai * improve tests
This commit is contained in:
parent
6787ad2435
commit
5f1256ac7e
0
haystack/preview/components/generators/__init__.py
Normal file
0
haystack/preview/components/generators/__init__.py
Normal file
33
haystack/preview/components/generators/openai/_helpers.py
Normal file
33
haystack/preview/components/generators/openai/_helpers.py
Normal file
@ -0,0 +1,33 @@
|
||||
import logging
|
||||
|
||||
from haystack.preview.lazy_imports import LazyImport
|
||||
|
||||
with LazyImport("Run 'pip install tiktoken'") as tiktoken_import:
|
||||
import tiktoken
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def enforce_token_limit(prompt: str, tokenizer: "tiktoken.Encoding", max_tokens_limit: int) -> str:
|
||||
"""
|
||||
Ensure that the length of the prompt is within the max tokens limit of the model.
|
||||
If needed, truncate the prompt text so that it fits within the limit.
|
||||
|
||||
:param prompt: Prompt text to be sent to the generative model.
|
||||
:param tokenizer: The tokenizer used to encode the prompt.
|
||||
:param max_tokens_limit: The max tokens limit of the model.
|
||||
:return: The prompt text that fits within the max tokens limit of the model.
|
||||
"""
|
||||
tiktoken_import.check()
|
||||
tokens = tokenizer.encode(prompt)
|
||||
tokens_count = len(tokens)
|
||||
if tokens_count > max_tokens_limit:
|
||||
logger.warning(
|
||||
"The prompt has been truncated from %s tokens to %s tokens to fit within the max token limit. "
|
||||
"Reduce the length of the prompt to prevent it from being cut off.",
|
||||
tokens_count,
|
||||
max_tokens_limit,
|
||||
)
|
||||
prompt = tokenizer.decode(tokens[:max_tokens_limit])
|
||||
return prompt
|
||||
@ -0,0 +1,2 @@
|
||||
preview:
|
||||
- Add generators module for LLM generator components.
|
||||
@ -0,0 +1,20 @@
|
||||
import pytest
|
||||
|
||||
from haystack.preview.components.generators.openai._helpers import enforce_token_limit
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_enforce_token_limit_above_limit(caplog, mock_tokenizer):
|
||||
prompt = enforce_token_limit("This is a test prompt.", tokenizer=mock_tokenizer, max_tokens_limit=3)
|
||||
assert prompt == "This is a"
|
||||
assert caplog.records[0].message == (
|
||||
"The prompt has been truncated from 5 tokens to 3 tokens to fit within the max token "
|
||||
"limit. Reduce the length of the prompt to prevent it from being cut off."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_enforce_token_limit_below_limit(caplog, mock_tokenizer):
|
||||
prompt = enforce_token_limit("This is a test prompt.", tokenizer=mock_tokenizer, max_tokens_limit=100)
|
||||
assert prompt == "This is a test prompt."
|
||||
assert not caplog.records
|
||||
13
test/preview/conftest.py
Normal file
13
test/preview/conftest.py
Normal file
@ -0,0 +1,13 @@
|
||||
from unittest.mock import Mock
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_tokenizer():
|
||||
"""
|
||||
Tokenizes the string by splitting on spaces.
|
||||
"""
|
||||
tokenizer = Mock()
|
||||
tokenizer.encode = lambda text: text.split()
|
||||
tokenizer.decode = lambda tokens: " ".join(tokens)
|
||||
return tokenizer
|
||||
Loading…
x
Reference in New Issue
Block a user