From 8cfeed095db538d0b9cfa12c138998cb466e39ed Mon Sep 17 00:00:00 2001 From: Julian Risch Date: Wed, 17 May 2023 21:31:08 +0200 Subject: [PATCH] build: Remove mmh3 dependency (#4896) * build: Remove mmh3 dependency * resolve circular import * pylint * make mmh3.py sibling of schema.py * pylint import order * pylint * undo example changes * increase coverage in modeling module * increase coverage further * rename new unit tests --- haystack/__init__.py | 2 +- haystack/mmh3.py | 344 ++++++++++++++++++++++++++++++++ haystack/schema.py | 6 +- haystack/utils/squad_data.py | 5 +- pyproject.toml | 1 - test/agents/test_agent.py | 16 +- test/modeling/test_processor.py | 14 +- test/others/test_squad_data.py | 34 +++- test/utils/test_mmh3.py | 10 + 9 files changed, 421 insertions(+), 11 deletions(-) create mode 100644 haystack/mmh3.py create mode 100644 test/utils/test_mmh3.py diff --git a/haystack/__init__.py b/haystack/__init__.py index b527cee87..8db3d8311 100644 --- a/haystack/__init__.py +++ b/haystack/__init__.py @@ -31,7 +31,6 @@ generalimport( "magic", "markdown", "mlflow", - "mmh3", "more_itertools", "networkx", "nltk", @@ -94,6 +93,7 @@ from haystack.schema import Document, Answer, Label, MultiLabel, Span, Evaluatio from haystack.nodes.base import BaseComponent from haystack.pipelines.base import Pipeline from haystack.environment import set_pytorch_secure_model_loading +from haystack.mmh3 import hash128 # Enables torch's secure model loading through setting an env var. diff --git a/haystack/mmh3.py b/haystack/mmh3.py new file mode 100644 index 000000000..7dc482ded --- /dev/null +++ b/haystack/mmh3.py @@ -0,0 +1,344 @@ +import sys as _sys + +# based on https://github.com/wc-duck/pymmh3/blob/master/pymmh3.py + +if _sys.version_info > (3, 0): + + def xrange(a, b, c): + return range(a, b, c) + + def xencode(x): + if isinstance(x, (bytes, bytearray)): + return x + else: + return x.encode() + +else: + + def xencode(x): + return x + + +del _sys + + +def hash128(key, seed=0x0, x64arch=True): + """Implements 128bit murmur3 hash.""" + + def hash128_x64(key, seed): + """Implements 128bit murmur3 hash for x64.""" + + def fmix(k): + k ^= k >> 33 + k = (k * 0xFF51AFD7ED558CCD) & 0xFFFFFFFFFFFFFFFF + k ^= k >> 33 + k = (k * 0xC4CEB9FE1A85EC53) & 0xFFFFFFFFFFFFFFFF + k ^= k >> 33 + return k + + length = len(key) + nblocks = int(length / 16) + + h1 = seed + h2 = seed + + c1 = 0x87C37B91114253D5 + c2 = 0x4CF5AD432745937F + + # body + for block_start in xrange(0, nblocks * 8, 8): + # ??? big endian? + k1 = ( + key[2 * block_start + 7] << 56 + | key[2 * block_start + 6] << 48 + | key[2 * block_start + 5] << 40 + | key[2 * block_start + 4] << 32 + | key[2 * block_start + 3] << 24 + | key[2 * block_start + 2] << 16 + | key[2 * block_start + 1] << 8 + | key[2 * block_start + 0] + ) + + k2 = ( + key[2 * block_start + 15] << 56 + | key[2 * block_start + 14] << 48 + | key[2 * block_start + 13] << 40 + | key[2 * block_start + 12] << 32 + | key[2 * block_start + 11] << 24 + | key[2 * block_start + 10] << 16 + | key[2 * block_start + 9] << 8 + | key[2 * block_start + 8] + ) + + k1 = (c1 * k1) & 0xFFFFFFFFFFFFFFFF + k1 = (k1 << 31 | k1 >> 33) & 0xFFFFFFFFFFFFFFFF # inlined ROTL64 + k1 = (c2 * k1) & 0xFFFFFFFFFFFFFFFF + h1 ^= k1 + + h1 = (h1 << 27 | h1 >> 37) & 0xFFFFFFFFFFFFFFFF # inlined ROTL64 + h1 = (h1 + h2) & 0xFFFFFFFFFFFFFFFF + h1 = (h1 * 5 + 0x52DCE729) & 0xFFFFFFFFFFFFFFFF + + k2 = (c2 * k2) & 0xFFFFFFFFFFFFFFFF + k2 = (k2 << 33 | k2 >> 31) & 0xFFFFFFFFFFFFFFFF # inlined ROTL64 + k2 = (c1 * k2) & 0xFFFFFFFFFFFFFFFF + h2 ^= k2 + + h2 = (h2 << 31 | h2 >> 33) & 0xFFFFFFFFFFFFFFFF # inlined ROTL64 + h2 = (h1 + h2) & 0xFFFFFFFFFFFFFFFF + h2 = (h2 * 5 + 0x38495AB5) & 0xFFFFFFFFFFFFFFFF + + # tail + tail_index = nblocks * 16 + k1 = 0 + k2 = 0 + tail_size = length & 15 + + if tail_size >= 15: + k2 ^= key[tail_index + 14] << 48 + if tail_size >= 14: + k2 ^= key[tail_index + 13] << 40 + if tail_size >= 13: + k2 ^= key[tail_index + 12] << 32 + if tail_size >= 12: + k2 ^= key[tail_index + 11] << 24 + if tail_size >= 11: + k2 ^= key[tail_index + 10] << 16 + if tail_size >= 10: + k2 ^= key[tail_index + 9] << 8 + if tail_size >= 9: + k2 ^= key[tail_index + 8] + + if tail_size > 8: + k2 = (k2 * c2) & 0xFFFFFFFFFFFFFFFF + k2 = (k2 << 33 | k2 >> 31) & 0xFFFFFFFFFFFFFFFF # inlined ROTL64 + k2 = (k2 * c1) & 0xFFFFFFFFFFFFFFFF + h2 ^= k2 + + if tail_size >= 8: + k1 ^= key[tail_index + 7] << 56 + if tail_size >= 7: + k1 ^= key[tail_index + 6] << 48 + if tail_size >= 6: + k1 ^= key[tail_index + 5] << 40 + if tail_size >= 5: + k1 ^= key[tail_index + 4] << 32 + if tail_size >= 4: + k1 ^= key[tail_index + 3] << 24 + if tail_size >= 3: + k1 ^= key[tail_index + 2] << 16 + if tail_size >= 2: + k1 ^= key[tail_index + 1] << 8 + if tail_size >= 1: + k1 ^= key[tail_index + 0] + + if tail_size > 0: + k1 = (k1 * c1) & 0xFFFFFFFFFFFFFFFF + k1 = (k1 << 31 | k1 >> 33) & 0xFFFFFFFFFFFFFFFF # inlined ROTL64 + k1 = (k1 * c2) & 0xFFFFFFFFFFFFFFFF + h1 ^= k1 + + # finalization + h1 ^= length + h2 ^= length + + h1 = (h1 + h2) & 0xFFFFFFFFFFFFFFFF + h2 = (h1 + h2) & 0xFFFFFFFFFFFFFFFF + + h1 = fmix(h1) + h2 = fmix(h2) + + h1 = (h1 + h2) & 0xFFFFFFFFFFFFFFFF + h2 = (h1 + h2) & 0xFFFFFFFFFFFFFFFF + + return h2 << 64 | h1 + + def hash128_x86(key, seed): + """Implements 128bit murmur3 hash for x86.""" + + def fmix(h): + h ^= h >> 16 + h = (h * 0x85EBCA6B) & 0xFFFFFFFF + h ^= h >> 13 + h = (h * 0xC2B2AE35) & 0xFFFFFFFF + h ^= h >> 16 + return h + + length = len(key) + nblocks = int(length / 16) + + h1 = seed + h2 = seed + h3 = seed + h4 = seed + + c1 = 0x239B961B + c2 = 0xAB0E9789 + c3 = 0x38B34AE5 + c4 = 0xA1E38B93 + + # body + for block_start in xrange(0, nblocks * 16, 16): + k1 = ( + key[block_start + 3] << 24 + | key[block_start + 2] << 16 + | key[block_start + 1] << 8 + | key[block_start + 0] + ) + + k2 = ( + key[block_start + 7] << 24 + | key[block_start + 6] << 16 + | key[block_start + 5] << 8 + | key[block_start + 4] + ) + + k3 = ( + key[block_start + 11] << 24 + | key[block_start + 10] << 16 + | key[block_start + 9] << 8 + | key[block_start + 8] + ) + + k4 = ( + key[block_start + 15] << 24 + | key[block_start + 14] << 16 + | key[block_start + 13] << 8 + | key[block_start + 12] + ) + + k1 = (c1 * k1) & 0xFFFFFFFF + k1 = (k1 << 15 | k1 >> 17) & 0xFFFFFFFF # inlined ROTL32 + k1 = (c2 * k1) & 0xFFFFFFFF + h1 ^= k1 + + h1 = (h1 << 19 | h1 >> 13) & 0xFFFFFFFF # inlined ROTL32 + h1 = (h1 + h2) & 0xFFFFFFFF + h1 = (h1 * 5 + 0x561CCD1B) & 0xFFFFFFFF + + k2 = (c2 * k2) & 0xFFFFFFFF + k2 = (k2 << 16 | k2 >> 16) & 0xFFFFFFFF # inlined ROTL32 + k2 = (c3 * k2) & 0xFFFFFFFF + h2 ^= k2 + + h2 = (h2 << 17 | h2 >> 15) & 0xFFFFFFFF # inlined ROTL32 + h2 = (h2 + h3) & 0xFFFFFFFF + h2 = (h2 * 5 + 0x0BCAA747) & 0xFFFFFFFF + + k3 = (c3 * k3) & 0xFFFFFFFF + k3 = (k3 << 17 | k3 >> 15) & 0xFFFFFFFF # inlined ROTL32 + k3 = (c4 * k3) & 0xFFFFFFFF + h3 ^= k3 + + h3 = (h3 << 15 | h3 >> 17) & 0xFFFFFFFF # inlined ROTL32 + h3 = (h3 + h4) & 0xFFFFFFFF + h3 = (h3 * 5 + 0x96CD1C35) & 0xFFFFFFFF + + k4 = (c4 * k4) & 0xFFFFFFFF + k4 = (k4 << 18 | k4 >> 14) & 0xFFFFFFFF # inlined ROTL32 + k4 = (c1 * k4) & 0xFFFFFFFF + h4 ^= k4 + + h4 = (h4 << 13 | h4 >> 19) & 0xFFFFFFFF # inlined ROTL32 + h4 = (h1 + h4) & 0xFFFFFFFF + h4 = (h4 * 5 + 0x32AC3B17) & 0xFFFFFFFF + + # tail + tail_index = nblocks * 16 + k1 = 0 + k2 = 0 + k3 = 0 + k4 = 0 + tail_size = length & 15 + + if tail_size >= 15: + k4 ^= key[tail_index + 14] << 16 + if tail_size >= 14: + k4 ^= key[tail_index + 13] << 8 + if tail_size >= 13: + k4 ^= key[tail_index + 12] + + if tail_size > 12: + k4 = (k4 * c4) & 0xFFFFFFFF + k4 = (k4 << 18 | k4 >> 14) & 0xFFFFFFFF # inlined ROTL32 + k4 = (k4 * c1) & 0xFFFFFFFF + h4 ^= k4 + + if tail_size >= 12: + k3 ^= key[tail_index + 11] << 24 + if tail_size >= 11: + k3 ^= key[tail_index + 10] << 16 + if tail_size >= 10: + k3 ^= key[tail_index + 9] << 8 + if tail_size >= 9: + k3 ^= key[tail_index + 8] + + if tail_size > 8: + k3 = (k3 * c3) & 0xFFFFFFFF + k3 = (k3 << 17 | k3 >> 15) & 0xFFFFFFFF # inlined ROTL32 + k3 = (k3 * c4) & 0xFFFFFFFF + h3 ^= k3 + + if tail_size >= 8: + k2 ^= key[tail_index + 7] << 24 + if tail_size >= 7: + k2 ^= key[tail_index + 6] << 16 + if tail_size >= 6: + k2 ^= key[tail_index + 5] << 8 + if tail_size >= 5: + k2 ^= key[tail_index + 4] + + if tail_size > 4: + k2 = (k2 * c2) & 0xFFFFFFFF + k2 = (k2 << 16 | k2 >> 16) & 0xFFFFFFFF # inlined ROTL32 + k2 = (k2 * c3) & 0xFFFFFFFF + h2 ^= k2 + + if tail_size >= 4: + k1 ^= key[tail_index + 3] << 24 + if tail_size >= 3: + k1 ^= key[tail_index + 2] << 16 + if tail_size >= 2: + k1 ^= key[tail_index + 1] << 8 + if tail_size >= 1: + k1 ^= key[tail_index + 0] + + if tail_size > 0: + k1 = (k1 * c1) & 0xFFFFFFFF + k1 = (k1 << 15 | k1 >> 17) & 0xFFFFFFFF # inlined ROTL32 + k1 = (k1 * c2) & 0xFFFFFFFF + h1 ^= k1 + + # finalization + h1 ^= length + h2 ^= length + h3 ^= length + h4 ^= length + + h1 = (h1 + h2) & 0xFFFFFFFF + h1 = (h1 + h3) & 0xFFFFFFFF + h1 = (h1 + h4) & 0xFFFFFFFF + h2 = (h1 + h2) & 0xFFFFFFFF + h3 = (h1 + h3) & 0xFFFFFFFF + h4 = (h1 + h4) & 0xFFFFFFFF + + h1 = fmix(h1) + h2 = fmix(h2) + h3 = fmix(h3) + h4 = fmix(h4) + + h1 = (h1 + h2) & 0xFFFFFFFF + h1 = (h1 + h3) & 0xFFFFFFFF + h1 = (h1 + h4) & 0xFFFFFFFF + h2 = (h1 + h2) & 0xFFFFFFFF + h3 = (h1 + h3) & 0xFFFFFFFF + h4 = (h1 + h4) & 0xFFFFFFFF + + return h4 << 96 | h3 << 64 | h2 << 32 | h1 + + key = bytearray(xencode(key)) + + if x64arch: + return hash128_x64(key, seed) + else: + return hash128_x86(key, seed) diff --git a/haystack/schema.py b/haystack/schema.py index db4c32bc3..40992ac49 100644 --- a/haystack/schema.py +++ b/haystack/schema.py @@ -18,7 +18,6 @@ import json import ast from dataclasses import asdict -import mmh3 import numpy as np from numpy import ndarray import pandas as pd @@ -32,6 +31,7 @@ from pydantic.json import pydantic_encoder from pydantic.dataclasses import dataclass from haystack import is_imported +from haystack.mmh3 import hash128 logger = logging.getLogger(__name__) @@ -147,7 +147,7 @@ class Document: """ if id_hash_keys is None: - return "{:02x}".format(mmh3.hash128(str(self.content), signed=False)) + return "{:02x}".format(hash128(str(self.content))) final_hash_key = "" for attr in id_hash_keys: @@ -163,7 +163,7 @@ class Document: "Can't create 'Document': 'id_hash_keys' must contain at least one of ['content', 'meta'] or be set to None." ) - return "{:02x}".format(mmh3.hash128(final_hash_key, signed=False)) + return "{:02x}".format(hash128(final_hash_key)) def to_dict(self, field_map: Optional[Dict[str, Any]] = None) -> Dict: """ diff --git a/haystack/utils/squad_data.py b/haystack/utils/squad_data.py index 80ee9ac45..037be09e8 100644 --- a/haystack/utils/squad_data.py +++ b/haystack/utils/squad_data.py @@ -5,13 +5,12 @@ import json import random import pandas as pd from tqdm.auto import tqdm -import mmh3 from haystack import is_imported +from haystack.mmh3 import hash128 from haystack.schema import Document, Label, Answer from haystack.modeling.data_handler.processor import _read_squad_file - logger = logging.getLogger(__name__) @@ -112,7 +111,7 @@ class SquadData: title = document.get("title", "") for paragraph in document["paragraphs"]: context = paragraph["context"] - document_id = paragraph.get("document_id", "{:02x}".format(mmh3.hash128(str(context), signed=False))) + document_id = paragraph.get("document_id", "{:02x}".format(hash128(str(context)))) for question in paragraph["qas"]: q = question["question"] id = question["id"] diff --git a/pyproject.toml b/pyproject.toml index 61e896916..7518afb4a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,6 @@ dependencies = [ "dill", # pickle extension for (de-)serialization "tqdm", # progress bars in model download and training scripts "networkx", # graphs library - "mmh3", # fast hashing function (murmurhash3) "quantulum3", # quantities extraction from text "posthog", # telemetry "azure-ai-formrecognizer>=3.2.0b2", # forms reader diff --git a/test/agents/test_agent.py b/test/agents/test_agent.py index f1b1a4825..7bdaff5ba 100644 --- a/test/agents/test_agent.py +++ b/test/agents/test_agent.py @@ -7,7 +7,7 @@ from test.conftest import MockRetriever, MockPromptNode from unittest import mock import pytest -from haystack import BaseComponent, Answer +from haystack import BaseComponent, Answer, Document from haystack.agents import Agent, AgentStep from haystack.agents.base import Tool, ToolsManager from haystack.nodes import PromptModel, PromptNode, PromptTemplate @@ -276,6 +276,20 @@ def test_update_hash(): assert agent.hash == "5ac8eca2f92c9545adcce3682b80d4c5" +@pytest.mark.unit +def test_tool_fails_processing_dict_result(): + tool = Tool(name="name", pipeline_or_node=MockPromptNode(), description="description") + with pytest.raises(ValueError): + tool._process_result({"answer": "answer"}) + + +@pytest.mark.unit +def test_tool_processes_answer_result_and_document_result(): + tool = Tool(name="name", pipeline_or_node=MockPromptNode(), description="description") + assert tool._process_result(Answer(answer="answer")) == "answer" + assert tool._process_result(Document(content="content")) == "content" + + def test_invalid_agent_template(): pn = PromptNode() with pytest.raises(ValueError, match="some_non_existing_template not supported"): diff --git a/test/modeling/test_processor.py b/test/modeling/test_processor.py index 2f053fefc..9a45de953 100644 --- a/test/modeling/test_processor.py +++ b/test/modeling/test_processor.py @@ -1,10 +1,11 @@ import copy import logging +from pathlib import Path import pytest from transformers import AutoTokenizer -from haystack.modeling.data_handler.processor import SquadProcessor +from haystack.modeling.data_handler.processor import SquadProcessor, _is_json # during inference (parameter return_baskets = False) we do not convert labels @@ -300,6 +301,17 @@ def test_dataset_from_dicts_qa_label_conversion(samples_path, caplog=None): ], f"Processing labels for {model} has changed." +@pytest.mark.unit +def test_is_json_identifies_json_objects(): + """Test that _is_json correctly identifies json objects""" + # Paths to json files should be considered json + assert _is_json(Path("processor_config.json")) + # dicts should be considered json + assert _is_json({"a": 1}) + # non-serializable objects should not be considered json + assert not _is_json(AutoTokenizer) + + @pytest.mark.integration def test_dataset_from_dicts_auto_determine_max_answers(samples_path, caplog=None): """ diff --git a/test/others/test_squad_data.py b/test/others/test_squad_data.py index 7aa156458..216c9ff83 100644 --- a/test/others/test_squad_data.py +++ b/test/others/test_squad_data.py @@ -1,4 +1,6 @@ import pandas as pd +import pytest + from haystack.utils.squad_data import SquadData from haystack.utils.augment_squad import augment_squad from haystack.schema import Document, Label, Answer @@ -22,7 +24,8 @@ def test_squad_augmentation(samples_path): assert original_squad.count(unit="paragraph") == augmented_squad.count(unit="paragraph") * multiplication_factor -def test_squad_to_df(): +@pytest.mark.unit +def test_squad_data_converts_df_to_data(): df = pd.DataFrame( [["title", "context", "question", "id", "answer", 1, False]], columns=["title", "context", "question", "id", "answer_text", "answer_start", "is_impossible"], @@ -51,6 +54,35 @@ def test_squad_to_df(): assert result == expected_result +@pytest.mark.unit +def test_squad_data_converts_data_to_df(): + data = [ + { + "title": "title", + "paragraphs": [ + { + "context": "context", + "document_id": "document_id", + "qas": [ + { + "question": "question", + "id": "id", + "answers": [{"text": "answer", "answer_start": 1}], + "is_impossible": False, + } + ], + } + ], + } + ] + expected_result = pd.DataFrame( + [["title", "context", "question", "id", "answer", 1, False, "document_id"]], + columns=["title", "context", "question", "id", "answer_text", "answer_start", "is_impossible", "document_id"], + ) + result = SquadData.to_df(data) + assert result.equals(expected_result) + + def test_to_label_object(): squad_data_list = [ { diff --git a/test/utils/test_mmh3.py b/test/utils/test_mmh3.py new file mode 100644 index 000000000..127b1161a --- /dev/null +++ b/test/utils/test_mmh3.py @@ -0,0 +1,10 @@ +import pytest + +from haystack.mmh3 import hash128 + + +@pytest.mark.unit +def test_mmh3(): + content = "This is the document text" * 100 + hashed_content = hash128(content) + assert hashed_content == 305042678480070366459393623793278501577