mirror of
				https://github.com/Unstructured-IO/unstructured.git
				synced 2025-10-30 01:17:43 +00:00 
			
		
		
		
	Use native ntlk download (#3796)
This PR changes how we download NLTK data to use the native nltk downloader. We had moved to our own hosted NLTK dataset because of this CVE: https://nvd.nist.gov/vuln/detail/CVE-2024-39705 Ref: https://github.com/Unstructured-IO/unstructured/pull/3361 Latest versions of NLTK have fixed this issue: https://github.com/nltk/nltk/blob/develop/ChangeLog
This commit is contained in:
		
							parent
							
								
									9445a2dd01
								
							
						
					
					
						commit
						0fb814db61
					
				| @ -1,4 +1,4 @@ | ||||
| ## 0.16.9-dev0 | ||||
| ## 0.16.9 | ||||
| 
 | ||||
| ### Enhancements | ||||
| 
 | ||||
| @ -6,6 +6,8 @@ | ||||
| 
 | ||||
| ### Fixes | ||||
| 
 | ||||
| - **Fix NLTK Download** to not download from unstructured S3 Bucket | ||||
| 
 | ||||
| ## 0.16.8 | ||||
| 
 | ||||
| ### Enhancements | ||||
|  | ||||
| @ -8,6 +8,7 @@ from unstructured.nlp import tokenize | ||||
| 
 | ||||
| 
 | ||||
| def test_nltk_packages_download_if_not_present(): | ||||
|     tokenize._download_nltk_packages_if_not_present.cache_clear() | ||||
|     with patch.object(nltk, "find", side_effect=LookupError): | ||||
|         with patch.object(tokenize, "download_nltk_packages") as mock_download: | ||||
|             tokenize._download_nltk_packages_if_not_present() | ||||
| @ -16,6 +17,7 @@ def test_nltk_packages_download_if_not_present(): | ||||
| 
 | ||||
| 
 | ||||
| def test_nltk_packages_do_not_download_if(): | ||||
|     tokenize._download_nltk_packages_if_not_present.cache_clear() | ||||
|     with patch.object(nltk, "find"), patch.object(nltk, "download") as mock_download: | ||||
|         tokenize._download_nltk_packages_if_not_present() | ||||
| 
 | ||||
|  | ||||
| @ -1 +1 @@ | ||||
| __version__ = "0.16.9-dev0"  # pragma: no cover | ||||
| __version__ = "0.16.9"  # pragma: no cover | ||||
|  | ||||
| @ -1,11 +1,6 @@ | ||||
| from __future__ import annotations | ||||
| 
 | ||||
| import hashlib | ||||
| import os | ||||
| import sys | ||||
| import tarfile | ||||
| import tempfile | ||||
| import urllib.request | ||||
| from functools import lru_cache | ||||
| from typing import Final, List, Tuple | ||||
| 
 | ||||
| @ -16,86 +11,10 @@ from nltk import word_tokenize as _word_tokenize | ||||
| 
 | ||||
| CACHE_MAX_SIZE: Final[int] = 128 | ||||
| 
 | ||||
| NLTK_DATA_FILENAME = "nltk_data_3.8.2.tar.gz" | ||||
| NLTK_DATA_URL = f"https://utic-public-cf.s3.amazonaws.com/{NLTK_DATA_FILENAME}" | ||||
| NLTK_DATA_SHA256 = "ba2ca627c8fb1f1458c15d5a476377a5b664c19deeb99fd088ebf83e140c1663" | ||||
| 
 | ||||
| 
 | ||||
| # NOTE(robinson) - mimic default dir logic from NLTK | ||||
| # https://github.com/nltk/nltk/ | ||||
| # 	blob/8c233dc585b91c7a0c58f96a9d99244a379740d5/nltk/downloader.py#L1046 | ||||
| def get_nltk_data_dir() -> str | None: | ||||
|     """Locates the directory the nltk data will be saved too. The directory | ||||
|     set by the NLTK environment variable takes highest precedence. Otherwise | ||||
|     the default is determined by the rules indicated below. Returns None when | ||||
|     the directory is not writable. | ||||
| 
 | ||||
|         On Windows, the default download directory is | ||||
|         ``PYTHONHOME/lib/nltk``, where *PYTHONHOME* is the | ||||
|         directory containing Python, e.g. ``C:\\Python311``. | ||||
| 
 | ||||
|         On all other platforms, the default directory is the first of | ||||
|         the following which exists or which can be created with write | ||||
|         permission: ``/usr/share/nltk_data``, ``/usr/local/share/nltk_data``, | ||||
|         ``/usr/lib/nltk_data``, ``/usr/local/lib/nltk_data``, ``~/nltk_data``. | ||||
|     """ | ||||
|     # Check if we are on GAE where we cannot write into filesystem. | ||||
|     if "APPENGINE_RUNTIME" in os.environ: | ||||
|         return | ||||
| 
 | ||||
|     # Check if we have sufficient permissions to install in a | ||||
|     # variety of system-wide locations. | ||||
|     for nltkdir in nltk.data.path: | ||||
|         if os.path.exists(nltkdir) and nltk.internals.is_writable(nltkdir): | ||||
|             return nltkdir | ||||
| 
 | ||||
|     # On Windows, use %APPDATA% | ||||
|     if sys.platform == "win32" and "APPDATA" in os.environ: | ||||
|         homedir = os.environ["APPDATA"] | ||||
| 
 | ||||
|     # Otherwise, install in the user's home directory. | ||||
|     else: | ||||
|         homedir = os.path.expanduser("~/") | ||||
|         if homedir == "~/": | ||||
|             raise ValueError("Could not find a default download directory") | ||||
| 
 | ||||
|     # NOTE(robinson) - NLTK appends nltk_data to the homedir. That's already | ||||
|     # present in the tar file so we don't have to do that here. | ||||
|     return homedir | ||||
| 
 | ||||
| 
 | ||||
| def download_nltk_packages(): | ||||
|     nltk_data_dir = get_nltk_data_dir() | ||||
| 
 | ||||
|     if nltk_data_dir is None: | ||||
|         raise OSError("NLTK data directory does not exist or is not writable.") | ||||
| 
 | ||||
|     # Check if the path ends with "nltk_data" and remove it if it does | ||||
|     if nltk_data_dir.endswith("nltk_data"): | ||||
|         nltk_data_dir = os.path.dirname(nltk_data_dir) | ||||
| 
 | ||||
|     def sha256_checksum(filename: str, block_size: int = 65536): | ||||
|         sha256 = hashlib.sha256() | ||||
|         with open(filename, "rb") as f: | ||||
|             for block in iter(lambda: f.read(block_size), b""): | ||||
|                 sha256.update(block) | ||||
|         return sha256.hexdigest() | ||||
| 
 | ||||
|     with tempfile.TemporaryDirectory() as temp_dir_path: | ||||
|         tgz_file_path = os.path.join(temp_dir_path, NLTK_DATA_FILENAME) | ||||
|         urllib.request.urlretrieve(NLTK_DATA_URL, tgz_file_path) | ||||
| 
 | ||||
|         file_hash = sha256_checksum(tgz_file_path) | ||||
|         if file_hash != NLTK_DATA_SHA256: | ||||
|             os.remove(tgz_file_path) | ||||
|             raise ValueError(f"SHA-256 mismatch: expected {NLTK_DATA_SHA256}, got {file_hash}") | ||||
| 
 | ||||
|         # Extract the contents | ||||
|         if not os.path.exists(nltk_data_dir): | ||||
|             os.makedirs(nltk_data_dir) | ||||
| 
 | ||||
|         with tarfile.open(tgz_file_path, "r:gz") as tar: | ||||
|             tar.extractall(path=nltk_data_dir) | ||||
|     nltk.download("averaged_perceptron_tagger_eng", quiet=True) | ||||
|     nltk.download("punkt_tab", quiet=True) | ||||
| 
 | ||||
| 
 | ||||
| def check_for_nltk_package(package_name: str, package_category: str) -> bool: | ||||
| @ -109,10 +28,13 @@ def check_for_nltk_package(package_name: str, package_category: str) -> bool: | ||||
|     try: | ||||
|         nltk.find(f"{package_category}/{package_name}", paths=paths) | ||||
|         return True | ||||
|     except LookupError: | ||||
|     except (LookupError, OSError): | ||||
|         return False | ||||
| 
 | ||||
| 
 | ||||
| # We cache this because we do not want to attempt | ||||
| # downloading the packages multiple times | ||||
| @lru_cache() | ||||
| def _download_nltk_packages_if_not_present(): | ||||
|     """If required NLTK packages are not available, download them.""" | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Nathan Van Gheem
						Nathan Van Gheem