feat!: Framework-agnostic device management (#6748)

* feat: Framework-agnostic device management

* Add release note

* Linting

* Fix test

* Add `first_device` property, expand release notes, validate `ComponentDevice` state
This commit is contained in:
Madeesh Kannan 2024-01-17 10:41:34 +01:00 committed by GitHub
parent b8b8b5d5c6
commit 7376838922
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 794 additions and 156 deletions

View File

@ -1,6 +1,6 @@
import pytest
from haystack import Document, Pipeline, ComponentError
from haystack import ComponentError, Document, Pipeline
from haystack.components.extractors import NamedEntityAnnotation, NamedEntityExtractor, NamedEntityExtractorBackend
@ -43,9 +43,7 @@ def spacy_annotations():
def test_ner_extractor_init():
extractor = NamedEntityExtractor(
backend=NamedEntityExtractorBackend.HUGGING_FACE, model_name_or_path="dslim/bert-base-NER", device_id=-1
)
extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER")
with pytest.raises(ComponentError, match=r"not initialized"):
extractor.run(documents=[])
@ -57,9 +55,7 @@ def test_ner_extractor_init():
@pytest.mark.parametrize("batch_size", [1, 3])
def test_ner_extractor_hf_backend(raw_texts, hf_annotations, batch_size):
extractor = NamedEntityExtractor(
backend=NamedEntityExtractorBackend.HUGGING_FACE, model_name_or_path="dslim/bert-base-NER"
)
extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER")
extractor.warm_up()
_extract_and_check_predictions(extractor, raw_texts, hf_annotations, batch_size)
@ -67,7 +63,7 @@ def test_ner_extractor_hf_backend(raw_texts, hf_annotations, batch_size):
@pytest.mark.parametrize("batch_size", [1, 3])
def test_ner_extractor_spacy_backend(raw_texts, spacy_annotations, batch_size):
extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.SPACY, model_name_or_path="en_core_web_trf")
extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.SPACY, model="en_core_web_trf")
extractor.warm_up()
_extract_and_check_predictions(extractor, raw_texts, spacy_annotations, batch_size)
@ -78,9 +74,7 @@ def test_ner_extractor_in_pipeline(raw_texts, hf_annotations, batch_size):
pipeline = Pipeline()
pipeline.add_component(
name="ner_extractor",
instance=NamedEntityExtractor(
backend=NamedEntityExtractorBackend.HUGGING_FACE, model_name_or_path="dslim/bert-base-NER"
),
instance=NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER"),
)
outputs = pipeline.run(

View File

@ -4,8 +4,9 @@ from dataclasses import dataclass
from enum import Enum, EnumMeta
from typing import Any, Dict, List, Optional, Union
from ... import ComponentError, DeserializationError, Document, component, default_from_dict, default_to_dict
from ...lazy_imports import LazyImport
from haystack import ComponentError, DeserializationError, Document, component, default_from_dict, default_to_dict
from haystack.lazy_imports import LazyImport
from haystack.utils.device import ComponentDevice
with LazyImport(message="Run 'pip install transformers[torch]'") as transformers_import:
from transformers import AutoModelForTokenClassification, AutoTokenizer
@ -85,7 +86,7 @@ class NamedEntityExtractor:
backend: Union[str, NamedEntityExtractorBackend],
model: str,
pipeline_kwargs: Optional[Dict[str, Any]] = None,
device_id: int = -1,
device: Optional[ComponentDevice] = None,
) -> None:
"""
Construct a Named Entity extractor component.
@ -102,24 +103,25 @@ class NamedEntityExtractor:
pipeline can override these arguments.
Dependent on the backend.
:param device_id:
Identifier of the device on which the backend
is executed.
:param device:
The device on which the model is loaded. If `None`,
the default device is automatically selected.
To execute on the CPU, pass a value of `-1`.
To execute on the GPU, pass the GPU identifier.
If a device/device map is specified in `pipeline_kwargs`,
it overrides this parameter (only applicable to the HuggingFace
backend).
"""
if isinstance(backend, str):
backend = NamedEntityExtractorBackend(backend)
self._backend: _NerBackend
device = ComponentDevice.resolve_device(device)
if backend == NamedEntityExtractorBackend.HUGGING_FACE:
self._backend = _HfBackend(model_name_or_path=model, device_id=device_id, pipeline_kwargs=pipeline_kwargs)
self._backend = _HfBackend(model_name_or_path=model, device=device, pipeline_kwargs=pipeline_kwargs)
elif backend == NamedEntityExtractorBackend.SPACY:
self._backend = _SpacyBackend(
model_name_or_path=model, device_id=device_id, pipeline_kwargs=pipeline_kwargs
)
self._backend = _SpacyBackend(model_name_or_path=model, device=device, pipeline_kwargs=pipeline_kwargs)
else:
raise ComponentError(f"Unknown NER backend '{type(backend).__name__}' for extractor")
@ -152,13 +154,15 @@ class NamedEntityExtractor:
self,
backend=self._backend.type,
model=self._backend.model_name,
device_id=self._backend.device_id,
device=self._backend.device.to_dict(),
pipeline_kwargs=self._backend._pipeline_kwargs,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "NamedEntityExtractor":
try:
init_params = data["init_parameters"]
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
return default_from_dict(cls, data)
except Exception as e:
raise DeserializationError(f"Couldn't deserialize {cls.__name__} instance") from e
@ -190,10 +194,16 @@ class _NerBackend(ABC):
Base class for NER backends.
"""
def __init__(self, type: NamedEntityExtractorBackend, pipeline_kwargs: Optional[Dict[str, Any]] = None) -> None:
def __init__(
self,
type: NamedEntityExtractorBackend,
device: ComponentDevice,
pipeline_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__()
self._type = type
self._device = device
self._pipeline_kwargs = pipeline_kwargs if pipeline_kwargs is not None else {}
@abstractmethod
@ -233,12 +243,12 @@ class _NerBackend(ABC):
"""
@property
@abstractmethod
def device_id(self) -> int:
def device(self) -> ComponentDevice:
"""
Returns the identifier of the device on which
the backend's model is loaded.
"""
return self._device
@property
def type(self) -> NamedEntityExtractorBackend:
@ -254,7 +264,7 @@ class _HfBackend(_NerBackend):
"""
def __init__(
self, *, model_name_or_path: str, device_id: int, pipeline_kwargs: Optional[Dict[str, Any]] = None
self, *, model_name_or_path: str, device: ComponentDevice, pipeline_kwargs: Optional[Dict[str, Any]] = None
) -> None:
"""
Construct a Hugging Face NER backend.
@ -262,23 +272,21 @@ class _HfBackend(_NerBackend):
:param model_name_or_path:
Name of the model or a path to the Hugging Face
model on the local disk.
:param device_id:
Identifier of the device on which the backend
is executed.
:param device:
The device on which the model is loaded. If `None`,
the default device is automatically selected.
To execute on the CPU, pass a value of `-1`.
To execute on the GPU, pass the GPU identifier.
If a device/device map is specified in `pipeline_kwargs`,
it overrides this parameter.
:param pipeline_kwargs:
Keyword arguments passed to the pipeline. The
pipeline can override these arguments.
"""
super().__init__(NamedEntityExtractorBackend.HUGGING_FACE, pipeline_kwargs)
super().__init__(NamedEntityExtractorBackend.HUGGING_FACE, device, pipeline_kwargs)
transformers_import.check()
self._model_name_or_path = model_name_or_path
self._device_id = device_id
self.tokenizer: Optional[AutoTokenizer] = None
self.model: Optional[AutoModelForTokenClassification] = None
self.pipeline: Optional[HfPipeline] = None
@ -292,9 +300,9 @@ class _HfBackend(_NerBackend):
"model": self.model,
"tokenizer": self.tokenizer,
"aggregation_strategy": "simple",
"device": self.device_id,
}
pipeline_params.update({k: v for k, v in self._pipeline_kwargs.items() if k not in pipeline_params})
self.device.update_hf_kwargs(pipeline_params, overwrite=False)
self.pipeline = pipeline(**pipeline_params)
def annotate(self, texts: List[str], *, batch_size: int = 1) -> List[List[NamedEntityAnnotation]]:
@ -324,10 +332,6 @@ class _HfBackend(_NerBackend):
def model_name(self) -> str:
return self._model_name_or_path
@property
def device_id(self) -> int:
return self._device_id
class _SpacyBackend(_NerBackend):
"""
@ -335,7 +339,7 @@ class _SpacyBackend(_NerBackend):
"""
def __init__(
self, *, model_name_or_path: str, device_id: int, pipeline_kwargs: Optional[Dict[str, Any]] = None
self, *, model_name_or_path: str, device: ComponentDevice, pipeline_kwargs: Optional[Dict[str, Any]] = None
) -> None:
"""
Construct a spaCy NER backend.
@ -343,25 +347,23 @@ class _SpacyBackend(_NerBackend):
:param model_name_or_path:
Name of the model or a path to the spaCy
model on the local disk.
:param device_id:
Identifier of the device on which the backend
is executed.
To execute on the CPU, pass a value of `-1`.
To execute on the GPU, pass the GPU identifier.
:param device:
The device on which the model is loaded. If `None`,
the default device is automatically selected.
:param pipeline_kwargs:
Keyword arguments passed to the pipeline. The
pipeline can override these arguments.
"""
super().__init__(NamedEntityExtractorBackend.SPACY, pipeline_kwargs)
super().__init__(NamedEntityExtractorBackend.SPACY, device, pipeline_kwargs)
spacy_import.check()
self._model_name_or_path = model_name_or_path
self._device_id = device_id
self.pipeline: Optional[SpacyPipeline] = None
if self.device.has_multiple_devices:
raise ValueError("spaCy backend for named entity extractor only supports inference on single devices")
def initialize(self):
# We need to initialize the model on the GPU if needed.
with self._select_device():
@ -402,10 +404,6 @@ class _SpacyBackend(_NerBackend):
def model_name(self) -> str:
return self._model_name_or_path
@property
def device_id(self) -> int:
return self._device_id
@contextmanager
def _select_device(self):
"""
@ -418,10 +416,11 @@ class _SpacyBackend(_NerBackend):
# the active device in spaCy/Thinc, we can't do much
# about it as a consumer unless we start poking into their
# internals.
device_id = self._device.to_spacy()
try:
if self._device_id >= 0:
spacy.require_gpu(self._device_id)
if device_id >= 0:
spacy.require_gpu(device_id)
yield
finally:
if self._device_id >= 0:
if device_id >= 0:
spacy.require_cpu()

View File

@ -1,9 +1,10 @@
import logging
from typing import Any, Dict, List, Literal, Optional, Union
from haystack import component, default_to_dict, default_from_dict
from haystack import component, default_from_dict, default_to_dict
from haystack.components.generators.hf_utils import StopWordsCriteria
from haystack.lazy_imports import LazyImport
from haystack.utils import ComponentDevice
logger = logging.getLogger(__name__)
@ -41,7 +42,7 @@ class HuggingFaceLocalGenerator:
self,
model: str = "google/flan-t5-base",
task: Optional[Literal["text-generation", "text2text-generation"]] = None,
device: Optional[str] = None,
device: Optional[ComponentDevice] = None,
token: Optional[Union[str, bool]] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
@ -58,9 +59,8 @@ class HuggingFaceLocalGenerator:
If the task is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored.
If not specified, the component will attempt to infer the task from the model name,
calling the Hugging Face Hub API.
:param device: The device on which the model is loaded. (e.g., "cpu", "cuda:0").
If `device` or `device_map` is specified in the `huggingface_pipeline_kwargs`,
this parameter will be ignored.
:param device: The device on which the model is loaded. If `None`, the default device is automatically
selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
:param token: The token to use as HTTP bearer authorization for remote files.
If True, will use the token generated when running huggingface-cli login (stored in ~/.huggingface).
If the token is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored.
@ -92,12 +92,9 @@ class HuggingFaceLocalGenerator:
# otherwise, populate them with values from other init parameters
huggingface_pipeline_kwargs.setdefault("model", model)
huggingface_pipeline_kwargs.setdefault("token", token)
if (
device is not None
and "device" not in huggingface_pipeline_kwargs
and "device_map" not in huggingface_pipeline_kwargs
):
huggingface_pipeline_kwargs["device"] = device
device = ComponentDevice.resolve_device(device)
device.update_hf_kwargs(huggingface_pipeline_kwargs, overwrite=False)
# task identification and validation
if task is None:

View File

@ -1,10 +1,10 @@
import logging
from pathlib import Path
from typing import List, Union, Dict, Any, Optional
from typing import Any, Dict, List, Optional, Union
from haystack import ComponentError, Document, component, default_to_dict, default_from_dict
from haystack import ComponentError, Document, component, default_from_dict, default_to_dict
from haystack.lazy_imports import LazyImport
from haystack.utils import get_device
from haystack.utils import ComponentDevice
logger = logging.getLogger(__name__)
@ -38,7 +38,7 @@ class TransformersSimilarityRanker:
def __init__(
self,
model: Union[str, Path] = "cross-encoder/ms-marco-MiniLM-L-6-v2",
device: Optional[str] = "cpu",
device: Optional[ComponentDevice] = None,
token: Union[bool, str, None] = None,
top_k: int = 10,
meta_fields_to_embed: Optional[List[str]] = None,
@ -53,7 +53,8 @@ class TransformersSimilarityRanker:
:param model: The name or path of a pre-trained cross-encoder model
from the Hugging Face Hub.
:param device: The torch device (for example, cuda:0, cpu, mps) to which you want to limit model inference.
:param device: The device on which the model is loaded. If `None`, the default device is automatically
selected.
:param token: The API token used to download private models from Hugging Face.
If this parameter is set to `True`, the token generated when running
`transformers-cli login` (stored in ~/.huggingface) is used.
@ -75,7 +76,7 @@ class TransformersSimilarityRanker:
if top_k <= 0:
raise ValueError(f"top_k must be > 0, but got {top_k}")
self.top_k = top_k
self.device = device
self.device = ComponentDevice.resolve_device(device)
self.token = token
self._model = None
self.tokenizer = None
@ -90,6 +91,11 @@ class TransformersSimilarityRanker:
self.score_threshold = score_threshold
self.model_kwargs = model_kwargs or {}
if self.device.has_multiple_devices:
raise ValueError(
f"{type(TransformersSimilarityRanker).__name__} currently only supports inference on single devices"
)
def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
@ -101,11 +107,9 @@ class TransformersSimilarityRanker:
Warm up the model and tokenizer used for scoring the Documents.
"""
if self._model is None:
if self.device is None:
self.device = get_device()
self._model = AutoModelForSequenceClassification.from_pretrained(
self.model, token=self.token, **self.model_kwargs
).to(self.device)
).to(self.device.to_torch())
self._model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(self.model, token=self.token)
@ -115,7 +119,7 @@ class TransformersSimilarityRanker:
"""
serialization_dict = default_to_dict(
self,
device=self.device,
device=self.device.to_dict(),
model=self.model,
token=self.token if not isinstance(self.token, str) else None, # don't serialize valid tokens
top_k=self.top_k,
@ -151,6 +155,10 @@ class TransformersSimilarityRanker:
torch_and_transformers_import.check()
init_params = data.get("init_parameters", {})
model_kwargs = init_params.get("model_kwargs", {})
serialized_device = init_params.get("device", {})
init_params["device"] = ComponentDevice.from_dict(serialized_device)
# convert string to torch.dtype
# 1. torch_dtype and bnb_4bit_compute_dtype can be specified in model_kwargs
for key, value in model_kwargs.items():
@ -224,7 +232,7 @@ class TransformersSimilarityRanker:
features = self.tokenizer(
query_doc_pairs, padding=True, truncation=True, return_tensors="pt"
).to( # type: ignore
self.device
self.device.to_torch()
)
with torch.inference_mode():
similarity_scores = self._model(**features).logits.squeeze(dim=1) # type: ignore

View File

@ -1,17 +1,17 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import logging
import math
import warnings
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from haystack import component, default_to_dict, default_from_dict, ComponentError, Document, ExtractedAnswer
from haystack import ComponentError, Document, ExtractedAnswer, component, default_from_dict, default_to_dict
from haystack.lazy_imports import LazyImport
from haystack.utils import get_device
from haystack.utils import ComponentDevice
with LazyImport("Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import:
from transformers import AutoModelForQuestionAnswering, AutoTokenizer
from tokenizers import Encoding
import torch
from tokenizers import Encoding
from transformers import AutoModelForQuestionAnswering, AutoTokenizer
logger = logging.getLogger(__name__)
@ -39,7 +39,7 @@ class ExtractiveReader:
def __init__(
self,
model: Union[Path, str] = "deepset/roberta-base-squad2-distilled",
device: Optional[str] = None,
device: Optional[ComponentDevice] = None,
token: Union[bool, str, None] = None,
top_k: int = 20,
score_threshold: Optional[float] = None,
@ -57,7 +57,8 @@ class ExtractiveReader:
:param model: A Hugging Face transformers question answering model.
Can either be a path to a folder containing the model files or an identifier for the Hugging Face hub.
Default: `'deepset/roberta-base-squad2-distilled'`
:param device: Pytorch device string. Uses GPU by default, if available.
:param device: The device on which the model is loaded. If `None`, the default device is automatically
selected.
:param token: The API token used to download private models from Hugging Face.
If this parameter is set to `True`, then the token generated when running
`transformers-cli login` (stored in ~/.huggingface) is used.
@ -89,7 +90,7 @@ class ExtractiveReader:
torch_and_transformers_import.check()
self.model_name_or_path = str(model)
self.model = None
self.device = device
self.device = ComponentDevice.resolve_device(device)
self.token = token
self.max_seq_length = max_seq_length
self.top_k = top_k
@ -102,6 +103,9 @@ class ExtractiveReader:
self.model_kwargs = model_kwargs or {}
self.overlap_threshold = overlap_threshold
if self.device.has_multiple_devices:
raise ValueError(f"{type(ExtractiveReader).__name__} currently only supports inference on single devices")
def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
@ -115,7 +119,7 @@ class ExtractiveReader:
serialization_dict = default_to_dict(
self,
model=self.model_name_or_path,
device=self.device,
device=self.device.to_dict(),
token=self.token if not isinstance(self.token, str) else None,
max_seq_length=self.max_seq_length,
top_k=self.top_k,
@ -152,6 +156,10 @@ class ExtractiveReader:
torch_and_transformers_import.check()
init_params = data.get("init_parameters", {})
model_kwargs = init_params.get("model_kwargs", {})
serialized_device = init_params.get("device", {})
init_params["device"] = ComponentDevice.from_dict(serialized_device)
# convert string to torch.dtype
# 1. torch_dtype and bnb_4bit_compute_dtype can be specified in model_kwargs
for key, value in model_kwargs.items():
@ -164,7 +172,6 @@ class ExtractiveReader:
data["init_parameters"]["model_kwargs"]["quantization_config"]["bnb_4bit_compute_dtype"] = getattr(
torch, bnb_4bit_compute_dtype.strip("torch.")
)
return default_from_dict(cls, data)
def warm_up(self):
@ -172,11 +179,9 @@ class ExtractiveReader:
Loads model and tokenizer
"""
if self.model is None:
if self.device is None:
self.device = get_device()
self.model = AutoModelForQuestionAnswering.from_pretrained(
self.model_name_or_path, token=self.token, **self.model_kwargs
).to(self.device)
).to(self.device.to_torch())
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, token=self.token)
def _flatten_documents(
@ -218,8 +223,8 @@ class ExtractiveReader:
stride=stride,
)
input_ids = encodings_pt.input_ids.to(self.device)
attention_mask = encodings_pt.attention_mask.to(self.device)
input_ids = encodings_pt.input_ids.to(self.device.to_torch())
attention_mask = encodings_pt.attention_mask.to(self.device.to_torch())
query_ids = [query_ids[index] for index in encodings_pt.overflow_to_sample_mapping]
document_ids = [document_ids[sample_id] for sample_id in encodings_pt.overflow_to_sample_mapping]
@ -227,7 +232,7 @@ class ExtractiveReader:
encodings = encodings_pt.encodings
sequence_ids = torch.tensor(
[[id_ if id_ is not None else -1 for id_ in encoding.sequence_ids] for encoding in encodings]
).to(self.device)
).to(self.device.to_torch())
return input_ids, attention_mask, sequence_ids, encodings, query_ids, document_ids
@ -256,7 +261,7 @@ class ExtractiveReader:
# The mask here onwards is the same for all instances in the batch
# As such we do away with the batch dimension
mask = torch.ones(logits.shape[-2:], dtype=torch.bool, device=self.device)
mask = torch.ones(logits.shape[-2:], dtype=torch.bool, device=self.device.to_torch())
mask = torch.triu(mask) # End shouldn't be before start
masked_logits = torch.where(mask, logits, -torch.inf)
probabilities = torch.sigmoid(masked_logits * self.calibration_factor)

View File

@ -1,4 +1,4 @@
from haystack.utils.expit import expit
from haystack.utils.requests_utils import request_with_retry
from haystack.utils.filters import document_matches_filter
from haystack.utils.device import get_device
from haystack.utils.device import ComponentDevice, DeviceType, Device, DeviceMap

View File

@ -1,34 +1,512 @@
import logging
import os
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, Optional, Tuple, Union
from haystack.lazy_imports import LazyImport
logger = logging.getLogger(__name__)
with LazyImport(message="Using cpu, to use cuda or mps backends run 'pip install transformers[torch]'") as torch_import:
with LazyImport(
message="PyTorch must be installed to use torch.device or use GPU support in HuggingFace transformers. Run 'pip install transformers[torch]'"
) as torch_import:
import torch
def get_device() -> str:
class DeviceType(Enum):
"""
Detect and return an available torch device as string. Priority is given as follows if the device is available:
1) GPU, 2) MPS, 3) CPU
Represents device types supported by Haystack. This also includes
devices that are not directly used by models - for example, the disk
device is exclusively used in device maps for frameworks that support
offloading model weights to disk.
"""
CPU = "cpu"
GPU = "cuda"
DISK = "disk"
MPS = "mps"
def __str__(self):
return self.value
@staticmethod
def from_str(string: str) -> "DeviceType":
"""
Create a device type from a string.
:param string:
The string to convert.
:returns:
The device type.
"""
map = {e.value: e for e in DeviceType}
type = map.get(string)
if type is None:
raise ValueError(f"Unknown device type string '{string}'")
return type
@dataclass
class Device:
"""
A generic representation of a device.
:param type:
The device type.
:param id:
The optional device id.
"""
type: DeviceType
id: Optional[int] = field(default=None)
def __init__(self, type: DeviceType, id: Optional[int] = None):
"""
Create a generic device.
:param type:
The device type.
:param id:
The device id.
"""
if id is not None and id < 0:
raise ValueError(f"Device id must be >= 0, got {id}")
self.type = type
self.id = id
def __str__(self):
if self.id is None:
return str(self.type)
else:
return f"{self.type}:{self.id}"
@staticmethod
def cpu() -> "Device":
"""
Create a generic CPU device.
:returns:
The CPU device.
"""
return Device(DeviceType.CPU)
@staticmethod
def gpu(id: int = 0) -> "Device":
"""
Create a generic GPU device.
:param id:
The GPU id.
:returns:
The GPU device.
"""
return Device(DeviceType.GPU, id)
@staticmethod
def disk() -> "Device":
"""
Create a generic disk device.
:returns:
The disk device.
"""
return Device(DeviceType.DISK)
@staticmethod
def mps() -> "Device":
"""
Create a generic Apple Metal Performance Shader device.
:returns:
The MPS device.
"""
return Device(DeviceType.MPS)
@staticmethod
def from_str(string: str) -> "Device":
device_type_str, device_id = _split_device_string(string)
return Device(DeviceType.from_str(device_type_str), device_id)
@dataclass
class DeviceMap:
"""
A generic mapping from strings to devices. The semantics of the
strings are dependent on target framework. Primarily used to deploy
HuggingFace models to multiple devices.
:param mapping:
Dictionary mapping strings to devices.
"""
mapping: Dict[str, Device] = field(default_factory=dict, hash=False)
def __getitem__(self, key: str) -> Device:
return self.mapping[key]
def __setitem__(self, key: str, value: Device):
self.mapping[key] = value
def __contains__(self, key: str) -> bool:
return key in self.mapping
def __len__(self) -> int:
return len(self.mapping)
def __iter__(self):
return iter(self.mapping.items())
def to_dict(self) -> Dict[str, str]:
"""
Serialize the mapping to a JSON-serializable dictionary.
:returns:
The serialized mapping.
"""
return {key: str(device) for key, device in self.mapping.items()}
@property
def first_device(self) -> Optional[Device]:
"""
Return the first device in the mapping, if any.
:returns:
The first device.
"""
if not self.mapping:
return None
else:
return next(iter(self.mapping.values()))
@staticmethod
def from_dict(dict: Dict[str, str]) -> "DeviceMap":
"""
Create a generic device map from a JSON-serialized dictionary.
:param dict:
The serialized mapping.
:returns:
The generic device map.
"""
mapping = {}
for key, device_str in dict.items():
mapping[key] = Device.from_str(device_str)
return DeviceMap(mapping)
@staticmethod
def from_hf(hf_device_map: Dict[str, Union[int, str]]) -> "DeviceMap":
"""
Create a generic device map from a HuggingFace device map.
:param hf_device_map:
The HuggingFace device map.
:returns:
The deserialized device map.
"""
mapping = {}
for key, device in hf_device_map.items():
if isinstance(device, int):
mapping[key] = Device(DeviceType.GPU, device)
elif isinstance(device, str):
device_type, device_id = _split_device_string(device)
mapping[key] = Device(DeviceType.from_str(device_type), device_id)
else:
raise ValueError(
f"Couldn't convert HuggingFace device map - unexpected device '{str(device)}' for '{key}'"
)
return DeviceMap(mapping)
@dataclass(frozen=True)
class ComponentDevice:
"""
A representation of a device for a component. This can be either
a single device or a device map.
"""
_single_device: Optional[Device] = field(default=None)
_multiple_devices: Optional[DeviceMap] = field(default=None)
@classmethod
def from_str(cls, device_str: str) -> "ComponentDevice":
"""
Create a component device representation from a device string.
The device string can only represent a single device.
:param device_str:
The device string.
:returns:
The component device representation.
"""
device = Device.from_str(device_str)
return cls.from_single(device)
@classmethod
def from_single(cls, device: Device) -> "ComponentDevice":
"""
Create a component device representation from a single device.
Disks cannot be used as single devices.
:param device:
The device.
:returns:
The component device representation.
"""
if device.type == DeviceType.DISK:
raise ValueError("The disk device can only be used as a part of device maps")
return cls(_single_device=device)
@classmethod
def from_multiple(cls, device_map: DeviceMap) -> "ComponentDevice":
"""
Create a component device representation from a device map.
:param device_map:
The device map.
:returns:
The component device representation.
"""
return cls(_multiple_devices=device_map)
def _validate(self):
"""
Validate the component device representation.
"""
if not (self._single_device is not None) ^ (self._multiple_devices is not None):
raise ValueError(
"The component device can neither be empty nor contain both a single device and a device map"
)
def to_torch(self) -> "torch.device":
"""
Convert the component device representation to PyTorch format.
Device maps are not supported.
:returns:
The PyTorch device representation.
"""
self._validate()
if self._single_device is None:
raise ValueError("Only single devices can be converted to PyTorch format")
torch_import.check()
assert self._single_device is not None
return torch.device(str(self._single_device))
def to_torch_str(self) -> str:
"""
Convert the component device representation to PyTorch string format.
Device maps are not supported.
:returns:
The PyTorch device string representation.
"""
self._validate()
if self._single_device is None:
raise ValueError("Only single devices can be converted to PyTorch format")
assert self._single_device is not None
return str(self._single_device)
def to_spacy(self) -> int:
"""
Convert the component device representation to spaCy format.
Device maps are not supported.
:returns:
The spaCy device representation.
"""
self._validate()
if self._single_device is None:
raise ValueError("Only single devices can be converted to spaCy format")
assert self._single_device is not None
if self._single_device.type == DeviceType.GPU:
assert self._single_device.id is not None
return self._single_device.id
else:
return -1
def to_hf(self) -> Union[Union[int, str], Dict[str, Union[int, str]]]:
"""
Convert the component device representation to HuggingFace format.
:returns:
The HuggingFace device representation.
"""
self._validate()
def convert_device(device: Device, *, gpu_id_only: bool = False) -> Union[int, str]:
if gpu_id_only and device.type == DeviceType.GPU:
assert device.id is not None
return device.id
else:
return str(device)
if self._single_device is not None:
return convert_device(self._single_device)
assert self._multiple_devices is not None
return {key: convert_device(device, gpu_id_only=True) for key, device in self._multiple_devices.mapping.items()}
def update_hf_kwargs(self, hf_kwargs: Dict[str, Any], *, overwrite: bool) -> Dict[str, Any]:
"""
Convert the component device representation to HuggingFace format
and add them as canonical keyword arguments to the keyword arguments
dictionary.
:param hf_kwargs:
The HuggingFace keyword arguments dictionary.
:param overwrite:
Whether to overwrite existing device arguments.
:returns:
The HuggingFace keyword arguments dictionary.
"""
self._validate()
if not overwrite and any(x in hf_kwargs for x in ("device", "device_map")):
return hf_kwargs
converted = self.to_hf()
key = "device_map" if self.has_multiple_devices else "device"
hf_kwargs[key] = converted
return hf_kwargs
@property
def has_multiple_devices(self) -> bool:
"""
Whether this component device representation contains multiple
devices.
"""
self._validate()
return self._multiple_devices is not None
@property
def first_device(self) -> Optional[Device]:
"""
Return either the single device or the first device in the
device map, if any.
:returns:
The first device.
"""
self._validate()
if self._single_device is not None:
return self._single_device
assert self._multiple_devices is not None
return self._multiple_devices.first_device
@staticmethod
def resolve_device(device: Optional["ComponentDevice"] = None) -> "ComponentDevice":
"""
Select a device for a component. If a device is specified,
it's used. Otherwise, the default device is used.
:param device:
The provided device, if any.
:returns:
The resolved device.
"""
if not isinstance(device, ComponentDevice) and device is not None:
raise ValueError(
f"Invalid component device type '{type(device).__name__}'. Must either be None or ComponentDevice."
)
if device is None:
device = ComponentDevice.from_single(_get_default_device())
return device
def to_dict(self) -> Dict[str, Any]:
"""
Convert the component device representation to a JSON-serializable
dictionary.
:returns:
The dictionary representation.
"""
if self._single_device is not None:
return {"type": "single", "device": str(self._single_device)}
elif self._multiple_devices is not None:
return {"type": "multiple", "device_map": self._multiple_devices.to_dict()}
else:
# Unreachable
assert False
@classmethod
def from_dict(cls, dict: Dict[str, Any]) -> "ComponentDevice":
"""
Create a component device representation from a JSON-serialized
dictionary.
:param dict:
The serialized representation.
:returns:
The deserialized component device.
"""
if dict["type"] == "single":
return cls.from_str(dict["device"])
elif dict["type"] == "multiple":
return cls.from_multiple(DeviceMap.from_dict(dict["device_map"]))
else:
raise ValueError(f"Unknown component device type '{dict['type']}' in serialized data")
def _get_default_device() -> Device:
"""
Return the default device for Haystack. Precedence:
GPU > MPS > CPU. If PyTorch is not installed, only CPU
is available.
:returns:
The default device.
"""
try:
torch_import.check()
except ImportError as e:
logger.warning(e.msg)
return "cpu:0"
if torch.cuda.is_available():
device = "cuda:0"
elif (
hasattr(torch.backends, "mps")
and torch.backends.mps.is_available()
and os.getenv("HAYSTACK_MPS_ENABLED", "true") != "false"
):
device = "mps:0"
has_mps = (
hasattr(torch.backends, "mps")
and torch.backends.mps.is_available()
and os.getenv("HAYSTACK_MPS_ENABLED", "true") != "false"
)
has_cuda = torch.cuda.is_available()
except ImportError:
has_mps = False
has_cuda = False
if has_cuda:
return Device.gpu()
elif has_mps:
return Device.mps()
else:
device = "cpu:0"
return Device.cpu()
return device
def _split_device_string(string: str) -> Tuple[str, Optional[int]]:
"""
Split a device string into device type and device id.
:param string:
The device string to split.
:returns:
The device type and device id, if any.
"""
if ":" in string:
device_type, device_id_str = string.split(":")
try:
device_id = int(device_id_str)
except ValueError:
raise ValueError(f"Device id must be an integer, got {device_id_str}")
else:
device_type = string
device_id = None
return device_type, device_id

View File

@ -0,0 +1,31 @@
---
upgrade:
- |
Implement framework-agnostic device representations. The main impetus behind this change is to move away from stringified representations
of devices that are not portable between different frameworks. It also enables support for multi-device inference in a generic manner.
Going forward, components can expose a single, optional device parameter in their constructor (`Optional[ComponentDevice]`):
```python
import haystack.utils import ComponentDevice, Device, DeviceMap
class MyComponent(Component):
def __init__(self, device: Optional[ComponentDevice] = None):
# If device is None, automatically select a device.
self.device = ComponentDevice.resolve_device(device)
def warm_up(self):
# Call the framework-specific conversion method.
self.model = AutoModel.from_pretrained("deepset/bert-base-cased-squad2", device=self.device.to_hf())
# Automatically selects a device.
c = MyComponent(device=None)
# Uses the first GPU available.
c = MyComponent(device=ComponentDevice.from_str("cuda:0"))
# Uses the CPU.
c = MyComponent(device=ComponentDevice.from_single(Device.cpu()))
# Allow the component to use multiple devices using a device map.
c = MyComponent(device=ComponentDevice.from_multiple(DeviceMap({
"layer1": Device.cpu(),
"layer2": Device.gpu(1),
"layer3": Device.disk()
})))
```

View File

@ -2,6 +2,7 @@ import pytest
from haystack import ComponentError, DeserializationError
from haystack.components.extractors import NamedEntityExtractor, NamedEntityExtractorBackend
from haystack.utils.device import ComponentDevice
@pytest.mark.unit
@ -21,7 +22,9 @@ def test_named_entity_extractor_backend():
@pytest.mark.unit
def test_named_entity_extractor_serde():
extractor = NamedEntityExtractor(
backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER", device_id=-1
backend=NamedEntityExtractorBackend.HUGGING_FACE,
model="dslim/bert-base-NER",
device=ComponentDevice.from_str("cuda:1"),
)
serde_data = extractor.to_dict()
@ -29,7 +32,7 @@ def test_named_entity_extractor_serde():
assert type(new_extractor._backend) == type(extractor._backend)
assert new_extractor._backend.model_name == extractor._backend.model_name
assert new_extractor._backend.device_id == extractor._backend.device_id
assert new_extractor._backend.device == extractor._backend.device
with pytest.raises(DeserializationError, match=r"Couldn't deserialize"):
serde_data["init_parameters"].pop("backend")

View File

@ -1,11 +1,12 @@
# pylint: disable=too-many-public-methods
from unittest.mock import patch, Mock
from unittest.mock import Mock, patch
import pytest
import torch
from transformers import PreTrainedTokenizerFast
from haystack.components.generators.hugging_face_local import HuggingFaceLocalGenerator, StopWordsCriteria
from haystack.utils import ComponentDevice, Device
class TestHuggingFaceLocalGenerator:
@ -18,6 +19,7 @@ class TestHuggingFaceLocalGenerator:
"model": "google/flan-t5-base",
"task": "text2text-generation",
"token": None,
"device": ComponentDevice.resolve_device(None).to_hf(),
}
assert generator.generation_kwargs == {}
assert generator.pipeline is None
@ -31,10 +33,13 @@ class TestHuggingFaceLocalGenerator:
"model": "google/flan-t5-base",
"task": "text2text-generation",
"token": "test-token",
"device": ComponentDevice.resolve_device(None).to_hf(),
}
def test_init_custom_device(self):
generator = HuggingFaceLocalGenerator(model="google/flan-t5-base", task="text2text-generation", device="cuda:0")
generator = HuggingFaceLocalGenerator(
model="google/flan-t5-base", task="text2text-generation", device=ComponentDevice.from_str("cuda:0")
)
assert generator.huggingface_pipeline_kwargs == {
"model": "google/flan-t5-base",
@ -50,6 +55,7 @@ class TestHuggingFaceLocalGenerator:
"model": "google/flan-t5-base",
"task": "text2text-generation",
"token": None,
"device": ComponentDevice.resolve_device(None).to_hf(),
}
def test_init_task_in_huggingface_pipeline_kwargs(self):
@ -59,6 +65,7 @@ class TestHuggingFaceLocalGenerator:
"model": "google/flan-t5-base",
"task": "text2text-generation",
"token": None,
"device": ComponentDevice.resolve_device(None).to_hf(),
}
@patch("haystack.components.generators.hugging_face_local.model_info")
@ -70,6 +77,7 @@ class TestHuggingFaceLocalGenerator:
"model": "google/flan-t5-base",
"task": "text2text-generation",
"token": None,
"device": ComponentDevice.resolve_device(None).to_hf(),
}
def test_init_invalid_task(self):
@ -92,7 +100,7 @@ class TestHuggingFaceLocalGenerator:
generator = HuggingFaceLocalGenerator(
model="google/flan-t5-base",
task="text2text-generation",
device="cpu",
device=ComponentDevice.from_str("cpu"),
token="test-token",
huggingface_pipeline_kwargs=huggingface_pipeline_kwargs,
)
@ -138,6 +146,7 @@ class TestHuggingFaceLocalGenerator:
"model": "google/flan-t5-base",
"task": "text2text-generation",
"token": None,
"device": ComponentDevice.resolve_device(None).to_hf(),
},
"generation_kwargs": {},
"stop_words": None,
@ -148,7 +157,7 @@ class TestHuggingFaceLocalGenerator:
component = HuggingFaceLocalGenerator(
model="gpt2",
task="text-generation",
device="cuda:0",
device=ComponentDevice.from_str("cuda:0"),
token="test-token",
generation_kwargs={"max_new_tokens": 100},
stop_words=["coca", "cola"],
@ -187,7 +196,7 @@ class TestHuggingFaceLocalGenerator:
component = HuggingFaceLocalGenerator(
model="gpt2",
task="text-generation",
device="cuda:0",
device=ComponentDevice.from_str("cuda:0"),
token="test-token",
generation_kwargs={"max_new_tokens": 100},
stop_words=["coca", "cola"],
@ -275,7 +284,10 @@ class TestHuggingFaceLocalGenerator:
generator.warm_up()
pipeline_mock.assert_called_once_with(
model="google/flan-t5-base", task="text2text-generation", token="test-token"
model="google/flan-t5-base",
task="text2text-generation",
token="test-token",
device=ComponentDevice.resolve_device(None).to_hf(),
)
@patch("haystack.components.generators.hugging_face_local.pipeline")

View File

@ -1,10 +1,12 @@
from unittest.mock import MagicMock, patch
import pytest
import torch
from transformers.modeling_outputs import SequenceClassifierOutput
from haystack import Document, ComponentError
from haystack import ComponentError, Document
from haystack.components.rankers.transformers_similarity import TransformersSimilarityRanker
from haystack.utils.device import ComponentDevice
class TestSimilarityRanker:
@ -14,7 +16,7 @@ class TestSimilarityRanker:
assert data == {
"type": "haystack.components.rankers.transformers_similarity.TransformersSimilarityRanker",
"init_parameters": {
"device": "cpu",
"device": ComponentDevice.resolve_device(None).to_dict(),
"top_k": 10,
"token": None,
"model": "cross-encoder/ms-marco-MiniLM-L-6-v2",
@ -30,7 +32,7 @@ class TestSimilarityRanker:
def test_to_dict_with_custom_init_parameters(self):
component = TransformersSimilarityRanker(
model="my_model",
device="cuda",
device=ComponentDevice.from_str("cuda:0"),
token="my_token",
top_k=5,
scale_score=False,
@ -42,7 +44,7 @@ class TestSimilarityRanker:
assert data == {
"type": "haystack.components.rankers.transformers_similarity.TransformersSimilarityRanker",
"init_parameters": {
"device": "cuda",
"device": ComponentDevice.from_str("cuda:0").to_dict(),
"model": "my_model",
"token": None, # we don't serialize valid tokens,
"top_k": 5,
@ -68,7 +70,7 @@ class TestSimilarityRanker:
assert data == {
"type": "haystack.components.rankers.transformers_similarity.TransformersSimilarityRanker",
"init_parameters": {
"device": "cpu",
"device": ComponentDevice.resolve_device(None).to_dict(),
"top_k": 10,
"token": None,
"model": "cross-encoder/ms-marco-MiniLM-L-6-v2",
@ -90,7 +92,7 @@ class TestSimilarityRanker:
data = {
"type": "haystack.components.rankers.transformers_similarity.TransformersSimilarityRanker",
"init_parameters": {
"device": "cuda",
"device": ComponentDevice.from_str("cuda:0").to_dict(),
"model": "my_model",
"token": None,
"top_k": 5,
@ -104,7 +106,7 @@ class TestSimilarityRanker:
}
component = TransformersSimilarityRanker.from_dict(data)
assert component.device == "cuda"
assert component.device == ComponentDevice.from_str("cuda:0")
assert component.model == "my_model"
assert component.token is None
assert component.top_k == 5

View File

@ -1,13 +1,14 @@
from math import ceil, exp
from typing import List
from unittest.mock import patch, Mock
import pytest
from unittest.mock import Mock, patch
import pytest
import torch
from transformers import pipeline
from haystack.components.readers import ExtractiveReader
from haystack import Document, ExtractedAnswer
from haystack.components.readers import ExtractiveReader
from haystack.utils.device import ComponentDevice
@pytest.fixture
@ -52,7 +53,7 @@ def mock_tokenizer():
def mock_reader(mock_tokenizer):
class MockModel(torch.nn.Module):
def to(self, device):
assert device == "cpu:0"
assert device == torch.device("cpu")
self.device_set = True
return self
@ -72,7 +73,7 @@ def mock_reader(mock_tokenizer):
with patch("haystack.components.readers.extractive.AutoModelForQuestionAnswering.from_pretrained") as model:
model.return_value = MockModel()
reader = ExtractiveReader(model="mock-model", device="cpu:0")
reader = ExtractiveReader(model="mock-model", device=ComponentDevice.from_str("cpu"))
reader.warm_up()
return reader
@ -95,7 +96,7 @@ def test_to_dict():
"type": "haystack.components.readers.extractive.ExtractiveReader",
"init_parameters": {
"model": "my-model",
"device": None,
"device": ComponentDevice.resolve_device(None).to_dict(),
"token": None, # don't serialize valid tokens
"top_k": 20,
"score_threshold": None,
@ -118,7 +119,7 @@ def test_to_dict_empty_model_kwargs():
"type": "haystack.components.readers.extractive.ExtractiveReader",
"init_parameters": {
"model": "my-model",
"device": None,
"device": ComponentDevice.resolve_device(None).to_dict(),
"token": None, # don't serialize valid tokens
"top_k": 20,
"score_threshold": None,
@ -138,7 +139,7 @@ def test_from_dict():
"type": "haystack.components.readers.extractive.ExtractiveReader",
"init_parameters": {
"model": "my-model",
"device": None,
"device": ComponentDevice.resolve_device(None).to_dict(),
"token": None,
"top_k": 20,
"score_threshold": None,
@ -154,6 +155,7 @@ def test_from_dict():
component = ExtractiveReader.from_dict(data)
assert component.model_name_or_path == "my-model"
assert component.device == ComponentDevice.resolve_device(None)
assert component.token is None
assert component.top_k == 20
assert component.score_threshold is None
@ -551,7 +553,9 @@ def test_roberta():
@pytest.mark.integration
def test_matches_hf_pipeline():
reader = ExtractiveReader("deepset/tinyroberta-squad2", device="cpu", overlap_threshold=None)
reader = ExtractiveReader(
"deepset/tinyroberta-squad2", device=ComponentDevice.from_str("cpu"), overlap_threshold=None
)
reader.warm_up()
answers = reader.run(example_queries[0], [[example_documents[0][0]]][0], top_k=20, no_answer=False)[
"answers"

View File

@ -1,28 +1,133 @@
import os
from unittest.mock import patch
from haystack.utils import get_device
import pytest
from haystack.utils import ComponentDevice, Device, DeviceMap, DeviceType
@patch("torch.cuda.is_available")
def test_get_device_cuda(torch_cuda_is_available):
torch_cuda_is_available.return_value = True
device = get_device()
assert device == "cuda:0"
def test_device_type():
for e in DeviceType:
assert e == DeviceType.from_str(e.value)
with pytest.raises(ValueError, match="Unknown device type string"):
DeviceType.from_str("tpu")
def test_device_creation():
assert Device.cpu().type == DeviceType.CPU
assert Device.gpu().type == DeviceType.GPU
assert Device.mps().type == DeviceType.MPS
assert Device.disk().type == DeviceType.DISK
assert Device.from_str("cpu") == Device.cpu()
assert Device.from_str("cuda:1") == Device.gpu(1)
assert Device.from_str("disk") == Device.disk()
assert Device.from_str("mps:0") == Device(DeviceType.MPS, 0)
with pytest.raises(ValueError, match="Device id must be >= 0"):
Device.gpu(-1)
def test_device_map():
map = DeviceMap({"layer1": Device.cpu(), "layer2": Device.gpu(1), "layer3": Device.disk()})
assert all(x in map for x in ["layer1", "layer2", "layer3"])
assert len(map) == 3
assert map["layer1"] == Device.cpu()
assert map["layer2"] == Device.gpu(1)
assert map["layer3"] == Device.disk()
for k, d in map:
assert k in ["layer1", "layer2", "layer3"]
assert d in [Device.cpu(), Device.gpu(1), Device.disk()]
map["layer1"] = Device.gpu(0)
assert map["layer1"] == Device.gpu(0)
map["layer2"] = Device.cpu()
assert DeviceMap.from_hf({"layer1": 0, "layer2": "cpu", "layer3": "disk"}) == DeviceMap(
{"layer1": Device.gpu(0), "layer2": Device.cpu(), "layer3": Device.disk()}
)
with pytest.raises(ValueError, match="unexpected device"):
DeviceMap.from_hf({"layer1": 0.1})
assert map.first_device == Device.gpu(0)
assert DeviceMap({}).first_device is None
def test_component_device_empty_and_full():
with pytest.raises(ValueError, match="neither be empty nor contain"):
ComponentDevice().first_device
with pytest.raises(ValueError, match="neither be empty nor contain"):
ComponentDevice(Device.cpu(), DeviceMap({})).to_hf()
def test_component_device_single():
single = ComponentDevice.from_single(Device.gpu(1))
assert not single.has_multiple_devices
assert single._single_device == Device.gpu(1)
assert single._multiple_devices is None
with pytest.raises(ValueError, match="disk device can only be used as a part of device maps"):
ComponentDevice.from_single(Device.disk())
assert single.to_torch_str() == "cuda:1"
assert single.to_spacy() == 1
assert single.to_hf() == "cuda:1"
assert single.update_hf_kwargs({}, overwrite=False) == {"device": "cuda:1"}
assert single.update_hf_kwargs({"device": 0}, overwrite=True) == {"device": "cuda:1"}
assert single.first_device == single._single_device
def test_component_device_multiple():
multiple = ComponentDevice.from_multiple(
DeviceMap({"layer1": Device.cpu(), "layer2": Device.gpu(1), "layer3": Device.disk()})
)
assert multiple.has_multiple_devices
assert multiple._single_device is None
assert multiple._multiple_devices == DeviceMap(
{"layer1": Device.cpu(), "layer2": Device.gpu(1), "layer3": Device.disk()}
)
with pytest.raises(ValueError, match="Only single devices"):
multiple.to_torch()
with pytest.raises(ValueError, match="Only single devices"):
multiple.to_torch_str()
with pytest.raises(ValueError, match="Only single devices"):
multiple.to_spacy()
assert multiple.to_hf() == {"layer1": "cpu", "layer2": 1, "layer3": "disk"}
assert multiple.update_hf_kwargs({}, overwrite=False) == {
"device_map": {"layer1": "cpu", "layer2": 1, "layer3": "disk"}
}
assert multiple.update_hf_kwargs({"device_map": {None: None}}, overwrite=True) == {
"device_map": {"layer1": "cpu", "layer2": 1, "layer3": "disk"}
}
assert multiple.first_device == Device.cpu()
@patch("torch.backends.mps.is_available")
@patch("torch.cuda.is_available")
def test_get_device_mps(torch_cuda_is_available, torch_backends_mps_is_available):
def test_component_device_resolution(torch_cuda_is_available, torch_backends_mps_is_available):
assert ComponentDevice.resolve_device(ComponentDevice.from_single(Device.cpu()))._single_device == Device.cpu()
torch_cuda_is_available.return_value = True
assert ComponentDevice.resolve_device(None)._single_device == Device.gpu(0)
torch_cuda_is_available.return_value = False
torch_backends_mps_is_available.return_value = True
device = get_device()
assert device == "mps:0"
assert ComponentDevice.resolve_device(None)._single_device == Device.mps()
@patch("torch.backends.mps.is_available")
@patch("torch.cuda.is_available")
def test_get_device_cpu(torch_cuda_is_available, torch_backends_mps_is_available):
torch_cuda_is_available.return_value = False
torch_backends_mps_is_available.return_value = False
device = get_device()
assert device == "cpu:0"
assert ComponentDevice.resolve_device(None)._single_device == Device.cpu()
torch_cuda_is_available.return_value = False
torch_backends_mps_is_available.return_value = True
os.environ["HAYSTACK_MPS_ENABLED"] = "false"
assert ComponentDevice.resolve_device(None)._single_device == Device.cpu()