List configs to list

This commit is contained in:
Jake Poznanski 2024-10-24 03:07:32 +00:00
parent ffe470bf0e
commit f13d0a5741
2 changed files with 5 additions and 5 deletions

View File

@ -93,7 +93,7 @@ def prepare_data_for_qwen2_training(example, processor, target_longest_image_dim
}
def batch_prepare_data_for_qwen2_training(batch, processor, target_longest_image_dim: int, target_anchor_text_len: int):
def batch_prepare_data_for_qwen2_training(batch, processor, target_longest_image_dim: list[int], target_anchor_text_len: list[int]):
# Process each example in the batch using the helper function
processed_examples = []
for i in range(len(batch["response"])):

View File

@ -74,8 +74,8 @@ def make_dataset(config: TrainConfig, processor: AutoProcessor) -> tuple[Dataset
partial(
batch_prepare_data_for_qwen2_training,
processor=processor,
target_longest_image_dim=target_longest_image_dim,
target_anchor_text_len=target_anchor_text_len,
target_longest_image_dim=list(target_longest_image_dim),
target_anchor_text_len=list(target_anchor_text_len),
)
)
@ -86,8 +86,8 @@ def make_dataset(config: TrainConfig, processor: AutoProcessor) -> tuple[Dataset
partial(
batch_prepare_data_for_qwen2_training,
processor=processor,
target_longest_image_dim=source.target_longest_image_dim,
target_anchor_text_len=source.target_anchor_text_len,
target_longest_image_dim=list(source.target_longest_image_dim),
target_anchor_text_len=list(source.target_anchor_text_len),
)
)
for source in config.valid_data.sources