diff --git a/tests/test_birr_outputs.py b/tests/test_birr_outputs.py index 96f8b26..bf720ea 100644 --- a/tests/test_birr_outputs.py +++ b/tests/test_birr_outputs.py @@ -128,4 +128,47 @@ class TestBirrTokenization(unittest.TestCase): print(end_token_len - start_token_len + 1, " max image tokens") print(compute_number_of_image_tokens(1024, 1024)) - \ No newline at end of file + + def testBirrChatTemplate(self): + import yaml + import json + import os + + from birr.tokenization import ModelTokenizer + from birr.core.config import FormatConfig, LLMModelConfig + from birr.batch_inference.data_models import RawInputItem + + from pdelfin.birrpipeline import build_page_query + + TEST_INSTANCES = [json.dumps(build_page_query(os.path.join( + os.path.dirname(__file__), + "gnarly_pdfs", + "edgar.pdf" + ), "test.pdf", 1, 1024, 4096)), + ] + + TEST_INSTANCES = [ + RawInputItem.from_message_dicts(i, json.loads(json_string)["chat_messages"]).messages + for i, json_string in enumerate(TEST_INSTANCES) + ] + + MODEL_NAME = "Qwen/Qwen2-0.5B" + + CONFIG_FILE = "/home/ubuntu/mise/birr/configs/inference/qwen2-vl-test.yaml" + + + with open(CONFIG_FILE, "r") as f: + file_contents = f.read() + config_dict = yaml.safe_load(file_contents) + + model_config = LLMModelConfig(name_or_path=MODEL_NAME) + format_config = FormatConfig(**config_dict.get("format", {})) + + tokenizer = ModelTokenizer(model_config, format_config) + + formatted = tokenizer.batch_format(TEST_INSTANCES) + + print(formatted) + + for formatted_instance in formatted: + print(f"========================================\n{formatted_instance}\n========================================") \ No newline at end of file