mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-28 10:26:27 +00:00

* Add device checking and model_kwargs like we do in ExtractiveReader * Add release notes * Make a utility function for the device checking * Better warning message and updated ExtractiveReader to use the util function * Add unit tests for get_device * Fix pylint
29 lines
905 B
Python
29 lines
905 B
Python
from unittest.mock import patch
|
|
|
|
from haystack.utils import get_device
|
|
|
|
|
|
@patch("torch.cuda.is_available")
|
|
def test_get_device_cuda(torch_cuda_is_available):
|
|
torch_cuda_is_available.return_value = True
|
|
device = get_device()
|
|
assert device == "cuda:0"
|
|
|
|
|
|
@patch("torch.backends.mps.is_available")
|
|
@patch("torch.cuda.is_available")
|
|
def test_get_device_mps(torch_cuda_is_available, torch_backends_mps_is_available):
|
|
torch_cuda_is_available.return_value = False
|
|
torch_backends_mps_is_available.return_value = True
|
|
device = get_device()
|
|
assert device == "mps:0"
|
|
|
|
|
|
@patch("torch.backends.mps.is_available")
|
|
@patch("torch.cuda.is_available")
|
|
def test_get_device_cpu(torch_cuda_is_available, torch_backends_mps_is_available):
|
|
torch_cuda_is_available.return_value = False
|
|
torch_backends_mps_is_available.return_value = False
|
|
device = get_device()
|
|
assert device == "cpu:0"
|