feat: generators (2.0) (#5690)

* add generators module

* add tests for module helper

* reno

* add another test

* move into openai

* improve tests
This commit is contained in:
ZanSara 2023-08-31 17:33:12 +02:00 committed by GitHub
parent 6787ad2435
commit 5f1256ac7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 68 additions and 0 deletions

View File

@ -0,0 +1,33 @@
import logging
from haystack.preview.lazy_imports import LazyImport
with LazyImport("Run 'pip install tiktoken'") as tiktoken_import:
import tiktoken
logger = logging.getLogger(__name__)
def enforce_token_limit(prompt: str, tokenizer: "tiktoken.Encoding", max_tokens_limit: int) -> str:
"""
Ensure that the length of the prompt is within the max tokens limit of the model.
If needed, truncate the prompt text so that it fits within the limit.
:param prompt: Prompt text to be sent to the generative model.
:param tokenizer: The tokenizer used to encode the prompt.
:param max_tokens_limit: The max tokens limit of the model.
:return: The prompt text that fits within the max tokens limit of the model.
"""
tiktoken_import.check()
tokens = tokenizer.encode(prompt)
tokens_count = len(tokens)
if tokens_count > max_tokens_limit:
logger.warning(
"The prompt has been truncated from %s tokens to %s tokens to fit within the max token limit. "
"Reduce the length of the prompt to prevent it from being cut off.",
tokens_count,
max_tokens_limit,
)
prompt = tokenizer.decode(tokens[:max_tokens_limit])
return prompt

View File

@ -0,0 +1,2 @@
preview:
- Add generators module for LLM generator components.

View File

@ -0,0 +1,20 @@
import pytest
from haystack.preview.components.generators.openai._helpers import enforce_token_limit
@pytest.mark.unit
def test_enforce_token_limit_above_limit(caplog, mock_tokenizer):
prompt = enforce_token_limit("This is a test prompt.", tokenizer=mock_tokenizer, max_tokens_limit=3)
assert prompt == "This is a"
assert caplog.records[0].message == (
"The prompt has been truncated from 5 tokens to 3 tokens to fit within the max token "
"limit. Reduce the length of the prompt to prevent it from being cut off."
)
@pytest.mark.unit
def test_enforce_token_limit_below_limit(caplog, mock_tokenizer):
prompt = enforce_token_limit("This is a test prompt.", tokenizer=mock_tokenizer, max_tokens_limit=100)
assert prompt == "This is a test prompt."
assert not caplog.records

13
test/preview/conftest.py Normal file
View File

@ -0,0 +1,13 @@
from unittest.mock import Mock
import pytest
@pytest.fixture()
def mock_tokenizer():
"""
Tokenizes the string by splitting on spaces.
"""
tokenizer = Mock()
tokenizer.encode = lambda text: text.split()
tokenizer.decode = lambda tokens: " ".join(tokens)
return tokenizer