autogen/test/agentchat/contrib/capabilities/test_transforms_util.py
Wael Karkoub 8564bd4c48
[Refactor] Transforms Utils (#2863)
* wip

* tests + docstrings

* improves tests

* fix import
2024-06-06 21:49:22 +00:00

73 lines
3.1 KiB
Python

import itertools
import tempfile
from typing import Dict, Tuple
import pytest
from autogen.agentchat.contrib.capabilities import transforms_util
from autogen.cache.cache import Cache
from autogen.types import MessageContentType
MESSAGES = {
"message1": {
"content": [{"text": "Hello"}, {"image_url": {"url": "https://example.com/image.jpg"}}],
"text_tokens": 1,
},
"message2": {"content": [{"image_url": {"url": "https://example.com/image.jpg"}}], "text_tokens": 0},
"message3": {"content": [{"text": "Hello"}, {"text": "World"}], "text_tokens": 2},
"message4": {"content": None, "text_tokens": 0},
"message5": {"content": "Hello there!", "text_tokens": 3},
"message6": {"content": ["Hello there!", "Hello there!"], "text_tokens": 6},
}
@pytest.mark.parametrize("message", MESSAGES.values())
def test_cache_content(message: Dict[str, MessageContentType]) -> None:
with tempfile.TemporaryDirectory() as tmpdirname:
cache = Cache.disk(tmpdirname)
cache_key_1 = "test_string"
transforms_util.cache_content_set(cache, cache_key_1, message["content"])
assert transforms_util.cache_content_get(cache, cache_key_1) == (message["content"],)
cache_key_2 = "test_list"
cache_value_2 = [message["content"], 1, "some_string", {"new_key": "new_value"}]
transforms_util.cache_content_set(cache, cache_key_2, *cache_value_2)
assert transforms_util.cache_content_get(cache, cache_key_2) == tuple(cache_value_2)
assert isinstance(cache_value_2[1], int)
assert isinstance(cache_value_2[2], str)
assert isinstance(cache_value_2[3], dict)
cache_key_3 = "test_None"
transforms_util.cache_content_set(None, cache_key_3, message["content"])
assert transforms_util.cache_content_get(cache, cache_key_3) is None
assert transforms_util.cache_content_get(None, cache_key_3) is None
@pytest.mark.parametrize("messages", itertools.product(MESSAGES.values(), MESSAGES.values()))
def test_cache_key(messages: Tuple[Dict[str, MessageContentType], Dict[str, MessageContentType]]) -> None:
message_1, message_2 = messages
cache_1 = transforms_util.cache_key(message_1["content"], 10)
cache_2 = transforms_util.cache_key(message_2["content"], 10)
if message_1 == message_2:
assert cache_1 == cache_2
else:
assert cache_1 != cache_2
@pytest.mark.parametrize("message", MESSAGES.values())
def test_min_tokens_reached(message: Dict[str, MessageContentType]):
assert transforms_util.min_tokens_reached([message], None)
assert transforms_util.min_tokens_reached([message], 0)
assert not transforms_util.min_tokens_reached([message], message["text_tokens"] + 1)
@pytest.mark.parametrize("message", MESSAGES.values())
def test_count_text_tokens(message: Dict[str, MessageContentType]):
assert transforms_util.count_text_tokens(message["content"]) == message["text_tokens"]
@pytest.mark.parametrize("message", MESSAGES.values())
def test_is_content_text_empty(message: Dict[str, MessageContentType]):
assert transforms_util.is_content_text_empty(message["content"]) == (message["text_tokens"] == 0)