Config work

This commit is contained in:
Jake Poznanski 2024-10-16 18:37:52 +00:00
parent 3c1b7de293
commit d4f64ed82a
2 changed files with 15 additions and 49 deletions

View File

@ -7,43 +7,33 @@ wandb:
project: pdelfin
entity: ai2-llm
# TODO This is not used
format:
instruction_template: "Original:"
response_template: "Rewritten:"
# Template from here: https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.py#L30
chat_template: |
{% for message in messages %}
{{'<|im_start|>' + message['role'] + '\n' + message['content']}}
{% if loop.last %}
{{ '<|im_end|>'}}
{% else %}
{{ '<|im_end|>\n' }}
{% endif %}
{% endfor %}
generate:
max_length: 4096
max_length: 8192
train_data:
seed: 1337
sources:
# These tend to be really big, so it's only practical to host them as parquets on weka, otherwise you may OOM or just never finish dataloading
- name: openai_batch_data_v5_1_train
parquet_path: /data/jakep/pdfdata/openai_batch_data_v5_1_parquet/*.parquet
- name: openai_batch_data_v5_1_train
parquet_path: /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_parquet/*.parquet
response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_train_done/*.json
target_longest_image_dim: 1024
target_anchor_text_len: 6000
- name: openai_batch_data_v5_1_iabooks_train
response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json
target_longest_image_dim: 1024
target_anchor_text_len: 6000
valid_data:
metric_for_best_model: openai_batch_data_v5_1_eval_loss
sources:
# These tend to be small, so you can load from s3 it's no big deal
- name: openai_batch_data_v5_1_eval
query_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_eval/*.jsonl
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json
target_longest_image_dim: 1024
target_anchor_text_len: 6000
- name: openai_batch_data_v5_1_iabooks_eval
query_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_eval/*.jsonl
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_iabooks_eval/*.json
target_longest_image_dim: 1024
target_anchor_text_len: 6000

View File

@ -22,28 +22,6 @@ class ModelConfig:
model_revision: Optional[str] = field(help="The model revision to use for the model.", default=None)
@dataclass
class FormatConfig:
"""Configuration for formatting the text that is input to the model."""
new_line_symbol: str = field(
help="The symbol to use for new lines in the text; default is '\\n'.",
default="\n",
)
system_message: Optional[str] = field(
help="The system message to use for formatting the text; default is no system message.",
default=None,
)
instruction_template: str = field(
help="The template to use for formatting the input text", default="Original:"
)
response_template: str = field(help="The template to use for formatting the output text", default="Rewrite:")
chat_template: Optional[str] = field(
help="The template to use for formatting the chat text. If None, the default chat template will be used.",
default=None,
)
@dataclass
class GenerateConfig:
max_length: int = field(help="The maximum length of the generated text", default=4096)
@ -75,9 +53,9 @@ class AwsConfig:
@dataclass
class SourceConfig:
name: str = field(help="The name of the source")
parquet_path: Optional[str] = field(help="The s3/glob path to a bunch of parquet files for a preprocessed dataset.", default=None)
query_glob_path: Optional[str] = field(help="The s3 bucket pointing to the inputs sent to OpenAI to generate the silver data", default=None)
response_glob_path: Optional[str] = field(help="The s3 bucket pointing to the batch api response json's sent back from open ai", default=None)
response_glob_path: str = field(help="The s3 bucket pointing to the batch api response json's sent back from open ai")
target_longest_image_dim: int = field(help="Dimensions to render the pdf page image to")
target_anchor_text_len: int = field(help="Maximum amount of anchor text (aka prompt hint)")
@dataclass
@ -141,7 +119,6 @@ class TrainConfig:
lora: Optional[LoraConfig] = field(default=None, help="The LoRA configuration")
aws: AwsConfig = field(default=AwsConfig(), help="Configuration for AWS S3")
wandb: WandbConfig = field(default=WandbConfig(), help="Configuration for Weights and Biases")
format: FormatConfig = field(default=FormatConfig(), help="Configuration for formatting the input/output text")
train_data: DataConfig = field(default=DataConfig(), help="Configuration for the training data")
valid_data: DataConfig = field(default=DataConfig(), help="Configuration for the validation data")
generate: GenerateConfig = field(default=GenerateConfig(), help="Configuration for text generation")
@ -158,5 +135,4 @@ class DemoConfig:
share: bool = field(default=False, help="Share the demo publicly.")
model: ModelConfig = field(default=ModelConfig())
format: FormatConfig = field(default=FormatConfig())
generate: GenerateConfig = field(default=GenerateConfig())