| 
									
										
										
										
											2025-02-18 17:14:56 +00:00
										 |  |  | #!/usr/bin/env python3 | 
					
						
							|  |  |  | import argparse | 
					
						
							|  |  |  | import json | 
					
						
							|  |  |  | import random | 
					
						
							| 
									
										
										
										
											2025-02-18 17:53:46 +00:00
										 |  |  | import re | 
					
						
							| 
									
										
										
										
											2025-02-18 17:14:56 +00:00
										 |  |  | import time | 
					
						
							| 
									
										
										
										
											2025-02-25 08:57:02 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | import boto3 | 
					
						
							|  |  |  | import requests | 
					
						
							| 
									
										
										
										
											2025-02-18 17:53:46 +00:00
										 |  |  | from tqdm import tqdm | 
					
						
							|  |  |  | from transformers import AutoTokenizer | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # Allowed characters: alphanumeric, space, and basic punctuation ".,!?()" | 
					
						
							| 
									
										
										
										
											2025-02-25 08:57:02 -08:00
										 |  |  | ALLOWED_RE = re.compile(r"^[A-Za-z0-9\.,!?() ]+$") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-18 17:14:56 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | def get_random_line_from_s3(bucket, key): | 
					
						
							| 
									
										
										
										
											2025-02-18 17:53:46 +00:00
										 |  |  |     """
 | 
					
						
							|  |  |  |     Reads an S3 object line-by-line and returns a random line using reservoir sampling. | 
					
						
							|  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2025-02-25 08:57:02 -08:00
										 |  |  |     s3 = boto3.client("s3") | 
					
						
							| 
									
										
										
										
											2025-02-18 17:14:56 +00:00
										 |  |  |     response = s3.get_object(Bucket=bucket, Key=key) | 
					
						
							|  |  |  |     random_line = None | 
					
						
							|  |  |  |     count = 0 | 
					
						
							| 
									
										
										
										
											2025-02-25 08:57:02 -08:00
										 |  |  |     for line in response["Body"].iter_lines(): | 
					
						
							| 
									
										
										
										
											2025-02-18 17:14:56 +00:00
										 |  |  |         if not line: | 
					
						
							|  |  |  |             continue | 
					
						
							| 
									
										
										
										
											2025-02-25 08:57:02 -08:00
										 |  |  |         line_str = line.decode("utf-8") | 
					
						
							| 
									
										
										
										
											2025-02-18 17:14:56 +00:00
										 |  |  |         count += 1 | 
					
						
							|  |  |  |         if random.randint(1, count) == 1: | 
					
						
							|  |  |  |             random_line = line_str | 
					
						
							|  |  |  |     return random_line | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-25 08:57:02 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-18 17:14:56 +00:00
										 |  |  | def query_infinigram(ngram, index="v4_rpj_llama_s4", retries=3): | 
					
						
							| 
									
										
										
										
											2025-02-18 17:53:46 +00:00
										 |  |  |     """
 | 
					
						
							|  |  |  |     Sends a count query to the infini-gram API for the given n-gram. | 
					
						
							|  |  |  |     Retries a few times in case of network issues. | 
					
						
							|  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2025-02-18 17:14:56 +00:00
										 |  |  |     url = "https://api.infini-gram.io/" | 
					
						
							|  |  |  |     payload = { | 
					
						
							|  |  |  |         "index": index, | 
					
						
							|  |  |  |         "query_type": "count", | 
					
						
							|  |  |  |         "query": ngram, | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     for i in range(retries): | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             response = requests.post(url, json=payload, timeout=10) | 
					
						
							|  |  |  |             if response.status_code == 200: | 
					
						
							|  |  |  |                 result = response.json() | 
					
						
							|  |  |  |                 if "count" in result: | 
					
						
							|  |  |  |                     return result["count"] | 
					
						
							| 
									
										
										
										
											2025-03-03 15:56:39 -08:00
										 |  |  |         except Exception:  # type: ignore | 
					
						
							| 
									
										
										
										
											2025-02-18 17:14:56 +00:00
										 |  |  |             time.sleep(1) | 
					
						
							|  |  |  |     return 0 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-25 08:57:02 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-18 17:53:46 +00:00
										 |  |  | def process_document(doc, tokenizer, ngram_size, num_samples, index="v4_rpj_llama_s4"): | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     Tokenizes the document using the Llama2 tokenizer and samples random n-grams. | 
					
						
							|  |  |  |     Each n-gram is chosen such that: | 
					
						
							|  |  |  |       1. It starts on a word-split boundary (using the offset mapping and a check on the preceding character). | 
					
						
							|  |  |  |       2. Its decoded string contains only alphanumeric characters, spaces, and the punctuation marks ".,!?()". | 
					
						
							| 
									
										
										
										
											2025-02-25 08:57:02 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-18 17:53:46 +00:00
										 |  |  |     Each valid n-gram is then queried using the infini-gram API. | 
					
						
							|  |  |  |     The function returns the document id, the number of matching n-grams (i.e. API count > 0), | 
					
						
							|  |  |  |     the total number of valid n-grams sampled, and a list of tuples (flag, ngram_string). | 
					
						
							|  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2025-02-18 17:14:56 +00:00
										 |  |  |     text = doc.get("text", "") | 
					
						
							|  |  |  |     doc_id = doc.get("id", "Unknown") | 
					
						
							| 
									
										
										
										
											2025-02-18 17:53:46 +00:00
										 |  |  |     # Get tokenized representation with offset mapping to determine word boundaries. | 
					
						
							|  |  |  |     tokenized = tokenizer(text, add_special_tokens=False, return_offsets_mapping=True) | 
					
						
							|  |  |  |     token_ids = tokenized["input_ids"] | 
					
						
							| 
									
										
										
										
											2025-02-25 08:57:02 -08:00
										 |  |  |     # offsets = tokenized["offset_mapping"] | 
					
						
							| 
									
										
										
										
											2025-02-18 17:53:46 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |     if len(token_ids) < ngram_size: | 
					
						
							|  |  |  |         return doc_id, 0, 0, [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Determine valid starting indices based on word-split boundaries. | 
					
						
							|  |  |  |     valid_positions = [] | 
					
						
							| 
									
										
										
										
											2025-02-18 19:01:17 +00:00
										 |  |  |     # for i in range(len(token_ids) - ngram_size + 1): | 
					
						
							|  |  |  |     #     start_offset = offsets[i][0] | 
					
						
							|  |  |  |     #     if start_offset == 0 or (start_offset > 0 and text[start_offset - 1] == " "): | 
					
						
							|  |  |  |     #         valid_positions.append(i) | 
					
						
							| 
									
										
										
										
											2025-02-25 08:57:02 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-18 17:53:46 +00:00
										 |  |  |     if not valid_positions: | 
					
						
							|  |  |  |         # Fallback: if no valid positions are found, use all possible positions. | 
					
						
							|  |  |  |         valid_positions = list(range(len(token_ids) - ngram_size + 1)) | 
					
						
							| 
									
										
										
										
											2025-02-25 08:57:02 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-18 17:53:46 +00:00
										 |  |  |     valid_ngram_details = [] | 
					
						
							|  |  |  |     attempts = 0 | 
					
						
							|  |  |  |     max_attempts = num_samples * 10  # Limit to prevent infinite loops. | 
					
						
							|  |  |  |     while len(valid_ngram_details) < num_samples and attempts < max_attempts: | 
					
						
							|  |  |  |         idx = random.choice(valid_positions) | 
					
						
							| 
									
										
										
										
											2025-02-25 08:57:02 -08:00
										 |  |  |         ngram_token_ids = token_ids[idx : idx + ngram_size] | 
					
						
							| 
									
										
										
										
											2025-02-18 17:53:46 +00:00
										 |  |  |         ngram_str = tokenizer.decode(ngram_token_ids, clean_up_tokenization_spaces=True) | 
					
						
							|  |  |  |         # Only accept n-grams that contain only allowed characters. | 
					
						
							| 
									
										
										
										
											2025-02-18 19:01:17 +00:00
										 |  |  |         if ALLOWED_RE.fullmatch(ngram_str) and len(ngram_str.strip()) > ngram_size * 3: | 
					
						
							| 
									
										
										
										
											2025-02-18 17:53:46 +00:00
										 |  |  |             count = query_infinigram(ngram_str, index=index) | 
					
						
							|  |  |  |             flag = "YES" if count > 0 else "NO" | 
					
						
							|  |  |  |             valid_ngram_details.append((flag, ngram_str)) | 
					
						
							|  |  |  |         attempts += 1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     match_count = sum(1 for flag, _ in valid_ngram_details if flag == "YES") | 
					
						
							|  |  |  |     sample_count = len(valid_ngram_details) | 
					
						
							|  |  |  |     return doc_id, match_count, sample_count, valid_ngram_details | 
					
						
							| 
									
										
										
										
											2025-02-18 17:14:56 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-25 08:57:02 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-18 17:14:56 +00:00
										 |  |  | def main(): | 
					
						
							| 
									
										
										
										
											2025-02-25 08:57:02 -08:00
										 |  |  |     parser = argparse.ArgumentParser(description="Infini-gram n-gram matching script with Llama2 tokenization.") | 
					
						
							| 
									
										
										
										
											2025-02-18 17:14:56 +00:00
										 |  |  |     parser.add_argument("N", type=int, help="Number of random .jsonl files to process") | 
					
						
							|  |  |  |     parser.add_argument("s3_path", type=str, help="S3 path to a prefix containing .jsonl files (e.g., s3://my-bucket/my-prefix/)") | 
					
						
							| 
									
										
										
										
											2025-02-18 19:01:17 +00:00
										 |  |  |     parser.add_argument("--index", type=str, default="v4_dolma-v1_7_llama", help="Infini-gram index to use (default: v4_rpj_llama_s4)") | 
					
						
							| 
									
										
										
										
											2025-02-18 17:53:46 +00:00
										 |  |  |     parser.add_argument("--ngram_size", type=int, default=10, help="Size of the n-gram to sample (default: 10)") | 
					
						
							|  |  |  |     parser.add_argument("--num_ngrams", type=int, default=100, help="Number of random n-grams to sample from each document (default: 100)") | 
					
						
							| 
									
										
										
										
											2025-02-18 17:14:56 +00:00
										 |  |  |     args = parser.parse_args() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if not args.s3_path.startswith("s3://"): | 
					
						
							|  |  |  |         print("Error: s3_path must start with 's3://'") | 
					
						
							|  |  |  |         return | 
					
						
							|  |  |  |     path_without_scheme = args.s3_path[5:] | 
					
						
							|  |  |  |     parts = path_without_scheme.split("/", 1) | 
					
						
							|  |  |  |     bucket = parts[0] | 
					
						
							|  |  |  |     prefix = parts[1] if len(parts) > 1 else "" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-18 17:53:46 +00:00
										 |  |  |     print("Listing .jsonl files from S3...") | 
					
						
							| 
									
										
										
										
											2025-02-18 17:14:56 +00:00
										 |  |  |     s3 = boto3.client("s3") | 
					
						
							|  |  |  |     response = s3.list_objects_v2(Bucket=bucket, Prefix=prefix) | 
					
						
							|  |  |  |     files = [obj["Key"] for obj in response.get("Contents", []) if obj["Key"].endswith(".jsonl")] | 
					
						
							|  |  |  |     if not files: | 
					
						
							|  |  |  |         print("No .jsonl files found in the given prefix.") | 
					
						
							|  |  |  |         return | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if args.N > len(files): | 
					
						
							|  |  |  |         print(f"Requested {args.N} files, but only found {len(files)}. Processing all available files.") | 
					
						
							|  |  |  |         args.N = len(files) | 
					
						
							|  |  |  |     random_files = random.sample(files, args.N) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-18 17:53:46 +00:00
										 |  |  |     print("Loading Llama2 tokenizer...") | 
					
						
							|  |  |  |     tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-18 17:14:56 +00:00
										 |  |  |     total_matches = 0 | 
					
						
							| 
									
										
										
										
											2025-02-18 17:53:46 +00:00
										 |  |  |     total_ngrams_sampled = 0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     for key in tqdm(random_files, desc="Processing files"): | 
					
						
							| 
									
										
										
										
											2025-02-18 17:14:56 +00:00
										 |  |  |         line = get_random_line_from_s3(bucket, key) | 
					
						
							|  |  |  |         if not line: | 
					
						
							|  |  |  |             print(f"Skipping {key}: No valid lines found.") | 
					
						
							|  |  |  |             continue | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             doc = json.loads(line) | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             print(f"Error parsing JSON in {key}: {e}") | 
					
						
							|  |  |  |             continue | 
					
						
							| 
									
										
										
										
											2025-02-25 08:57:02 -08:00
										 |  |  |         doc_id, match_count, sample_count, details = process_document(doc, tokenizer, args.ngram_size, args.num_ngrams, index=args.index) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-18 17:53:46 +00:00
										 |  |  |         # Print per-document n-gram summary | 
					
						
							|  |  |  |         print(f"\nDocument ID: {doc_id}") | 
					
						
							|  |  |  |         for flag, ngram in details: | 
					
						
							|  |  |  |             # Print the flag in a fixed-width field (4 characters) followed by the n-gram representation. | 
					
						
							|  |  |  |             print(f"{flag:4} {repr(ngram)}") | 
					
						
							|  |  |  |         percentage = (match_count / sample_count * 100) if sample_count else 0 | 
					
						
							|  |  |  |         print(f"Matched n-grams: {match_count}/{sample_count} ({percentage:.2f}%)") | 
					
						
							| 
									
										
										
										
											2025-02-25 08:57:02 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-18 17:14:56 +00:00
										 |  |  |         total_matches += match_count | 
					
						
							| 
									
										
										
										
											2025-02-18 17:53:46 +00:00
										 |  |  |         total_ngrams_sampled += sample_count | 
					
						
							| 
									
										
										
										
											2025-02-18 17:14:56 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-18 17:53:46 +00:00
										 |  |  |     overall_percentage = (total_matches / total_ngrams_sampled * 100) if total_ngrams_sampled else 0 | 
					
						
							|  |  |  |     print(f"\nTotal matched n-grams: {total_matches}/{total_ngrams_sampled} ({overall_percentage:.2f}%)") | 
					
						
							| 
									
										
										
										
											2025-02-18 17:14:56 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-25 08:57:02 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-18 17:14:56 +00:00
										 |  |  | if __name__ == "__main__": | 
					
						
							|  |  |  |     main() |