haystack/test/utils/test_device.py
Sebastian Husch Lee c294b8ac8c
feat: Add auto device checks and model_kwargs to TransformersSimilarityRanker (#6561)
* Add device checking and model_kwargs like we do in ExtractiveReader

* Add release notes

* Make a utility function for the device checking

* Better warning message and updated ExtractiveReader to use the util function

* Add unit tests for get_device

* Fix pylint
2023-12-18 15:13:42 +01:00

29 lines
905 B
Python

from unittest.mock import patch
from haystack.utils import get_device
@patch("torch.cuda.is_available")
def test_get_device_cuda(torch_cuda_is_available):
torch_cuda_is_available.return_value = True
device = get_device()
assert device == "cuda:0"
@patch("torch.backends.mps.is_available")
@patch("torch.cuda.is_available")
def test_get_device_mps(torch_cuda_is_available, torch_backends_mps_is_available):
torch_cuda_is_available.return_value = False
torch_backends_mps_is_available.return_value = True
device = get_device()
assert device == "mps:0"
@patch("torch.backends.mps.is_available")
@patch("torch.cuda.is_available")
def test_get_device_cpu(torch_cuda_is_available, torch_backends_mps_is_available):
torch_cuda_is_available.return_value = False
torch_backends_mps_is_available.return_value = False
device = get_device()
assert device == "cpu:0"