haystack/test/conftest.py
Stefano Fiorucci 188b2a7f06
feat: support for tools in OpenAIChatGenerator (#8666)
* move chatmsg>openai conversion to chatmsg dataclass

* implementation and tests cleanup

* release note

* try fixing azure chat generator

* add serde test for toolinvoker

* small fix
2024-12-20 14:20:54 +00:00

65 lines
1.8 KiB
Python

# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from datetime import datetime
from pathlib import Path
from test.tracing.utils import SpyingTracer
from typing import Generator
from unittest.mock import Mock, patch
import pytest
from openai.types.chat import ChatCompletion, ChatCompletionMessage
from openai.types.chat.chat_completion import Choice
from haystack import tracing
from haystack.testing.test_utils import set_all_seeds
set_all_seeds(0)
# Tracing is disable by default to avoid failures in CI
tracing.disable_tracing()
@pytest.fixture()
def mock_tokenizer():
"""
Tokenizes the string by splitting on spaces.
"""
tokenizer = Mock()
tokenizer.encode = lambda text: text.split()
tokenizer.decode = lambda tokens: " ".join(tokens)
return tokenizer
@pytest.fixture()
def test_files_path():
return Path(__file__).parent / "test_files"
@pytest.fixture(autouse=True)
def request_blocker(request: pytest.FixtureRequest, monkeypatch):
"""
This fixture is applied automatically to all tests.
Those that are not marked as integration will have the requests module
monkeypatched to avoid making HTTP requests by mistake.
"""
marker = request.node.get_closest_marker("integration")
if marker is not None:
return
def urlopen_mock(self, method, url, *args, **kwargs):
raise RuntimeError(f"The test was about to {method} {self.scheme}://{self.host}{url}")
monkeypatch.setattr("urllib3.connectionpool.HTTPConnectionPool.urlopen", urlopen_mock)
@pytest.fixture()
def spying_tracer() -> Generator[SpyingTracer, None, None]:
tracer = SpyingTracer()
tracing.enable_tracing(tracer)
yield tracer
# Make sure to disable tracing after the test to avoid affecting other tests
tracing.disable_tracing()