| 
									
										
										
										
											2024-05-25 11:22:51 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). | 
					
						
							|  |  |  | # Source for "Build a Large Language Model From Scratch" | 
					
						
							|  |  |  | #   - https://www.manning.com/books/build-a-large-language-model-from-scratch | 
					
						
							|  |  |  | # Code: https://github.com/rasbt/LLMs-from-scratch | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import argparse | 
					
						
							|  |  |  | import json | 
					
						
							| 
									
										
										
										
											2024-05-26 14:25:09 -05:00
										 |  |  | import re | 
					
						
							| 
									
										
										
										
											2024-05-25 18:03:37 -05:00
										 |  |  | from sklearn import __version__ as sklearn_version | 
					
						
							| 
									
										
										
										
											2024-05-25 11:22:51 -05:00
										 |  |  | from sklearn.feature_extraction.text import TfidfVectorizer | 
					
						
							|  |  |  | from sklearn.metrics.pairwise import cosine_similarity | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # Sample JSON dataset | 
					
						
							|  |  |  | example_data = [ | 
					
						
							| 
									
										
										
										
											2024-05-25 11:38:55 -05:00
										 |  |  |     {"instruction": "What is the capital of Italy?", | 
					
						
							|  |  |  |      "input": "", "output": "The capital of Italy is Rome." | 
					
						
							|  |  |  |      }, | 
					
						
							|  |  |  |     {"instruction": "What's the capital city of Italy?", | 
					
						
							|  |  |  |      "input": "", "output": "The capital city is Rome." | 
					
						
							|  |  |  |      }, | 
					
						
							|  |  |  |     {"instruction": "Identify the main verb in the sentence: 'The cat sleeps on the couch.'", | 
					
						
							|  |  |  |      "input": "", "output": "The verb is 'sleeps'." | 
					
						
							|  |  |  |      }, | 
					
						
							|  |  |  |     {"instruction": "Identify the verb in the following sentence: The cat sleeps on the couch.", | 
					
						
							|  |  |  |      "input": "", "output": "The verb in the sentence is \"sleeps.\"" | 
					
						
							|  |  |  |      }, | 
					
						
							|  |  |  |     # ... | 
					
						
							| 
									
										
										
										
											2024-05-25 11:22:51 -05:00
										 |  |  | ] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-05-26 14:25:09 -05:00
										 |  |  | def preprocess_text(text): | 
					
						
							|  |  |  |     # Lowercase the text | 
					
						
							|  |  |  |     text = text.lower() | 
					
						
							|  |  |  |     # Remove punctuation | 
					
						
							|  |  |  |     text = re.sub(r'[^\w\s]', '', text) | 
					
						
							|  |  |  |     return text | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def find_near_duplicates(json_data, threshold=0.75, key="instruction"): | 
					
						
							| 
									
										
										
										
											2024-05-25 11:22:51 -05:00
										 |  |  |     """The higher the threshold, the more similar the texts have to be to match""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Extract instructions | 
					
						
							| 
									
										
										
										
											2024-05-26 14:25:09 -05:00
										 |  |  |     text = [preprocess_text(item[key]) for item in json_data if item[key]] | 
					
						
							| 
									
										
										
										
											2024-05-25 11:22:51 -05:00
										 |  |  |     near_duplicates = [] | 
					
						
							| 
									
										
										
										
											2024-05-26 14:25:09 -05:00
										 |  |  |     indices_to_remove = set() | 
					
						
							| 
									
										
										
										
											2024-05-25 11:22:51 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  |     if not text: | 
					
						
							| 
									
										
										
										
											2024-05-26 14:25:09 -05:00
										 |  |  |         return {}, near_duplicates | 
					
						
							| 
									
										
										
										
											2024-05-25 11:22:51 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # Vectorize the text data | 
					
						
							| 
									
										
										
										
											2024-05-26 14:25:09 -05:00
										 |  |  |     vectorizer = TfidfVectorizer(stop_words=None, analyzer='char', ngram_range=(1, 3)) | 
					
						
							| 
									
										
										
										
											2024-05-25 11:22:51 -05:00
										 |  |  |     tfidf_matrix = vectorizer.fit_transform(text) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Compute cosine similarity between each pair of entries | 
					
						
							|  |  |  |     cos_sim_matrix = cosine_similarity(tfidf_matrix) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Find pairs of near-duplicate instructions based on the threshold | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     for i in range(len(cos_sim_matrix)): | 
					
						
							|  |  |  |         for j in range(i+1, len(cos_sim_matrix)): | 
					
						
							|  |  |  |             if cos_sim_matrix[i, j] > threshold: | 
					
						
							| 
									
										
										
										
											2024-05-26 14:28:30 -05:00
										 |  |  |                 if len(json_data[i][key]) <= 1 or len(json_data[j][key]) <= 1: | 
					
						
							| 
									
										
										
										
											2024-05-26 14:25:09 -05:00
										 |  |  |                     continue | 
					
						
							| 
									
										
										
										
											2024-05-25 11:22:51 -05:00
										 |  |  |                 near_duplicates.append((json_data[i], json_data[j], cos_sim_matrix[i, j])) | 
					
						
							| 
									
										
										
										
											2024-05-26 14:25:09 -05:00
										 |  |  |                 if key in ("input", "output"):  # Don't remove duplicates based on the instruction | 
					
						
							|  |  |  |                     indices_to_remove.add(j)  # Mark the second entry for removal | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Remove the near-duplicate entries | 
					
						
							|  |  |  |     filtered_json_data = [item for index, item in enumerate(json_data) if index not in indices_to_remove] | 
					
						
							| 
									
										
										
										
											2024-05-25 11:22:51 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-05-26 14:25:09 -05:00
										 |  |  |     return filtered_json_data, near_duplicates | 
					
						
							| 
									
										
										
										
											2024-05-25 11:22:51 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-05-26 14:25:09 -05:00
										 |  |  | def find_print_and_remove_near_duplicates(json_data, remove_duplicates=False, threshold=0.75): | 
					
						
							| 
									
										
										
										
											2024-05-25 11:38:55 -05:00
										 |  |  |     """
 | 
					
						
							|  |  |  |     Searches each key in the first JSON object for duplicates across a list of JSON objects. | 
					
						
							|  |  |  |     Prints the duplicates if found. | 
					
						
							|  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2024-05-25 11:22:51 -05:00
										 |  |  |     for key in json_data[0].keys(): | 
					
						
							| 
									
										
										
										
											2024-05-26 14:25:09 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  |         if remove_duplicates: | 
					
						
							|  |  |  |             json_data, near_duplicates = find_near_duplicates(json_data, key=key, threshold=threshold) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             _, near_duplicates = find_near_duplicates(json_data, key=key, threshold=threshold) | 
					
						
							| 
									
										
										
										
											2024-05-25 11:38:55 -05:00
										 |  |  |         separator = 50 * '=' | 
					
						
							|  |  |  |         print(f"\n\n{separator}\nSearching '{key}' for duplicates ...\n{separator}") | 
					
						
							| 
									
										
										
										
											2024-05-25 11:22:51 -05:00
										 |  |  |         if not near_duplicates: | 
					
						
							|  |  |  |             print("No duplicates found") | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             for dup in near_duplicates: | 
					
						
							| 
									
										
										
										
											2024-05-25 11:38:55 -05:00
										 |  |  |                 print( | 
					
						
							|  |  |  |                     f"Duplicate pair found with similarity {dup[2]:.2f}:\n" | 
					
						
							|  |  |  |                     f"1. {dup[0][key]}\n2. {dup[1][key]}\n" | 
					
						
							|  |  |  |                 ) | 
					
						
							| 
									
										
										
										
											2024-05-26 14:25:09 -05:00
										 |  |  |     return json_data | 
					
						
							| 
									
										
										
										
											2024-05-25 11:22:51 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if __name__ == "__main__": | 
					
						
							| 
									
										
										
										
											2024-05-25 18:03:37 -05:00
										 |  |  |     print("scikit-learn version:", sklearn_version) | 
					
						
							| 
									
										
										
										
											2024-05-25 11:22:51 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  |     parser = argparse.ArgumentParser() | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--json_file", | 
					
						
							|  |  |  |         type=str, | 
					
						
							|  |  |  |         help=("Path to the dataset JSON file") | 
					
						
							| 
									
										
										
										
											2024-05-25 11:42:59 -05:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2024-05-26 14:25:09 -05:00
										 |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--threshold", | 
					
						
							|  |  |  |         type=float, | 
					
						
							|  |  |  |         default=0.9, | 
					
						
							|  |  |  |         help=("A sensitivity threshold between 0 and 1 where 1 is strictest") | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--remove_duplicates", | 
					
						
							|  |  |  |         action='store_true', | 
					
						
							|  |  |  |         default=False, | 
					
						
							|  |  |  |         help=( | 
					
						
							|  |  |  |             "Removes duplicates based on the 'input' or 'output' keys " | 
					
						
							|  |  |  |             " (but not the 'instruction') and saves the cleaned JSON file as --json_output_file" | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |         "--json_output_file", | 
					
						
							|  |  |  |         type=str, | 
					
						
							|  |  |  |         help=("Path to the dataset JSON file") | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-05-25 11:22:51 -05:00
										 |  |  |     args = parser.parse_args() | 
					
						
							| 
									
										
										
										
											2024-05-26 14:25:09 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  |     if args.remove_duplicates and not args.json_output_file: | 
					
						
							|  |  |  |         raise ValueError( | 
					
						
							|  |  |  |             "Provide an output file via --json_output_file " | 
					
						
							|  |  |  |             "to save the cleaned JSON data." | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-05-25 11:22:51 -05:00
										 |  |  |     if not args.json_file: | 
					
						
							|  |  |  |         json_data = example_data | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         with open(args.json_file, "r") as file: | 
					
						
							|  |  |  |             json_data = json.load(file) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-05-26 14:25:09 -05:00
										 |  |  |     json_data = find_print_and_remove_near_duplicates( | 
					
						
							|  |  |  |         json_data=json_data, | 
					
						
							|  |  |  |         remove_duplicates=args.remove_duplicates, | 
					
						
							|  |  |  |         threshold=args.threshold | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if args.remove_duplicates: | 
					
						
							|  |  |  |         with open(args.json_output_file, "w") as file: | 
					
						
							|  |  |  |             json.dump(json_data, file, indent=4) |