Birr tokenization test

This commit is contained in:
Jake Poznanski 2024-10-18 23:02:37 +00:00
parent 77f0b9fa84
commit 9d35d3ca8f

View File

@ -128,4 +128,47 @@ class TestBirrTokenization(unittest.TestCase):
print(end_token_len - start_token_len + 1, " max image tokens")
print(compute_number_of_image_tokens(1024, 1024))
def testBirrChatTemplate(self):
import yaml
import json
import os
from birr.tokenization import ModelTokenizer
from birr.core.config import FormatConfig, LLMModelConfig
from birr.batch_inference.data_models import RawInputItem
from pdelfin.birrpipeline import build_page_query
TEST_INSTANCES = [json.dumps(build_page_query(os.path.join(
os.path.dirname(__file__),
"gnarly_pdfs",
"edgar.pdf"
), "test.pdf", 1, 1024, 4096)),
]
TEST_INSTANCES = [
RawInputItem.from_message_dicts(i, json.loads(json_string)["chat_messages"]).messages
for i, json_string in enumerate(TEST_INSTANCES)
]
MODEL_NAME = "Qwen/Qwen2-0.5B"
CONFIG_FILE = "/home/ubuntu/mise/birr/configs/inference/qwen2-vl-test.yaml"
with open(CONFIG_FILE, "r") as f:
file_contents = f.read()
config_dict = yaml.safe_load(file_contents)
model_config = LLMModelConfig(name_or_path=MODEL_NAME)
format_config = FormatConfig(**config_dict.get("format", {}))
tokenizer = ModelTokenizer(model_config, format_config)
formatted = tokenizer.batch_format(TEST_INSTANCES)
print(formatted)
for formatted_instance in formatted:
print(f"========================================\n{formatted_instance}\n========================================")