diff --git a/haystack/components/audio/whisper_local.py b/haystack/components/audio/whisper_local.py index 920a94732..2aca04daa 100644 --- a/haystack/components/audio/whisper_local.py +++ b/haystack/components/audio/whisper_local.py @@ -4,15 +4,12 @@ import logging import tempfile from pathlib import Path -from haystack import component, Document, default_to_dict, ComponentError +from haystack import component, Document, default_to_dict, ComponentError, default_from_dict from haystack.dataclasses import ByteStream from haystack.lazy_imports import LazyImport +from haystack.utils import ComponentDevice -with LazyImport( - "Run 'pip install transformers[torch]' to install torch and " - "'pip install \"openai-whisper>=20231106\"' to install whisper." -) as whisper_import: - import torch +with LazyImport("Run 'pip install \"openai-whisper>=20231106\"' to install whisper.") as whisper_import: import whisper @@ -33,14 +30,14 @@ class LocalWhisperTranscriber: def __init__( self, model: WhisperLocalModel = "large", - device: Optional[str] = None, + device: Optional[ComponentDevice] = None, whisper_params: Optional[Dict[str, Any]] = None, ): """ :param model: Name of the model to use. Set it to one of the following values: :type model: Literal["tiny", "small", "medium", "large", "large-v2"] - :param device: Name of the torch device to use for inference. If None, CPU is used. - :type device: Optional[str] + :param device: The device on which the model is loaded. If `None`, the default device is automatically + selected. """ whisper_import.check() if model not in get_args(WhisperLocalModel): @@ -49,7 +46,7 @@ class LocalWhisperTranscriber: ) self.model = model self.whisper_params = whisper_params or {} - self.device = torch.device(device) if device else torch.device("cpu") + self.device = ComponentDevice.resolve_device(device) self._model = None def warm_up(self) -> None: @@ -57,13 +54,23 @@ class LocalWhisperTranscriber: Loads the model. """ if not self._model: - self._model = whisper.load_model(self.model, device=self.device) + self._model = whisper.load_model(self.model, device=self.device.to_torch()) def to_dict(self) -> Dict[str, Any]: """ Serialize this component to a dictionary. """ - return default_to_dict(self, model=self.model, device=str(self.device), whisper_params=self.whisper_params) + return default_to_dict(self, model=self.model, device=self.device.to_dict(), whisper_params=self.whisper_params) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "LocalWhisperTranscriber": + """ + Create a `LocalWhisperTranscriber` instance from a dictionary. + """ + serialized_device = data["init_parameters"]["device"] + data["init_parameters"]["device"] = ComponentDevice.from_dict(serialized_device) + + return default_from_dict(cls, data) @component.output_types(documents=List[Document]) def run(self, sources: List[Union[str, Path, ByteStream]], whisper_params: Optional[Dict[str, Any]] = None): diff --git a/releasenotes/notes/whisper-loc-new-devices-0665a24cd92ee4b6.yaml b/releasenotes/notes/whisper-loc-new-devices-0665a24cd92ee4b6.yaml new file mode 100644 index 000000000..709529ced --- /dev/null +++ b/releasenotes/notes/whisper-loc-new-devices-0665a24cd92ee4b6.yaml @@ -0,0 +1,23 @@ +--- +upgrade: + - | + Adopt the new framework-agnostic device management in Local Whisper Transcriber. + + Before this change: + ```python + from haystack.components.audio import LocalWhisperTranscriber + + transcriber = LocalWhisperTranscriber(device="cuda:0") + ``` + + After this change: + ```python + from haystack.utils.device import ComponentDevice, Device + from haystack.components.audio import LocalWhisperTranscriber + + device = ComponentDevice.from_single(Device.gpu(id=0)) + # or + # device = ComponentDevice.from_str("cuda:0") + + transcriber = LocalWhisperTranscriber(device=device) + ``` diff --git a/test/components/audio/test_whisper_local.py b/test/components/audio/test_whisper_local.py index b2089a1bc..b4b1cd9cd 100644 --- a/test/components/audio/test_whisper_local.py +++ b/test/components/audio/test_whisper_local.py @@ -7,6 +7,7 @@ import torch from haystack.dataclasses import Document, ByteStream from haystack.components.audio import LocalWhisperTranscriber +from haystack.utils.device import ComponentDevice, Device SAMPLES_PATH = Path(__file__).parent.parent.parent / "test_files" @@ -18,7 +19,7 @@ class TestLocalWhisperTranscriber: model="large-v2" ) # Doesn't matter if it's huge, the model is not loaded in init. assert transcriber.model == "large-v2" - assert transcriber.device == torch.device("cpu") + assert transcriber.device == ComponentDevice.resolve_device(None) assert transcriber._model is None def test_init_wrong_model(self): @@ -30,23 +31,44 @@ class TestLocalWhisperTranscriber: data = transcriber.to_dict() assert data == { "type": "haystack.components.audio.whisper_local.LocalWhisperTranscriber", - "init_parameters": {"model": "large", "device": "cpu", "whisper_params": {}}, + "init_parameters": { + "model": "large", + "device": ComponentDevice.resolve_device(None).to_dict(), + "whisper_params": {}, + }, } def test_to_dict_with_custom_init_parameters(self): transcriber = LocalWhisperTranscriber( - model="tiny", device="cuda", whisper_params={"return_segments": True, "temperature": [0.1, 0.6, 0.8]} + model="tiny", + device=ComponentDevice.from_str("cuda:0"), + whisper_params={"return_segments": True, "temperature": [0.1, 0.6, 0.8]}, ) data = transcriber.to_dict() assert data == { "type": "haystack.components.audio.whisper_local.LocalWhisperTranscriber", "init_parameters": { "model": "tiny", - "device": "cuda", + "device": ComponentDevice.from_str("cuda:0").to_dict(), "whisper_params": {"return_segments": True, "temperature": [0.1, 0.6, 0.8]}, }, } + def test_from_dict(self): + data = { + "type": "haystack.components.audio.whisper_local.LocalWhisperTranscriber", + "init_parameters": { + "model": "tiny", + "device": ComponentDevice.from_single(Device.cpu()).to_dict(), + "whisper_params": {}, + }, + } + transcriber = LocalWhisperTranscriber.from_dict(data) + assert transcriber.model == "tiny" + assert transcriber.device == ComponentDevice.from_single(Device.cpu()) + assert transcriber.whisper_params == {} + assert transcriber._model is None + def test_warmup(self): with patch("haystack.components.audio.whisper_local.whisper") as mocked_whisper: transcriber = LocalWhisperTranscriber(model="large-v2")