fix: add tiktoken fallback mechanism. (#3929)

* feat: migrate to tiktoken when tokenizing for OpenAI

* refactor: add OpenAI optional egg

* fix: add Python 3.7 fallback support for tiktoken

* refactor: change both tokenization implementations and fix mypy

* refactor: remove dummy-class

* refactor: add tiktoken as core dependency and minor refactoring

* refactor: sort imports

* refactor: remove out-of-scope PR change

* refactor: reintroduce corner case check

* refactor: remove unused egg

* refactor: remove unused exception after titkoken as core dep

* refactor: reduce ifs and include log warning

* refactor: remove timeout linting ignore

* refactor: revert change due to mypy

* refactor: disable pylint import error

* fix: add arm64 fallback to HF tokenizer

* fix: add aarch64 fallback mechanism

* refactor: improve log message

* fix: change platform selection method

* refactor: consolidate archs
This commit is contained in:
Daniel Bichuetti 2023-01-25 07:37:29 -03:00 committed by GitHub
parent 5c53b2bd4a
commit afc1e1ccef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 17 additions and 5 deletions

View File

@ -2,6 +2,7 @@ import json
import logging
import sys
from typing import List, Optional, Tuple, Union
import platform
import requests
@ -12,14 +13,19 @@ from haystack.utils.reflection import retry_with_exponential_backoff
logger = logging.getLogger(__name__)
machine = platform.machine()
system = platform.system()
USE_TIKTOKEN = False
if sys.version_info >= (3, 8):
if sys.version_info >= (3, 8) and (machine in ["amd64", "x86_64"] or (machine == "arm64" and system == "Darwin")):
USE_TIKTOKEN = True
if USE_TIKTOKEN:
import tiktoken # pylint: disable=import-error
else:
logger.warning("OpenAI tiktoken module is not available for Python < 3.8. Falling back to GPT2TokenizerFast.")
logger.warning(
"OpenAI tiktoken module is not available for Python < 3.8,Linux ARM64 and AARCH64. Falling back to GPT2TokenizerFast."
)
from transformers import GPT2TokenizerFast, PreTrainedTokenizerFast

View File

@ -3,6 +3,7 @@ import logging
import sys
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import platform
import numpy as np
import requests
@ -19,14 +20,19 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
machine = platform.machine()
system = platform.system()
USE_TIKTOKEN = False
if sys.version_info >= (3, 8):
if sys.version_info >= (3, 8) and (machine in ["amd64", "x86_64"] or (machine == "arm64" and system == "Darwin")):
USE_TIKTOKEN = True
if USE_TIKTOKEN:
import tiktoken # pylint: disable=import-error
else:
logger.warning("OpenAI tiktoken module is not available for Python < 3.8. Falling back to GPT2TokenizerFast.")
logger.warning(
"OpenAI tiktoken module is not available for Python < 3.8,Linux ARM64 and AARCH64. Falling back to GPT2TokenizerFast."
)
from transformers import GPT2TokenizerFast, PreTrainedTokenizerFast

View File

@ -88,7 +88,7 @@ dependencies = [
"elasticsearch>=7.7,<8",
# OpenAI tokenizer
"tiktoken>=0.1.2; python_version >= '3.8'",
"tiktoken>=0.1.2; python_version >= '3.8' and (platform_machine == 'amd64' or platform_machine == 'x86_64' or (platform_machine == 'arm64' and platform_system == 'Darwin'))",
# context matching
"rapidfuzz>=2.0.15,<2.8.0", # FIXME https://github.com/deepset-ai/haystack/pull/3199