Use native ntlk download (#3796)

This PR changes how we download NLTK data to use the native nltk
downloader.

We had moved to our own hosted NLTK dataset because of this CVE:
https://nvd.nist.gov/vuln/detail/CVE-2024-39705

Ref: https://github.com/Unstructured-IO/unstructured/pull/3361

Latest versions of NLTK have fixed this issue:
https://github.com/nltk/nltk/blob/develop/ChangeLog
This commit is contained in:
Nathan Van Gheem 2024-12-02 14:30:28 -05:00 committed by GitHub
parent 9445a2dd01
commit 0fb814db61
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 12 additions and 86 deletions

View File

@ -1,4 +1,4 @@
## 0.16.9-dev0
## 0.16.9
### Enhancements
@ -6,6 +6,8 @@
### Fixes
- **Fix NLTK Download** to not download from unstructured S3 Bucket
## 0.16.8
### Enhancements

View File

@ -8,6 +8,7 @@ from unstructured.nlp import tokenize
def test_nltk_packages_download_if_not_present():
tokenize._download_nltk_packages_if_not_present.cache_clear()
with patch.object(nltk, "find", side_effect=LookupError):
with patch.object(tokenize, "download_nltk_packages") as mock_download:
tokenize._download_nltk_packages_if_not_present()
@ -16,6 +17,7 @@ def test_nltk_packages_download_if_not_present():
def test_nltk_packages_do_not_download_if():
tokenize._download_nltk_packages_if_not_present.cache_clear()
with patch.object(nltk, "find"), patch.object(nltk, "download") as mock_download:
tokenize._download_nltk_packages_if_not_present()

View File

@ -1 +1 @@
__version__ = "0.16.9-dev0" # pragma: no cover
__version__ = "0.16.9" # pragma: no cover

View File

@ -1,11 +1,6 @@
from __future__ import annotations
import hashlib
import os
import sys
import tarfile
import tempfile
import urllib.request
from functools import lru_cache
from typing import Final, List, Tuple
@ -16,86 +11,10 @@ from nltk import word_tokenize as _word_tokenize
CACHE_MAX_SIZE: Final[int] = 128
NLTK_DATA_FILENAME = "nltk_data_3.8.2.tar.gz"
NLTK_DATA_URL = f"https://utic-public-cf.s3.amazonaws.com/{NLTK_DATA_FILENAME}"
NLTK_DATA_SHA256 = "ba2ca627c8fb1f1458c15d5a476377a5b664c19deeb99fd088ebf83e140c1663"
# NOTE(robinson) - mimic default dir logic from NLTK
# https://github.com/nltk/nltk/
# blob/8c233dc585b91c7a0c58f96a9d99244a379740d5/nltk/downloader.py#L1046
def get_nltk_data_dir() -> str | None:
"""Locates the directory the nltk data will be saved too. The directory
set by the NLTK environment variable takes highest precedence. Otherwise
the default is determined by the rules indicated below. Returns None when
the directory is not writable.
On Windows, the default download directory is
``PYTHONHOME/lib/nltk``, where *PYTHONHOME* is the
directory containing Python, e.g. ``C:\\Python311``.
On all other platforms, the default directory is the first of
the following which exists or which can be created with write
permission: ``/usr/share/nltk_data``, ``/usr/local/share/nltk_data``,
``/usr/lib/nltk_data``, ``/usr/local/lib/nltk_data``, ``~/nltk_data``.
"""
# Check if we are on GAE where we cannot write into filesystem.
if "APPENGINE_RUNTIME" in os.environ:
return
# Check if we have sufficient permissions to install in a
# variety of system-wide locations.
for nltkdir in nltk.data.path:
if os.path.exists(nltkdir) and nltk.internals.is_writable(nltkdir):
return nltkdir
# On Windows, use %APPDATA%
if sys.platform == "win32" and "APPDATA" in os.environ:
homedir = os.environ["APPDATA"]
# Otherwise, install in the user's home directory.
else:
homedir = os.path.expanduser("~/")
if homedir == "~/":
raise ValueError("Could not find a default download directory")
# NOTE(robinson) - NLTK appends nltk_data to the homedir. That's already
# present in the tar file so we don't have to do that here.
return homedir
def download_nltk_packages():
nltk_data_dir = get_nltk_data_dir()
if nltk_data_dir is None:
raise OSError("NLTK data directory does not exist or is not writable.")
# Check if the path ends with "nltk_data" and remove it if it does
if nltk_data_dir.endswith("nltk_data"):
nltk_data_dir = os.path.dirname(nltk_data_dir)
def sha256_checksum(filename: str, block_size: int = 65536):
sha256 = hashlib.sha256()
with open(filename, "rb") as f:
for block in iter(lambda: f.read(block_size), b""):
sha256.update(block)
return sha256.hexdigest()
with tempfile.TemporaryDirectory() as temp_dir_path:
tgz_file_path = os.path.join(temp_dir_path, NLTK_DATA_FILENAME)
urllib.request.urlretrieve(NLTK_DATA_URL, tgz_file_path)
file_hash = sha256_checksum(tgz_file_path)
if file_hash != NLTK_DATA_SHA256:
os.remove(tgz_file_path)
raise ValueError(f"SHA-256 mismatch: expected {NLTK_DATA_SHA256}, got {file_hash}")
# Extract the contents
if not os.path.exists(nltk_data_dir):
os.makedirs(nltk_data_dir)
with tarfile.open(tgz_file_path, "r:gz") as tar:
tar.extractall(path=nltk_data_dir)
nltk.download("averaged_perceptron_tagger_eng", quiet=True)
nltk.download("punkt_tab", quiet=True)
def check_for_nltk_package(package_name: str, package_category: str) -> bool:
@ -109,10 +28,13 @@ def check_for_nltk_package(package_name: str, package_category: str) -> bool:
try:
nltk.find(f"{package_category}/{package_name}", paths=paths)
return True
except LookupError:
except (LookupError, OSError):
return False
# We cache this because we do not want to attempt
# downloading the packages multiple times
@lru_cache()
def _download_nltk_packages_if_not_present():
"""If required NLTK packages are not available, download them."""