mirror of
https://github.com/Azure-Samples/graphrag-accelerator.git
synced 2025-06-27 04:39:57 +00:00
60 lines
2.1 KiB
Python
60 lines
2.1 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
|
|
from src.logger.blob_workflow_callbacks import BlobWorkflowCallbacks
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_blob_service_client():
|
|
with patch(
|
|
"src.logger.blob_workflow_callbacks.BlobServiceClient"
|
|
) as mock_blob_service_client:
|
|
yield mock_blob_service_client
|
|
|
|
|
|
@pytest.fixture
|
|
def workflow_callbacks(mock_blob_service_client):
|
|
with patch(
|
|
"src.logger.blob_workflow_callbacks.BlobWorkflowCallbacks.__init__",
|
|
return_value=None,
|
|
):
|
|
instance = BlobWorkflowCallbacks()
|
|
instance._blob_service_client = mock_blob_service_client
|
|
instance._index_name = "mock_index_name"
|
|
instance._container_name = "logs"
|
|
instance._blob_name = "logs/logs.txt"
|
|
instance._num_workflow_steps = 4
|
|
instance._processed_workflow_steps = []
|
|
instance._workflow_name = ""
|
|
yield instance
|
|
|
|
|
|
def test_on_workflow_start(workflow_callbacks):
|
|
workflow_callbacks.on_workflow_start("test_workflow", object())
|
|
# check if blob workflow callbacks _write_log() method was called
|
|
assert workflow_callbacks._blob_service_client.get_blob_client().append_block.called
|
|
|
|
|
|
def test_on_workflow_end(workflow_callbacks):
|
|
workflow_callbacks.on_workflow_end("test_workflow", object())
|
|
assert workflow_callbacks._blob_service_client.get_blob_client().append_block.called
|
|
|
|
|
|
# def test_on_workflow_step_start(workflow_callbacks):
|
|
# workflow_callbacks.on_workflow_step_start("test_step", object())
|
|
# assert workflow_callbacks._blob_service_client.get_blob_client().append_block.called
|
|
|
|
|
|
# def test_on_workflow_step_end(workflow_callbacks):
|
|
# workflow_callbacks.on_workflow_step_end("test_step", object())
|
|
# assert workflow_callbacks._blob_service_client.get_blob_client().append_block.called
|
|
|
|
|
|
def test_on_error(workflow_callbacks):
|
|
workflow_callbacks.on_error("test_error", Exception("test_exception"))
|
|
assert workflow_callbacks._blob_service_client.get_blob_client().append_block.called
|