Adding prompt length histogram to a script

This commit is contained in:
Jake Poznanski 2024-10-08 18:22:56 +00:00
parent adc702c918
commit 57d9a21eeb

View File

@ -10,6 +10,9 @@ import smart_open
from pdelfin.prompts import build_finetuning_prompt
# Import Plotly for plotting
import plotly.express as px
def setup_logging():
"""Configure logging for the script."""
@ -57,9 +60,11 @@ def process_file(input_file: str, output_file: str, rewrite_prompt_str: bool):
Args:
input_file (str): Path or URL to the input JSONL file.
output_file (str): Path or URL to the output JSONL file.
rewrite_prompt_str (bool): Flag to rewrite the prompt string.
"""
processed_count = 0
error_count = 0
prompt_lengths = []
try:
with smart_open.open(input_file, 'r', encoding='utf-8') as infile, \
@ -89,15 +94,21 @@ def process_file(input_file: str, output_file: str, rewrite_prompt_str: bool):
transformed["chat_messages"][0]["content"][0]["text"] = build_finetuning_prompt(raw_page_text)
if transformed is not None:
prompt_text = transformed["chat_messages"][0]["content"][0]["text"]
prompt_length = len(prompt_text)
prompt_lengths.append(prompt_length)
outfile.write(json.dumps(transformed) + '\n')
processed_count += 1
else:
error_count += 1
logging.info(f"Processed '{input_file}': {processed_count} records transformed, {error_count} errors.")
return prompt_lengths
except Exception as e:
logging.exception(e)
logging.error(f"Failed to process file {input_file}: {e}")
return []
def construct_output_file_path(input_file_path, input_dir, output_dir):
@ -230,6 +241,7 @@ def main():
tasks.append((input_file, output_file))
# Process files in parallel
all_prompt_lengths = []
with ProcessPoolExecutor(max_workers=max_jobs) as executor:
future_to_file = {
executor.submit(process_file, input_file, output_file, args.rewrite_finetuning_prompt): input_file
@ -239,12 +251,29 @@ def main():
for future in as_completed(future_to_file):
input_file = future_to_file[future]
try:
future.result()
prompt_lengths = future.result()
all_prompt_lengths.extend(prompt_lengths)
except Exception as exc:
logging.error(f"File {input_file} generated an exception: {exc}")
logging.info("All files have been processed.")
# Plot histogram of prompt lengths
if all_prompt_lengths:
fig = px.histogram(all_prompt_lengths, nbins=50, title="Histogram of Prompt Lengths")
fig.update_xaxes(title="Prompt Length")
fig.update_yaxes(title="Frequency")
try:
fig.write_image("prompt_lengths_histogram.png")
logging.info("Histogram of prompt lengths has been saved to 'prompt_lengths_histogram.png'.")
except Exception as e:
logging.error(f"Failed to save the histogram image: {e}")
logging.error("Please make sure that the 'kaleido' package is installed (pip install -U kaleido).")
fig.write_html("prompt_lengths_histogram.html")
logging.info("Histogram of prompt lengths has been saved to 'prompt_lengths_histogram.html'.")
else:
logging.warning("No prompt lengths were collected; histogram will not be generated.")
if __name__ == "__main__":
main()