From afc1e1ccef525b6a913b8d008a504073bde00a8c Mon Sep 17 00:00:00 2001 From: Daniel Bichuetti Date: Wed, 25 Jan 2023 07:37:29 -0300 Subject: [PATCH] fix: add tiktoken fallback mechanism. (#3929) * feat: migrate to tiktoken when tokenizing for OpenAI * refactor: add OpenAI optional egg * fix: add Python 3.7 fallback support for tiktoken * refactor: change both tokenization implementations and fix mypy * refactor: remove dummy-class * refactor: add tiktoken as core dependency and minor refactoring * refactor: sort imports * refactor: remove out-of-scope PR change * refactor: reintroduce corner case check * refactor: remove unused egg * refactor: remove unused exception after titkoken as core dep * refactor: reduce ifs and include log warning * refactor: remove timeout linting ignore * refactor: revert change due to mypy * refactor: disable pylint import error * fix: add arm64 fallback to HF tokenizer * fix: add aarch64 fallback mechanism * refactor: improve log message * fix: change platform selection method * refactor: consolidate archs --- haystack/nodes/answer_generator/openai.py | 10 ++++++++-- haystack/nodes/retriever/_openai_encoder.py | 10 ++++++++-- pyproject.toml | 2 +- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/haystack/nodes/answer_generator/openai.py b/haystack/nodes/answer_generator/openai.py index c44769920..ab6f2f0f0 100644 --- a/haystack/nodes/answer_generator/openai.py +++ b/haystack/nodes/answer_generator/openai.py @@ -2,6 +2,7 @@ import json import logging import sys from typing import List, Optional, Tuple, Union +import platform import requests @@ -12,14 +13,19 @@ from haystack.utils.reflection import retry_with_exponential_backoff logger = logging.getLogger(__name__) +machine = platform.machine() +system = platform.system() + USE_TIKTOKEN = False -if sys.version_info >= (3, 8): +if sys.version_info >= (3, 8) and (machine in ["amd64", "x86_64"] or (machine == "arm64" and system == "Darwin")): USE_TIKTOKEN = True if USE_TIKTOKEN: import tiktoken # pylint: disable=import-error else: - logger.warning("OpenAI tiktoken module is not available for Python < 3.8. Falling back to GPT2TokenizerFast.") + logger.warning( + "OpenAI tiktoken module is not available for Python < 3.8,Linux ARM64 and AARCH64. Falling back to GPT2TokenizerFast." + ) from transformers import GPT2TokenizerFast, PreTrainedTokenizerFast diff --git a/haystack/nodes/retriever/_openai_encoder.py b/haystack/nodes/retriever/_openai_encoder.py index e5063bef7..3569e3be0 100644 --- a/haystack/nodes/retriever/_openai_encoder.py +++ b/haystack/nodes/retriever/_openai_encoder.py @@ -3,6 +3,7 @@ import logging import sys from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +import platform import numpy as np import requests @@ -19,14 +20,19 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +machine = platform.machine() +system = platform.system() + USE_TIKTOKEN = False -if sys.version_info >= (3, 8): +if sys.version_info >= (3, 8) and (machine in ["amd64", "x86_64"] or (machine == "arm64" and system == "Darwin")): USE_TIKTOKEN = True if USE_TIKTOKEN: import tiktoken # pylint: disable=import-error else: - logger.warning("OpenAI tiktoken module is not available for Python < 3.8. Falling back to GPT2TokenizerFast.") + logger.warning( + "OpenAI tiktoken module is not available for Python < 3.8,Linux ARM64 and AARCH64. Falling back to GPT2TokenizerFast." + ) from transformers import GPT2TokenizerFast, PreTrainedTokenizerFast diff --git a/pyproject.toml b/pyproject.toml index c8cbbc6cf..ec33231bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,7 +88,7 @@ dependencies = [ "elasticsearch>=7.7,<8", # OpenAI tokenizer - "tiktoken>=0.1.2; python_version >= '3.8'", + "tiktoken>=0.1.2; python_version >= '3.8' and (platform_machine == 'amd64' or platform_machine == 'x86_64' or (platform_machine == 'arm64' and platform_system == 'Darwin'))", # context matching "rapidfuzz>=2.0.15,<2.8.0", # FIXME https://github.com/deepset-ai/haystack/pull/3199