mirror of
https://github.com/allenai/olmocr.git
synced 2025-09-25 16:30:28 +00:00
Adding prompt length histogram to a script
This commit is contained in:
parent
adc702c918
commit
57d9a21eeb
@ -10,6 +10,9 @@ import smart_open
|
|||||||
|
|
||||||
from pdelfin.prompts import build_finetuning_prompt
|
from pdelfin.prompts import build_finetuning_prompt
|
||||||
|
|
||||||
|
# Import Plotly for plotting
|
||||||
|
import plotly.express as px
|
||||||
|
|
||||||
|
|
||||||
def setup_logging():
|
def setup_logging():
|
||||||
"""Configure logging for the script."""
|
"""Configure logging for the script."""
|
||||||
@ -57,9 +60,11 @@ def process_file(input_file: str, output_file: str, rewrite_prompt_str: bool):
|
|||||||
Args:
|
Args:
|
||||||
input_file (str): Path or URL to the input JSONL file.
|
input_file (str): Path or URL to the input JSONL file.
|
||||||
output_file (str): Path or URL to the output JSONL file.
|
output_file (str): Path or URL to the output JSONL file.
|
||||||
|
rewrite_prompt_str (bool): Flag to rewrite the prompt string.
|
||||||
"""
|
"""
|
||||||
processed_count = 0
|
processed_count = 0
|
||||||
error_count = 0
|
error_count = 0
|
||||||
|
prompt_lengths = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with smart_open.open(input_file, 'r', encoding='utf-8') as infile, \
|
with smart_open.open(input_file, 'r', encoding='utf-8') as infile, \
|
||||||
@ -89,15 +94,21 @@ def process_file(input_file: str, output_file: str, rewrite_prompt_str: bool):
|
|||||||
transformed["chat_messages"][0]["content"][0]["text"] = build_finetuning_prompt(raw_page_text)
|
transformed["chat_messages"][0]["content"][0]["text"] = build_finetuning_prompt(raw_page_text)
|
||||||
|
|
||||||
if transformed is not None:
|
if transformed is not None:
|
||||||
|
prompt_text = transformed["chat_messages"][0]["content"][0]["text"]
|
||||||
|
prompt_length = len(prompt_text)
|
||||||
|
prompt_lengths.append(prompt_length)
|
||||||
|
|
||||||
outfile.write(json.dumps(transformed) + '\n')
|
outfile.write(json.dumps(transformed) + '\n')
|
||||||
processed_count += 1
|
processed_count += 1
|
||||||
else:
|
else:
|
||||||
error_count += 1
|
error_count += 1
|
||||||
|
|
||||||
logging.info(f"Processed '{input_file}': {processed_count} records transformed, {error_count} errors.")
|
logging.info(f"Processed '{input_file}': {processed_count} records transformed, {error_count} errors.")
|
||||||
|
return prompt_lengths
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(e)
|
logging.exception(e)
|
||||||
logging.error(f"Failed to process file {input_file}: {e}")
|
logging.error(f"Failed to process file {input_file}: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
def construct_output_file_path(input_file_path, input_dir, output_dir):
|
def construct_output_file_path(input_file_path, input_dir, output_dir):
|
||||||
@ -230,6 +241,7 @@ def main():
|
|||||||
tasks.append((input_file, output_file))
|
tasks.append((input_file, output_file))
|
||||||
|
|
||||||
# Process files in parallel
|
# Process files in parallel
|
||||||
|
all_prompt_lengths = []
|
||||||
with ProcessPoolExecutor(max_workers=max_jobs) as executor:
|
with ProcessPoolExecutor(max_workers=max_jobs) as executor:
|
||||||
future_to_file = {
|
future_to_file = {
|
||||||
executor.submit(process_file, input_file, output_file, args.rewrite_finetuning_prompt): input_file
|
executor.submit(process_file, input_file, output_file, args.rewrite_finetuning_prompt): input_file
|
||||||
@ -239,12 +251,29 @@ def main():
|
|||||||
for future in as_completed(future_to_file):
|
for future in as_completed(future_to_file):
|
||||||
input_file = future_to_file[future]
|
input_file = future_to_file[future]
|
||||||
try:
|
try:
|
||||||
future.result()
|
prompt_lengths = future.result()
|
||||||
|
all_prompt_lengths.extend(prompt_lengths)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logging.error(f"File {input_file} generated an exception: {exc}")
|
logging.error(f"File {input_file} generated an exception: {exc}")
|
||||||
|
|
||||||
logging.info("All files have been processed.")
|
logging.info("All files have been processed.")
|
||||||
|
|
||||||
|
# Plot histogram of prompt lengths
|
||||||
|
if all_prompt_lengths:
|
||||||
|
fig = px.histogram(all_prompt_lengths, nbins=50, title="Histogram of Prompt Lengths")
|
||||||
|
fig.update_xaxes(title="Prompt Length")
|
||||||
|
fig.update_yaxes(title="Frequency")
|
||||||
|
try:
|
||||||
|
fig.write_image("prompt_lengths_histogram.png")
|
||||||
|
logging.info("Histogram of prompt lengths has been saved to 'prompt_lengths_histogram.png'.")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to save the histogram image: {e}")
|
||||||
|
logging.error("Please make sure that the 'kaleido' package is installed (pip install -U kaleido).")
|
||||||
|
fig.write_html("prompt_lengths_histogram.html")
|
||||||
|
logging.info("Histogram of prompt lengths has been saved to 'prompt_lengths_histogram.html'.")
|
||||||
|
else:
|
||||||
|
logging.warning("No prompt lengths were collected; histogram will not be generated.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user