mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-24 09:20:13 +00:00
26 lines
1.0 KiB
Python
26 lines
1.0 KiB
Python
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
import logging
|
|
from haystack.utils.hf import resolve_hf_device_map
|
|
from haystack.utils.device import ComponentDevice
|
|
|
|
|
|
def test_resolve_hf_device_map_only_device():
|
|
model_kwargs = resolve_hf_device_map(device=None, model_kwargs={})
|
|
assert model_kwargs["device_map"] == ComponentDevice.resolve_device(None).to_hf()
|
|
|
|
|
|
def test_resolve_hf_device_map_only_device_map():
|
|
model_kwargs = resolve_hf_device_map(device=None, model_kwargs={"device_map": "cpu"})
|
|
assert model_kwargs["device_map"] == "cpu"
|
|
|
|
|
|
def test_resolve_hf_device_map_device_and_device_map(caplog):
|
|
with caplog.at_level(logging.WARNING):
|
|
model_kwargs = resolve_hf_device_map(
|
|
device=ComponentDevice.from_str("cpu"), model_kwargs={"device_map": "cuda:0"}
|
|
)
|
|
assert "The parameters `device` and `device_map` from `model_kwargs` are both provided." in caplog.text
|
|
assert model_kwargs["device_map"] == "cuda:0"
|