From cbf44413d81b0acb87cddd17744e09f9f3eebc74 Mon Sep 17 00:00:00 2001 From: Sara Zan Date: Fri, 21 Oct 2022 13:58:17 +0200 Subject: [PATCH] feat: add `__cointains__` to `Span` (#3446) * add __contains__ * add tests --- haystack/schema.py | 41 ++++++++++++++++++++++++++++++++++++++ test/others/test_schema.py | 29 +++++++++++++++++++++++++++ 2 files changed, 70 insertions(+) diff --git a/haystack/schema.py b/haystack/schema.py index 4349a9283..71cbc35ae 100644 --- a/haystack/schema.py +++ b/haystack/schema.py @@ -308,6 +308,47 @@ class Span: :param end: Position where the spand ends """ + def __contains__(self, value): + """ + Checks for inclusion of the given value into the interval defined by Span. + ``` + assert 10 in Span(5, 15) # True + assert 20 in Span(1, 15) # False + ``` + Includes the left edge, but not the right edge. + ``` + assert 5 in Span(5, 15) # True + assert 15 in Span(5, 15) # False + ``` + Works for numbers and all values that can be safely converted into floats. + ``` + assert 10.0 in Span(5, 15) # True + assert "10" in Span(5, 15) # True + ``` + It also works for Span objects, returning True only if the given + Span is fully contained into the original Span. + As for numerical values, the left edge is included, the right edge is not. + ``` + assert Span(10, 11) in Span(5, 15) # True + assert Span(5, 10) in Span(5, 15) # True + assert Span(10, 15) in Span(5, 15) # False + assert Span(5, 15) in Span(5, 15) # False + assert Span(5, 14) in Span(5, 15) # True + assert Span(0, 1) in Span(5, 15) # False + assert Span(0, 10) in Span(5, 15) # False + assert Span(10, 20) in Span(5, 15) # False + ``` + """ + if isinstance(value, Span): + return self.start <= value.start and self.end > value.end + try: + value = float(value) + return self.start <= value < self.end + except Exception as e: + raise ValueError( + f"Cannot use 'in' with a value of type {type(value)}. Use numeric values or Span objects." + ) from e + @dataclass class Answer: diff --git a/test/others/test_schema.py b/test/others/test_schema.py index 8de117061..987eb83dc 100644 --- a/test/others/test_schema.py +++ b/test/others/test_schema.py @@ -467,3 +467,32 @@ def test_deserialize_speech_answer(): context_audio=SAMPLES_PATH / "audio" / "the context for this answer is here.wav", ) assert speech_answer == SpeechAnswer.from_dict(speech_answer.to_dict()) + + +def test_span_in(): + assert 10 in Span(5, 15) + assert not 20 in Span(1, 15) + + +def test_span_in_edges(): + assert 5 in Span(5, 15) + assert not 15 in Span(5, 15) + + +def test_span_in_other_values(): + assert 10.0 in Span(5, 15) + assert "10" in Span(5, 15) + with pytest.raises(ValueError): + "hello" in Span(5, 15) + + +def test_assert_span_vs_span(): + assert Span(10, 11) in Span(5, 15) + assert Span(5, 10) in Span(5, 15) + assert not Span(10, 15) in Span(5, 15) + assert not Span(5, 15) in Span(5, 15) + assert Span(5, 14) in Span(5, 15) + + assert not Span(0, 1) in Span(5, 15) + assert not Span(0, 10) in Span(5, 15) + assert not Span(10, 20) in Span(5, 15)