mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-03 07:04:01 +00:00
29 lines
905 B
Python
29 lines
905 B
Python
![]() |
from unittest.mock import patch
|
||
|
|
||
|
from haystack.utils import get_device
|
||
|
|
||
|
|
||
|
@patch("torch.cuda.is_available")
|
||
|
def test_get_device_cuda(torch_cuda_is_available):
|
||
|
torch_cuda_is_available.return_value = True
|
||
|
device = get_device()
|
||
|
assert device == "cuda:0"
|
||
|
|
||
|
|
||
|
@patch("torch.backends.mps.is_available")
|
||
|
@patch("torch.cuda.is_available")
|
||
|
def test_get_device_mps(torch_cuda_is_available, torch_backends_mps_is_available):
|
||
|
torch_cuda_is_available.return_value = False
|
||
|
torch_backends_mps_is_available.return_value = True
|
||
|
device = get_device()
|
||
|
assert device == "mps:0"
|
||
|
|
||
|
|
||
|
@patch("torch.backends.mps.is_available")
|
||
|
@patch("torch.cuda.is_available")
|
||
|
def test_get_device_cpu(torch_cuda_is_available, torch_backends_mps_is_available):
|
||
|
torch_cuda_is_available.return_value = False
|
||
|
torch_backends_mps_is_available.return_value = False
|
||
|
device = get_device()
|
||
|
assert device == "cpu:0"
|