diff --git a/scripts/pii_rule_comparison.py b/scripts/pii_rule_comparison.py index 0716ba8..bd11c9e 100644 --- a/scripts/pii_rule_comparison.py +++ b/scripts/pii_rule_comparison.py @@ -2,22 +2,46 @@ """ Compare PII Detection Rules and Calculate IoU -This script processes JSONL attribute files from two different folders, -applies different rules to each for PII detection, and calculates the +This script processes documents and their attributes from S3 or local storage, +applies different rules for PII detection, and calculates the Intersection over Union (IoU) to measure how well they overlap. +How it works: +1. Documents are stored in one location (--docs-folder) +2. Attributes are stored in separate folders (--ref-attr-folder and --hyp-attr-folder) +3. The script merges documents with their attributes by matching filenames and document IDs +4. PII detection rules are applied to each set of merged documents +5. IoU and other metrics are calculated to compare the results + +Folder structure: +- s3://bucket/path/documents/ - Contains the main document JSONL files +- s3://bucket/path/attributes/attr_name/ - Contains attributes that can be matched with documents by ID + +Document and attribute matching: +- Files are matched by basename (example.jsonl in documents matches example.jsonl in attributes) +- Within each file, documents are matched by their "id" field +- When a match is found, attributes from the attribute file are merged into the document + Example usage: python pii_rule_comparison.py \ - --ref-folder s3://bucket/workspace/attributes/model_a \ - --hyp-folder s3://bucket/workspace/attributes/model_b \ + --docs-folder s3://bucket/path/documents \ + --ref-attr-folder s3://bucket/path/attributes/model_a \ + --hyp-attr-folder s3://bucket/path/attributes/model_b \ --ref-rule "gpt_4_1_contains_pii:any" \ --hyp-rule "gpt_4_1_contains_email_addresses:any" \ - --output-file iou_results.json + --output-file iou_results.json \ + --recursive + +Rule expression syntax: +- Simple rule: "attribute_name:rule_type" where rule_type is "any" or "all" +- Boolean expressions: "not rule1:any and rule2:all" +- Parentheses for grouping: "(rule1:any or rule2:any) and not rule3:all" """ import argparse import boto3 import gzip +import io import json import logging import os @@ -90,8 +114,9 @@ class BinaryNode(ExpressionNode): def parse_args(): parser = argparse.ArgumentParser(description="Compare PII detection rules and calculate IoU") - parser.add_argument("--ref-folder", required=True, help="Reference attribute folder path (local or s3://)") - parser.add_argument("--hyp-folder", required=True, help="Hypothesis attribute folder path (local or s3://)") + parser.add_argument("--docs-folder", required=True, help="Documents folder path containing JSONL files (local or s3://)") + parser.add_argument("--ref-attr-folder", required=True, help="Reference attributes folder path (local or s3://)") + parser.add_argument("--hyp-attr-folder", required=True, help="Hypothesis attributes folder path (local or s3://)") parser.add_argument("--ref-rule", required=True, help="""Reference rule expression. Can be a simple rule in format 'attribute_name:rule_type', where rule_type is 'any' or 'all'. Or a boolean expression like @@ -103,6 +128,7 @@ def parse_args(): parser.add_argument("--output-file", default="iou_results.json", help="Output JSON file to save results") parser.add_argument("--aws-profile", help="AWS profile for S3 access") parser.add_argument("--recursive", action="store_true", help="Recursively process folder structure") + parser.add_argument("--debug", action="store_true", help="Enable debug logging for more detailed output") return parser.parse_args() def parse_s3_path(s3_path): @@ -170,8 +196,35 @@ def load_jsonl_file(file_path, s3_client=None): if file_path.endswith(".gz"): decompressed = gzip.decompress(raw_data) elif file_path.endswith((".zst", ".ztd", ".zstd")): - dctx = zstd.ZstdDecompressor() - decompressed = dctx.decompress(raw_data) + try: + # First try with standard decompression + dctx = zstd.ZstdDecompressor() + decompressed = dctx.decompress(raw_data) + except zstd.ZstdError as e: + # If that fails, try with stream decompression + logger.warning(f"Standard zstd decompression failed for {file_path}, trying stream decompression: {e}") + + try: + # Try with content-size not required + dctx = zstd.ZstdDecompressor(max_window_size=2147483648) # Use a large window size + decompressor = dctx.stream_reader(io.BytesIO(raw_data)) + decompressed = decompressor.read() + except Exception as inner_e: + # If both methods fail, try with chunking + logger.warning(f"Stream decompression also failed, trying chunked reading: {inner_e}") + + # Chunked reading approach + buffer = io.BytesIO() + dctx = zstd.ZstdDecompressor(max_window_size=2147483648) + with dctx.stream_reader(io.BytesIO(raw_data)) as reader: + while True: + chunk = reader.read(16384) # Read in 16KB chunks + if not chunk: + break + buffer.write(chunk) + + buffer.seek(0) + decompressed = buffer.read() else: decompressed = raw_data @@ -183,6 +236,129 @@ def load_jsonl_file(file_path, s3_client=None): logger.error(f"Error loading file {file_path}: {e}") return [] +def load_documents_and_attributes(docs_folder, attr_folder, s3_client=None, recursive=False): + """ + Load documents and merge them with their attributes. + + Args: + docs_folder: Path to the documents folder + attr_folder: Path to the attributes folder + s3_client: S3 client for S3 paths + recursive: Whether to process folders recursively + + Returns: + List of documents with their attributes merged in + """ + try: + # List all document files + logger.info(f"Finding document files in: {docs_folder}") + doc_files = list_jsonl_files(docs_folder, s3_client, recursive) + logger.info(f"Found {len(doc_files)} document files") + + if not doc_files: + logger.warning(f"No document files found in {docs_folder}. Check the path and permissions.") + + # List all attribute files + logger.info(f"Finding attribute files in: {attr_folder}") + attr_files = list_jsonl_files(attr_folder, s3_client, recursive) + logger.info(f"Found {len(attr_files)} attribute files") + + if not attr_files: + logger.warning(f"No attribute files found in {attr_folder}. Check the path and permissions.") + + # Create a mapping from document filename to attribute filename + # based on matching basenames + attr_file_map = {} + + for attr_path in attr_files: + if attr_path.startswith("s3://"): + _, attr_key = parse_s3_path(attr_path) + basename = os.path.basename(attr_key) + else: + basename = os.path.basename(attr_path) + + attr_file_map[basename] = attr_path + logger.debug(f"Mapped attribute file basename: {basename} -> {attr_path}") + + # Load and merge documents with their attributes + merged_docs = [] + matched_files = 0 + docs_with_matched_attrs = 0 + + for doc_path in doc_files: + try: + if doc_path.startswith("s3://"): + _, doc_key = parse_s3_path(doc_path) + basename = os.path.basename(doc_key) + else: + basename = os.path.basename(doc_path) + + # Load documents + docs = load_jsonl_file(doc_path, s3_client) + if not docs: + logger.warning(f"No documents loaded from {basename} (path: {doc_path})") + continue + + logger.info(f"Loaded {len(docs)} documents from {basename}") + + # Find matching attribute file + if basename in attr_file_map: + matched_files += 1 + attr_path = attr_file_map[basename] + attrs = load_jsonl_file(attr_path, s3_client) + + if not attrs: + logger.warning(f"No attributes loaded from {os.path.basename(attr_path)} (path: {attr_path})") + merged_docs.extend(docs) + continue + + logger.info(f"Loaded {len(attrs)} attributes from {os.path.basename(attr_path)}") + + # Create a mapping from document ID to attributes + attr_by_id = {attr["id"]: attr for attr in attrs if "id" in attr} + logger.info(f"Created mapping for {len(attr_by_id)} attribute entries by ID") + + # Count documents with matched attributes + docs_matched_in_file = 0 + + # Merge documents with their attributes + for doc in docs: + if "id" in doc and doc["id"] in attr_by_id: + docs_matched_in_file += 1 + + # If document doesn't have attributes field, create it + if "attributes" not in doc: + doc["attributes"] = {} + + # If attributes document has attributes field, merge them + if "attributes" in attr_by_id[doc["id"]]: + doc["attributes"].update(attr_by_id[doc["id"]]["attributes"]) + else: + logger.debug(f"Attribute document {doc['id']} has no 'attributes' field") + + docs_with_matched_attrs += docs_matched_in_file + logger.info(f"Matched attributes for {docs_matched_in_file}/{len(docs)} documents in {basename}") + merged_docs.extend(docs) + else: + logger.warning(f"No matching attribute file found for {basename}") + merged_docs.extend(docs) + except Exception as e: + logger.error(f"Error processing document file {doc_path}: {e}") + continue + + logger.info(f"Total documents processed: {len(merged_docs)}") + logger.info(f"Files with matched attributes: {matched_files}/{len(doc_files)}") + logger.info(f"Documents with matched attributes: {docs_with_matched_attrs}") + + # Check if no attributes were merged + if docs_with_matched_attrs == 0: + logger.warning("No documents had matching attributes! This may indicate a problem with file naming or ID matching.") + + return merged_docs + except Exception as e: + logger.error(f"Error in load_documents_and_attributes: {e}") + raise + def apply_rule(doc, rule): """ Apply a rule to determine if a document meets the PII criteria. @@ -214,11 +390,20 @@ def apply_simple_rule(doc, attribute_name, rule_type): Returns: True if the document matches the rule, False otherwise """ + # Check if document has attributes if "attributes" not in doc or not doc["attributes"]: + logger.debug(f"Document {doc.get('id', 'unknown')} has no attributes") return False attributes = doc["attributes"] - if attribute_name not in attributes or not attributes[attribute_name]: + + # Check if the specific attribute exists + if attribute_name not in attributes: + logger.debug(f"Document {doc.get('id', 'unknown')} doesn't have attribute: {attribute_name}") + return False + + if not attributes[attribute_name]: + logger.debug(f"Document {doc.get('id', 'unknown')} has empty attribute: {attribute_name}") return False # Extract the boolean values from the attribute spans @@ -226,12 +411,20 @@ def apply_simple_rule(doc, attribute_name, rule_type): values = [span[2] for span in attributes[attribute_name] if len(span) >= 3 and span[2] is not None] if not values: + logger.debug(f"Document {doc.get('id', 'unknown')} has no valid values for attribute: {attribute_name}") return False + # Apply the rule if rule_type == "any": - return any(values) + result = any(values) + if result: + logger.debug(f"Document {doc.get('id', 'unknown')} matched rule '{attribute_name}:{rule_type}' (found True in {len(values)} values)") + return result elif rule_type == "all": - return all(values) + result = all(values) + if result: + logger.debug(f"Document {doc.get('id', 'unknown')} matched rule '{attribute_name}:{rule_type}' (all {len(values)} values are True)") + return result else: raise ValueError(f"Unknown rule type: {rule_type}") @@ -567,35 +760,6 @@ def is_complex_expression(rule): """Check if the rule is a complex boolean expression.""" return isinstance(rule, ExpressionNode) -def get_matching_files(ref_files, hyp_files): - """ - Find files that exist in both reference and hypothesis folders, - matching by their relative paths. - - Returns dict mapping ref_path -> hyp_path for matched files - """ - # First, convert to relative paths for matching - def get_relative_path(path, base_folder): - if path.startswith("s3://"): - _, full_key = parse_s3_path(path) - _, base_key = parse_s3_path(base_folder) - return full_key[len(base_key):].lstrip("/") if full_key.startswith(base_key) else full_key - else: - return os.path.relpath(path, base_folder) - - ref_base = args.ref_folder - hyp_base = args.hyp_folder - - ref_relative = {get_relative_path(path, ref_base): path for path in ref_files} - hyp_relative = {get_relative_path(path, hyp_base): path for path in hyp_files} - - # Find matching files - matched_files = {} - for rel_path in ref_relative: - if rel_path in hyp_relative: - matched_files[ref_relative[rel_path]] = hyp_relative[rel_path] - - return matched_files def calculate_iou(ref_ids, hyp_ids): """Calculate Intersection over Union of two sets of document IDs.""" @@ -715,24 +879,19 @@ def get_expression_summary(expression): return str(expression) -def compare_files(ref_path, hyp_path, ref_rule, hyp_rule, s3_client=None): +def compare_documents(ref_docs, hyp_docs, ref_rule, hyp_rule): """ - Compare two JSONL files using the specified rules and calculate IoU. + Compare two sets of documents using the specified rules and calculate IoU. Args: - ref_path: Path to reference JSONL file - hyp_path: Path to hypothesis JSONL file + ref_docs: List of reference documents + hyp_docs: List of hypothesis documents ref_rule: Rule expression for reference (tuple or ExpressionNode) hyp_rule: Rule expression for hypothesis (tuple or ExpressionNode) - s3_client: S3 client for S3 paths Returns: Dictionary with comparison results """ - # Load the files - ref_docs = load_jsonl_file(ref_path, s3_client) - hyp_docs = load_jsonl_file(hyp_path, s3_client) - # Extract document IDs and create ID-to-document maps ref_id_to_doc = {doc["id"]: doc for doc in ref_docs if "id" in doc} hyp_id_to_doc = {doc["id"]: doc for doc in hyp_docs if "id" in doc} @@ -788,8 +947,6 @@ def compare_files(ref_path, hyp_path, ref_rule, hyp_rule, s3_client=None): hyp_rule_str = get_expression_summary(hyp_rule) return { - "ref_file": ref_path, - "hyp_file": hyp_path, "total_docs": len(common_ids), "ref_rule": ref_rule_str, "hyp_rule": hyp_rule_str, @@ -850,9 +1007,16 @@ def main(): global args args = parse_args() + # Set up logging based on arguments + if args.debug: + logger.setLevel(logging.DEBUG) + logger.debug("Debug logging enabled") + # Set up S3 client if needed s3_client = None - if args.ref_folder.startswith("s3://") or args.hyp_folder.startswith("s3://"): + if (args.docs_folder.startswith("s3://") or + args.ref_attr_folder.startswith("s3://") or + args.hyp_attr_folder.startswith("s3://")): session = boto3.Session(profile_name=args.aws_profile) if args.aws_profile else boto3.Session() s3_client = session.client("s3") @@ -870,88 +1034,45 @@ def main(): logger.info(f"Reference rule parsed as: {ref_rule_str}") logger.info(f"Hypothesis rule parsed as: {hyp_rule_str}") - # List JSONL files in both folders - logger.info(f"Finding JSONL files in reference folder: {args.ref_folder}") - ref_files = list_jsonl_files(args.ref_folder, s3_client, args.recursive) + # Load documents and merge with attributes + logger.info("Loading documents and merging with reference attributes...") + ref_docs = load_documents_and_attributes(args.docs_folder, args.ref_attr_folder, s3_client, args.recursive) - logger.info(f"Finding JSONL files in hypothesis folder: {args.hyp_folder}") - hyp_files = list_jsonl_files(args.hyp_folder, s3_client, args.recursive) + logger.info("Loading documents and merging with hypothesis attributes...") + hyp_docs = load_documents_and_attributes(args.docs_folder, args.hyp_attr_folder, s3_client, args.recursive) - logger.info(f"Found {len(ref_files)} files in reference folder and {len(hyp_files)} files in hypothesis folder") + # Compare the documents + logger.info("Comparing documents using reference and hypothesis rules...") + comparison_result = compare_documents(ref_docs, hyp_docs, ref_rule, hyp_rule) - # Find matching files - matched_files = get_matching_files(ref_files, hyp_files) - logger.info(f"Found {len(matched_files)} matching files between folders") - - if not matched_files: - logger.error("No matching files found between reference and hypothesis folders") - sys.exit(1) - - # Process each pair of files - results = [] + # Prepare overall statistics overall_stats = { - "total_files": len(matched_files), - "total_docs": 0, - "ref_matches": 0, - "hyp_matches": 0, - "true_positives": 0, - "false_positives": 0, - "false_negatives": 0, - # Initialize rule stats counters - "ref_rule_stats": defaultdict(int), - "hyp_rule_stats": defaultdict(int) + "total_docs": comparison_result["total_docs"], + "ref_matches": comparison_result["ref_matches"], + "hyp_matches": comparison_result["hyp_matches"], + "true_positives": comparison_result["true_positives"], + "false_positives": comparison_result["false_positives"], + "false_negatives": comparison_result["false_negatives"], + "precision": comparison_result["precision"], + "recall": comparison_result["recall"], + "f1": comparison_result["f1"], + "iou": comparison_result["iou"], + "ref_rule_stats": comparison_result["ref_rule_stats"], + "hyp_rule_stats": comparison_result["hyp_rule_stats"] } - for i, (ref_path, hyp_path) in enumerate(matched_files.items()): - logger.info(f"Processing file pair {i+1}/{len(matched_files)}: {os.path.basename(ref_path)}") - - file_result = compare_files(ref_path, hyp_path, ref_rule, hyp_rule, s3_client) - results.append(file_result) - - # Accumulate overall statistics - overall_stats["total_docs"] += file_result["total_docs"] - overall_stats["ref_matches"] += file_result["ref_matches"] - overall_stats["hyp_matches"] += file_result["hyp_matches"] - overall_stats["true_positives"] += file_result["true_positives"] - overall_stats["false_positives"] += file_result["false_positives"] - overall_stats["false_negatives"] += file_result["false_negatives"] - - # Accumulate rule statistics - for key, value in file_result["ref_rule_stats"].items(): - overall_stats["ref_rule_stats"][key] += value - for key, value in file_result["hyp_rule_stats"].items(): - overall_stats["hyp_rule_stats"][key] += value - - # Calculate overall metrics - tp = overall_stats["true_positives"] - fp = overall_stats["false_positives"] - fn = overall_stats["false_negatives"] - - overall_stats["precision"] = tp / (tp + fp) if (tp + fp) > 0 else 0 - overall_stats["recall"] = tp / (tp + fn) if (tp + fn) > 0 else 0 - overall_stats["f1"] = ( - 2 * overall_stats["precision"] * overall_stats["recall"] / - (overall_stats["precision"] + overall_stats["recall"]) - if (overall_stats["precision"] + overall_stats["recall"]) > 0 else 0 - ) - overall_stats["iou"] = tp / (tp + fp + fn) if (tp + fp + fn) > 0 else 0 - - # Convert defaultdicts to regular dicts for JSON serialization - overall_stats["ref_rule_stats"] = dict(overall_stats["ref_rule_stats"]) - overall_stats["hyp_rule_stats"] = dict(overall_stats["hyp_rule_stats"]) - # Prepare final output output = { "config": { - "ref_folder": args.ref_folder, - "hyp_folder": args.hyp_folder, + "docs_folder": args.docs_folder, + "ref_attr_folder": args.ref_attr_folder, + "hyp_attr_folder": args.hyp_attr_folder, "ref_rule": args.ref_rule, "ref_rule_parsed": ref_rule_str, "hyp_rule": args.hyp_rule, "hyp_rule_parsed": hyp_rule_str }, - "overall": overall_stats, - "file_results": results + "overall": overall_stats } # Save results @@ -960,6 +1081,9 @@ def main(): # Print summary logger.info("\n--- COMPARISON SUMMARY ---") + logger.info(f"Documents Folder: {args.docs_folder}") + logger.info(f"Reference Attributes Folder: {args.ref_attr_folder}") + logger.info(f"Hypothesis Attributes Folder: {args.hyp_attr_folder}") logger.info(f"Reference Rule Expression: {args.ref_rule}") logger.info(f" Parsed as: {ref_rule_str}") logger.info(f"Hypothesis Rule Expression: {args.hyp_rule}") @@ -984,4 +1108,29 @@ def main(): logger.info(f"Detailed results saved to: {args.output_file}") if __name__ == "__main__": - main() \ No newline at end of file + main() + +# Example commands with actual S3 paths: +""" +# Example for AI2 OE data with resume detection: +python scripts/pii_rule_comparison.py \ + --docs-folder s3://ai2-oe-data/jakep/s2pdf_dedupe_minhash_v1_mini/documents/ \ + --ref-attr-folder s3://ai2-oe-data/jakep/s2pdf_dedupe_minhash_v1_mini/attributes/chatgpt_pii_vision/ \ + --hyp-attr-folder s3://ai2-oe-data/jakep/s2pdf_dedupe_minhash_v1_mini/attributes/model_pii_tagging/ \ + --ref-rule "gpt_4_1_contains_pii:any and not gpt_4_1_is_public_document:all" \ + --hyp-rule "google_gemma-3-4b-it_is_resume_cv:any" \ + --output-file pii_resume_comparison.json \ + --recursive \ + --debug + +# Example for Dolma data with PII detection: +python scripts/pii_rule_comparison.py \ + --docs-folder s3://allenai-dolma/documents/v1.5 \ + --ref-attr-folder s3://allenai-dolma/attributes/v1.5/pii_detection/gpt4_1 \ + --hyp-attr-folder s3://allenai-dolma/attributes/v1.5/pii_detection/custom_rule \ + --ref-rule "contains_pii:any" \ + --hyp-rule "(contains_email_addresses:any or contains_phone_numbers:any) and not false_positive:any" \ + --output-file pii_detection_comparison.json \ + --recursive \ + --aws-profile dolma +""" \ No newline at end of file