This commit is contained in:
Jake Poznanski 2024-09-23 09:43:36 -07:00
parent ea3af0143c
commit 5916239cd8
2 changed files with 10 additions and 10 deletions

View File

@ -9,7 +9,7 @@ from typing import Any, Dict, Optional
from logging import Logger
import boto3
from datasets import Dataset, Features, Value, load_dataset
from datasets import Dataset, Features, Value, load_dataset, concatenate_datasets, DatasetDict
from .core.config import DataConfig, SourceConfig
@ -167,7 +167,7 @@ def make_dataset(
logger = logger or get_logger(__name__)
random.seed(train_data_config.seed)
dataset_splits: Dict[str, datasets.Dataset] = {}
dataset_splits: Dict[str, Dataset] = {}
tmp_train_sets = []
logger.info("Loading training data from %s sources", len(train_data_config.sources))
@ -175,7 +175,7 @@ def make_dataset(
tmp_train_sets.append(
build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
)
dataset_splits["train"] = datasets.concatenate_datasets(tmp_train_sets)
dataset_splits["train"] = concatenate_datasets(tmp_train_sets)
logger.info(
f"Loaded {len(dataset_splits['train'])} training samples from {len(train_data_config.sources)} sources"
)
@ -187,7 +187,7 @@ def make_dataset(
tmp_validation_sets.append(
build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
)
dataset_splits["validation"] = datasets.concatenate_datasets(tmp_validation_sets)
dataset_splits["validation"] = concatenate_datasets(tmp_validation_sets)
logger.info(
f"Loaded {len(dataset_splits['validation'])} validation samples from {len(valid_data_config.sources)} sources"
)
@ -199,9 +199,9 @@ def make_dataset(
tmp_test_sets.append(
build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
)
dataset_splits["test"] = datasets.concatenate_datasets(tmp_test_sets)
dataset_splits["test"] = concatenate_datasets(tmp_test_sets)
logger.info(
f"Loaded {len(dataset_splits['test'])} test samples from {len(test_data_config.sources)} sources"
)
return datasets.DatasetDict(**dataset_splits)
return DatasetDict(**dataset_splits)

View File

@ -149,8 +149,8 @@ def run_train(config: TrainConfig):
model = get_peft_model(model=model, peft_config=peft_config)
log_trainable_parameters(model=model, logger=logger)
train_ds = train_ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
print(train_ds)
formatted_dataset = dataset.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
print(formatted_dataset)
print("---------------")
save_path = join_path("", config.save.path, run_name.run)
@ -202,8 +202,8 @@ def run_train(config: TrainConfig):
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
#eval_dataset=formatted_dataset["validation"], # pyright: ignore
train_dataset=formatted_dataset["train"],
eval_dataset=formatted_dataset["validation"], # pyright: ignore
tokenizer=processor.tokenizer,
#data_collator=collator,
#callbacks=[checkpoint_callback],