mirror of
https://github.com/allenai/olmocr.git
synced 2026-01-07 12:51:39 +00:00
typos
This commit is contained in:
parent
ea3af0143c
commit
5916239cd8
@ -9,7 +9,7 @@ from typing import Any, Dict, Optional
|
||||
from logging import Logger
|
||||
|
||||
import boto3
|
||||
from datasets import Dataset, Features, Value, load_dataset
|
||||
from datasets import Dataset, Features, Value, load_dataset, concatenate_datasets, DatasetDict
|
||||
|
||||
from .core.config import DataConfig, SourceConfig
|
||||
|
||||
@ -167,7 +167,7 @@ def make_dataset(
|
||||
logger = logger or get_logger(__name__)
|
||||
random.seed(train_data_config.seed)
|
||||
|
||||
dataset_splits: Dict[str, datasets.Dataset] = {}
|
||||
dataset_splits: Dict[str, Dataset] = {}
|
||||
tmp_train_sets = []
|
||||
|
||||
logger.info("Loading training data from %s sources", len(train_data_config.sources))
|
||||
@ -175,7 +175,7 @@ def make_dataset(
|
||||
tmp_train_sets.append(
|
||||
build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
|
||||
)
|
||||
dataset_splits["train"] = datasets.concatenate_datasets(tmp_train_sets)
|
||||
dataset_splits["train"] = concatenate_datasets(tmp_train_sets)
|
||||
logger.info(
|
||||
f"Loaded {len(dataset_splits['train'])} training samples from {len(train_data_config.sources)} sources"
|
||||
)
|
||||
@ -187,7 +187,7 @@ def make_dataset(
|
||||
tmp_validation_sets.append(
|
||||
build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
|
||||
)
|
||||
dataset_splits["validation"] = datasets.concatenate_datasets(tmp_validation_sets)
|
||||
dataset_splits["validation"] = concatenate_datasets(tmp_validation_sets)
|
||||
logger.info(
|
||||
f"Loaded {len(dataset_splits['validation'])} validation samples from {len(valid_data_config.sources)} sources"
|
||||
)
|
||||
@ -199,9 +199,9 @@ def make_dataset(
|
||||
tmp_test_sets.append(
|
||||
build_batch_query_response_vision_dataset(source.query_glob_path, source.response_glob_path)
|
||||
)
|
||||
dataset_splits["test"] = datasets.concatenate_datasets(tmp_test_sets)
|
||||
dataset_splits["test"] = concatenate_datasets(tmp_test_sets)
|
||||
logger.info(
|
||||
f"Loaded {len(dataset_splits['test'])} test samples from {len(test_data_config.sources)} sources"
|
||||
)
|
||||
|
||||
return datasets.DatasetDict(**dataset_splits)
|
||||
return DatasetDict(**dataset_splits)
|
||||
@ -149,8 +149,8 @@ def run_train(config: TrainConfig):
|
||||
model = get_peft_model(model=model, peft_config=peft_config)
|
||||
log_trainable_parameters(model=model, logger=logger)
|
||||
|
||||
train_ds = train_ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
|
||||
print(train_ds)
|
||||
formatted_dataset = dataset.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor))
|
||||
print(formatted_dataset)
|
||||
print("---------------")
|
||||
|
||||
save_path = join_path("", config.save.path, run_name.run)
|
||||
@ -202,8 +202,8 @@ def run_train(config: TrainConfig):
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_ds,
|
||||
#eval_dataset=formatted_dataset["validation"], # pyright: ignore
|
||||
train_dataset=formatted_dataset["train"],
|
||||
eval_dataset=formatted_dataset["validation"], # pyright: ignore
|
||||
tokenizer=processor.tokenizer,
|
||||
#data_collator=collator,
|
||||
#callbacks=[checkpoint_callback],
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user