| 
									
										
										
										
											2024-03-19 09:26:26 -05:00
										 |  |  | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). | 
					
						
							|  |  |  | # Source for "Build a Large Language Model From Scratch" | 
					
						
							|  |  |  | #   - https://www.manning.com/books/build-a-large-language-model-from-scratch | 
					
						
							|  |  |  | # Code: https://github.com/rasbt/LLMs-from-scratch | 
					
						
							| 
									
										
										
										
											2023-07-23 13:18:13 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-16 13:16:51 -06:00
										 |  |  | from importlib.metadata import PackageNotFoundError, import_module, version as get_version | 
					
						
							| 
									
										
										
										
											2024-06-18 19:20:45 -05:00
										 |  |  | from os.path import dirname, exists, join, realpath | 
					
						
							| 
									
										
										
										
											2023-07-23 13:18:13 -05:00
										 |  |  | from packaging.version import parse as version_parse | 
					
						
							| 
									
										
										
										
											2025-02-16 13:16:51 -06:00
										 |  |  | from packaging.requirements import Requirement | 
					
						
							|  |  |  | from packaging.specifiers import SpecifierSet | 
					
						
							| 
									
										
										
										
											2023-07-23 13:18:13 -05:00
										 |  |  | import platform | 
					
						
							|  |  |  | import sys | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-17 21:09:31 -05:00
										 |  |  | if version_parse(platform.python_version()) < version_parse("3.9"): | 
					
						
							| 
									
										
										
										
											2025-02-16 13:16:51 -06:00
										 |  |  |     print("[FAIL] We recommend Python 3.9 or newer but found version %s" % sys.version) | 
					
						
							| 
									
										
										
										
											2023-07-23 13:18:13 -05:00
										 |  |  | else: | 
					
						
							| 
									
										
										
										
											2025-02-16 13:16:51 -06:00
										 |  |  |     print("[OK] Your Python version is %s" % platform.python_version()) | 
					
						
							| 
									
										
										
										
											2023-07-23 13:18:13 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def get_packages(pkgs): | 
					
						
							| 
									
										
										
										
											2025-02-16 13:16:51 -06:00
										 |  |  |     """
 | 
					
						
							|  |  |  |     Returns a dictionary mapping package names (in lowercase) to their installed version. | 
					
						
							|  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2025-02-17 10:33:53 -07:00
										 |  |  |     PACKAGE_MODULE_OVERRIDES = { | 
					
						
							|  |  |  |         "tensorflow-cpu": ["tensorflow", "tensorflow_cpu"], | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2025-02-16 13:16:51 -06:00
										 |  |  |     result = {} | 
					
						
							| 
									
										
										
										
											2023-07-23 13:18:13 -05:00
										 |  |  |     for p in pkgs: | 
					
						
							| 
									
										
										
										
											2025-02-17 10:33:53 -07:00
										 |  |  |         # Determine possible module names to try. | 
					
						
							|  |  |  |         module_names = PACKAGE_MODULE_OVERRIDES.get(p.lower(), [p]) | 
					
						
							|  |  |  |         version_found = None | 
					
						
							|  |  |  |         for module_name in module_names: | 
					
						
							| 
									
										
										
										
											2023-07-23 13:18:13 -05:00
										 |  |  |             try: | 
					
						
							| 
									
										
										
										
											2025-02-17 10:33:53 -07:00
										 |  |  |                 imported = import_module(module_name) | 
					
						
							|  |  |  |                 version_found = getattr(imported, "__version__", None) | 
					
						
							|  |  |  |                 if version_found is None: | 
					
						
							|  |  |  |                     try: | 
					
						
							|  |  |  |                         version_found = get_version(module_name) | 
					
						
							|  |  |  |                     except PackageNotFoundError: | 
					
						
							|  |  |  |                         version_found = None | 
					
						
							|  |  |  |                 if version_found is not None: | 
					
						
							|  |  |  |                     break  # Stop if we successfully got a version. | 
					
						
							|  |  |  |             except ImportError: | 
					
						
							|  |  |  |                 # Also try replacing hyphens with underscores as a fallback. | 
					
						
							|  |  |  |                 alt_module = module_name.replace("-", "_") | 
					
						
							|  |  |  |                 if alt_module != module_name: | 
					
						
							|  |  |  |                     try: | 
					
						
							|  |  |  |                         imported = import_module(alt_module) | 
					
						
							|  |  |  |                         version_found = getattr(imported, "__version__", None) | 
					
						
							|  |  |  |                         if version_found is None: | 
					
						
							|  |  |  |                             try: | 
					
						
							|  |  |  |                                 version_found = get_version(alt_module) | 
					
						
							|  |  |  |                             except PackageNotFoundError: | 
					
						
							|  |  |  |                                 version_found = None | 
					
						
							|  |  |  |                         if version_found is not None: | 
					
						
							|  |  |  |                             break | 
					
						
							|  |  |  |                     except ImportError: | 
					
						
							|  |  |  |                         continue | 
					
						
							|  |  |  |                 continue | 
					
						
							|  |  |  |         if version_found is None: | 
					
						
							|  |  |  |             version_found = "0.0" | 
					
						
							|  |  |  |         result[p.lower()] = version_found | 
					
						
							| 
									
										
										
										
											2025-02-16 13:16:51 -06:00
										 |  |  |     return result | 
					
						
							| 
									
										
										
										
											2023-07-23 13:18:13 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def get_requirements_dict(): | 
					
						
							| 
									
										
										
										
											2025-02-16 13:16:51 -06:00
										 |  |  |     """
 | 
					
						
							|  |  |  |     Parses requirements.txt and returns a dictionary mapping package names (lowercase) | 
					
						
							|  |  |  |     to a specifier string (e.g. ">=2.18.0,<3.0"). It uses packaging.requirements.Requirement | 
					
						
							|  |  |  |     to properly handle environment markers. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-23 13:18:13 -05:00
										 |  |  |     PROJECT_ROOT = dirname(realpath(__file__)) | 
					
						
							| 
									
										
										
										
											2024-03-18 08:00:49 -05:00
										 |  |  |     PROJECT_ROOT_UP_TWO = dirname(dirname(PROJECT_ROOT)) | 
					
						
							|  |  |  |     REQUIREMENTS_FILE = join(PROJECT_ROOT_UP_TWO, "requirements.txt") | 
					
						
							| 
									
										
										
										
											2024-06-18 19:20:45 -05:00
										 |  |  |     if not exists(REQUIREMENTS_FILE): | 
					
						
							|  |  |  |         REQUIREMENTS_FILE = join(PROJECT_ROOT, "requirements.txt") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-16 13:16:51 -06:00
										 |  |  |     reqs = {} | 
					
						
							| 
									
										
										
										
											2023-07-23 13:18:13 -05:00
										 |  |  |     with open(REQUIREMENTS_FILE) as f: | 
					
						
							|  |  |  |         for line in f: | 
					
						
							| 
									
										
										
										
											2025-02-16 13:16:51 -06:00
										 |  |  |             # Remove inline comments and trailing whitespace. | 
					
						
							|  |  |  |             # This splits on the first '#' and takes the part before it. | 
					
						
							|  |  |  |             line = line.split("#", 1)[0].strip() | 
					
						
							|  |  |  |             if not line: | 
					
						
							| 
									
										
										
										
											2024-03-18 08:00:49 -05:00
										 |  |  |                 continue | 
					
						
							| 
									
										
										
										
											2025-02-16 13:16:51 -06:00
										 |  |  |             try: | 
					
						
							|  |  |  |                 req = Requirement(line) | 
					
						
							|  |  |  |             except Exception as e: | 
					
						
							|  |  |  |                 print(f"Skipping line due to parsing error: {line} ({e})") | 
					
						
							|  |  |  |                 continue | 
					
						
							|  |  |  |             # Evaluate the marker if present. | 
					
						
							|  |  |  |             if req.marker is not None and not req.marker.evaluate(): | 
					
						
							|  |  |  |                 continue | 
					
						
							|  |  |  |             # Store the package name and its version specifier. | 
					
						
							|  |  |  |             spec = str(req.specifier) if req.specifier else ">=0" | 
					
						
							|  |  |  |             reqs[req.name.lower()] = spec | 
					
						
							|  |  |  |     return reqs | 
					
						
							| 
									
										
										
										
											2023-07-23 13:18:13 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-16 13:16:51 -06:00
										 |  |  | def check_packages(reqs): | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     Checks the installed versions of packages against the requirements. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     installed = get_packages(reqs.keys()) | 
					
						
							|  |  |  |     for pkg_name, spec_str in reqs.items(): | 
					
						
							|  |  |  |         spec_set = SpecifierSet(spec_str) | 
					
						
							|  |  |  |         actual_ver = installed.get(pkg_name, "0.0") | 
					
						
							| 
									
										
										
										
											2024-06-17 21:09:31 -05:00
										 |  |  |         if actual_ver == "N/A": | 
					
						
							| 
									
										
										
										
											2023-07-23 13:18:13 -05:00
										 |  |  |             continue | 
					
						
							| 
									
										
										
										
											2025-02-16 13:16:51 -06:00
										 |  |  |         actual_ver_parsed = version_parse(actual_ver) | 
					
						
							|  |  |  |         # If the installed version is a pre-release, allow pre-releases in the specifier. | 
					
						
							|  |  |  |         if actual_ver_parsed.is_prerelease: | 
					
						
							|  |  |  |             spec_set.prereleases = True | 
					
						
							|  |  |  |         if actual_ver_parsed not in spec_set: | 
					
						
							|  |  |  |             print(f"[FAIL] {pkg_name} {actual_ver_parsed}, please install a version matching {spec_set}") | 
					
						
							| 
									
										
										
										
											2023-07-23 13:18:13 -05:00
										 |  |  |         else: | 
					
						
							| 
									
										
										
										
											2025-02-16 13:16:51 -06:00
										 |  |  |             print(f"[OK] {pkg_name} {actual_ver_parsed}") | 
					
						
							| 
									
										
										
										
											2023-07-23 13:18:13 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-18 11:58:37 -05:00
										 |  |  | def main(): | 
					
						
							| 
									
										
										
										
											2025-02-16 13:16:51 -06:00
										 |  |  |     reqs = get_requirements_dict() | 
					
						
							|  |  |  |     check_packages(reqs) | 
					
						
							| 
									
										
										
										
											2024-03-18 11:58:37 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-17 21:09:31 -05:00
										 |  |  | if __name__ == "__main__": | 
					
						
							| 
									
										
										
										
											2024-03-18 11:58:37 -05:00
										 |  |  |     main() |