haystack/test/utils/test_hf.py
Sebastian Husch Lee ceda4cd655
feat: Add support for device_map (#6679)
* Getting device_map working to support 8bit loading and multi device inference

* Update to take account the device specified by the user

* add release notes

* Add device_map support for ExtractiveReader

* Update test

* Update to model that doesn't have issues

* Update test

* Update pytest approx

* Update release notes

* Start supporting device map

* Update ExtractiveReader to use new ComponentDevice

* Update similarity ranker to follow extractive reader implementation

* Fixing pylint

* Make mypy mostly happy

* Add new unit test to test device_map

* Adding unit tests

* Some refactoring

* Add more tests

* Add more tests

* Add another unit test

* Update first_device property to return a ComponentDevice to be able to use the to methods

* Updating tests for test_device

* Update tests and now explicitly modify device_map in model_kwargs

* Update haystack/utils/hf.py

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>

* Make mypy happy

* mypy

* Remove unneeded optional flag

* Update ExtractiveReader with new logic

* Update ranker to follow new logic

* Removing unneeded code

* Make mypy happy

* fxi pylint

* Fix test

* Adding unit tests for device_map="auto"

* Add unit tests for ranker

* PR comments

* Make util method

* Adding unit tests

* Fix type annotation

* Fix pylint

* Fix test

---------

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
2024-01-30 13:47:57 +01:00

23 lines
926 B
Python

import logging
from haystack.utils.hf import resolve_hf_device_map
from haystack.utils.device import ComponentDevice
def test_resolve_hf_device_map_only_device():
model_kwargs = resolve_hf_device_map(device=None, model_kwargs={})
assert model_kwargs["device_map"] == ComponentDevice.resolve_device(None).to_hf()
def test_resolve_hf_device_map_only_device_map():
model_kwargs = resolve_hf_device_map(device=None, model_kwargs={"device_map": "cpu"})
assert model_kwargs["device_map"] == "cpu"
def test_resolve_hf_device_map_device_and_device_map(caplog):
with caplog.at_level(logging.WARNING):
model_kwargs = resolve_hf_device_map(
device=ComponentDevice.from_str("cpu"), model_kwargs={"device_map": "cuda:0"}
)
assert "The parameters `device` and `device_map` from `model_kwargs` are both provided." in caplog.text
assert model_kwargs["device_map"] == "cuda:0"