mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-14 12:31:10 +00:00

* Getting device_map working to support 8bit loading and multi device inference * Update to take account the device specified by the user * add release notes * Add device_map support for ExtractiveReader * Update test * Update to model that doesn't have issues * Update test * Update pytest approx * Update release notes * Start supporting device map * Update ExtractiveReader to use new ComponentDevice * Update similarity ranker to follow extractive reader implementation * Fixing pylint * Make mypy mostly happy * Add new unit test to test device_map * Adding unit tests * Some refactoring * Add more tests * Add more tests * Add another unit test * Update first_device property to return a ComponentDevice to be able to use the to methods * Updating tests for test_device * Update tests and now explicitly modify device_map in model_kwargs * Update haystack/utils/hf.py Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com> * Make mypy happy * mypy * Remove unneeded optional flag * Update ExtractiveReader with new logic * Update ranker to follow new logic * Removing unneeded code * Make mypy happy * fxi pylint * Fix test * Adding unit tests for device_map="auto" * Add unit tests for ranker * PR comments * Make util method * Adding unit tests * Fix type annotation * Fix pylint * Fix test --------- Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
23 lines
926 B
Python
23 lines
926 B
Python
import logging
|
|
from haystack.utils.hf import resolve_hf_device_map
|
|
from haystack.utils.device import ComponentDevice
|
|
|
|
|
|
def test_resolve_hf_device_map_only_device():
|
|
model_kwargs = resolve_hf_device_map(device=None, model_kwargs={})
|
|
assert model_kwargs["device_map"] == ComponentDevice.resolve_device(None).to_hf()
|
|
|
|
|
|
def test_resolve_hf_device_map_only_device_map():
|
|
model_kwargs = resolve_hf_device_map(device=None, model_kwargs={"device_map": "cpu"})
|
|
assert model_kwargs["device_map"] == "cpu"
|
|
|
|
|
|
def test_resolve_hf_device_map_device_and_device_map(caplog):
|
|
with caplog.at_level(logging.WARNING):
|
|
model_kwargs = resolve_hf_device_map(
|
|
device=ComponentDevice.from_str("cpu"), model_kwargs={"device_map": "cuda:0"}
|
|
)
|
|
assert "The parameters `device` and `device_map` from `model_kwargs` are both provided." in caplog.text
|
|
assert model_kwargs["device_map"] == "cuda:0"
|