diff --git a/pdelfin/silver_data/convertsilver_birr.py b/pdelfin/silver_data/convertsilver_birr.py index 170508e..70cc50c 100644 --- a/pdelfin/silver_data/convertsilver_birr.py +++ b/pdelfin/silver_data/convertsilver_birr.py @@ -10,6 +10,9 @@ import smart_open from pdelfin.prompts import build_finetuning_prompt +# Import Plotly for plotting +import plotly.express as px + def setup_logging(): """Configure logging for the script.""" @@ -57,9 +60,11 @@ def process_file(input_file: str, output_file: str, rewrite_prompt_str: bool): Args: input_file (str): Path or URL to the input JSONL file. output_file (str): Path or URL to the output JSONL file. + rewrite_prompt_str (bool): Flag to rewrite the prompt string. """ processed_count = 0 error_count = 0 + prompt_lengths = [] try: with smart_open.open(input_file, 'r', encoding='utf-8') as infile, \ @@ -89,15 +94,21 @@ def process_file(input_file: str, output_file: str, rewrite_prompt_str: bool): transformed["chat_messages"][0]["content"][0]["text"] = build_finetuning_prompt(raw_page_text) if transformed is not None: + prompt_text = transformed["chat_messages"][0]["content"][0]["text"] + prompt_length = len(prompt_text) + prompt_lengths.append(prompt_length) + outfile.write(json.dumps(transformed) + '\n') processed_count += 1 else: error_count += 1 logging.info(f"Processed '{input_file}': {processed_count} records transformed, {error_count} errors.") + return prompt_lengths except Exception as e: logging.exception(e) logging.error(f"Failed to process file {input_file}: {e}") + return [] def construct_output_file_path(input_file_path, input_dir, output_dir): @@ -230,6 +241,7 @@ def main(): tasks.append((input_file, output_file)) # Process files in parallel + all_prompt_lengths = [] with ProcessPoolExecutor(max_workers=max_jobs) as executor: future_to_file = { executor.submit(process_file, input_file, output_file, args.rewrite_finetuning_prompt): input_file @@ -239,12 +251,29 @@ def main(): for future in as_completed(future_to_file): input_file = future_to_file[future] try: - future.result() + prompt_lengths = future.result() + all_prompt_lengths.extend(prompt_lengths) except Exception as exc: logging.error(f"File {input_file} generated an exception: {exc}") logging.info("All files have been processed.") + # Plot histogram of prompt lengths + if all_prompt_lengths: + fig = px.histogram(all_prompt_lengths, nbins=50, title="Histogram of Prompt Lengths") + fig.update_xaxes(title="Prompt Length") + fig.update_yaxes(title="Frequency") + try: + fig.write_image("prompt_lengths_histogram.png") + logging.info("Histogram of prompt lengths has been saved to 'prompt_lengths_histogram.png'.") + except Exception as e: + logging.error(f"Failed to save the histogram image: {e}") + logging.error("Please make sure that the 'kaleido' package is installed (pip install -U kaleido).") + fig.write_html("prompt_lengths_histogram.html") + logging.info("Histogram of prompt lengths has been saved to 'prompt_lengths_histogram.html'.") + else: + logging.warning("No prompt lengths were collected; histogram will not be generated.") + if __name__ == "__main__": main()