feat: adding support for torch xpu device (#9470)

* feat: add support for torch xpu device support

* test: xpu based tests ci/cd

* test: add xpu code device support

---------

Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com>
Co-authored-by: David S. Batista <dsbatista@gmail.com>
This commit is contained in:
Sriniketh J 2025-06-17 17:45:19 +05:30 committed by GitHub
parent 7dbac5b3c9
commit 6198f0cba9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 43 additions and 2 deletions

View File

@ -16,6 +16,7 @@ env:
PYTHON_VERSION: "3.9"
HATCH_VERSION: "1.14.1"
HAYSTACK_MPS_ENABLED: false
HAYSTACK_XPU_ENABLED: false
on:
workflow_dispatch: # Activate this workflow manually

View File

@ -401,6 +401,9 @@ jobs:
needs: unit-tests
runs-on: windows-latest
timeout-minutes: 30
env:
HAYSTACK_XPU_ENABLED: false
steps:
- uses: actions/checkout@v4

View File

@ -28,6 +28,7 @@ class DeviceType(Enum):
GPU = "cuda"
DISK = "disk"
MPS = "mps"
XPU = "xpu"
def __str__(self):
return self.value
@ -126,6 +127,16 @@ class Device:
"""
return Device(DeviceType.MPS)
@staticmethod
def xpu() -> "Device":
"""
Create a generic Intel GPU Optimization device.
:returns:
The XPU device.
"""
return Device(DeviceType.XPU)
@staticmethod
def from_str(string: str) -> "Device":
"""
@ -482,7 +493,7 @@ def _get_default_device() -> Device:
Return the default device for Haystack.
Precedence:
GPU > MPS > CPU. If PyTorch is not installed, only CPU is available.
GPU > XPU > MPS > CPU. If PyTorch is not installed, only CPU is available.
:returns:
The default device.
@ -496,12 +507,21 @@ def _get_default_device() -> Device:
and os.getenv("HAYSTACK_MPS_ENABLED", "true") != "false"
)
has_cuda = torch.cuda.is_available()
has_xpu = (
hasattr(torch, "xpu")
and hasattr(torch.xpu, "is_available")
and torch.xpu.is_available()
and os.getenv("HAYSTACK_XPU_ENABLED", "true") != "false"
)
except ImportError:
has_mps = False
has_cuda = False
has_xpu = False
if has_cuda:
return Device.gpu()
elif has_xpu:
return Device.xpu()
elif has_mps:
return Device.mps()
else:

View File

@ -22,12 +22,14 @@ def test_device_creation():
assert Device.cpu().type == DeviceType.CPU
assert Device.gpu().type == DeviceType.GPU
assert Device.mps().type == DeviceType.MPS
assert Device.xpu().type == DeviceType.XPU
assert Device.disk().type == DeviceType.DISK
assert Device.from_str("cpu") == Device.cpu()
assert Device.from_str("cuda:1") == Device.gpu(1)
assert Device.from_str("disk") == Device.disk()
assert Device.from_str("mps:0") == Device(DeviceType.MPS, 0)
assert Device.from_str("xpu:0") == Device(DeviceType.XPU, 0)
with pytest.raises(ValueError, match="Device id must be >= 0"):
Device.gpu(-1)
@ -115,23 +117,38 @@ def test_component_device_multiple():
assert multiple.first_device == ComponentDevice.from_single(Device.cpu())
@patch("torch.xpu.is_available")
@patch("torch.backends.mps.is_available")
@patch("torch.cuda.is_available")
def test_component_device_resolution(torch_cuda_is_available, torch_backends_mps_is_available):
def test_component_device_resolution(torch_cuda_is_available, torch_backends_mps_is_available, torch_xpu_is_available):
assert ComponentDevice.resolve_device(ComponentDevice.from_single(Device.cpu()))._single_device == Device.cpu()
torch_cuda_is_available.return_value = True
assert ComponentDevice.resolve_device(None)._single_device == Device.gpu(0)
torch_cuda_is_available.return_value = False
torch_xpu_is_available.return_value = True
torch_backends_mps_is_available.return_value = False
assert ComponentDevice.resolve_device(None)._single_device == Device.xpu()
torch_cuda_is_available.return_value = False
torch_xpu_is_available.return_value = False
torch_backends_mps_is_available.return_value = True
assert ComponentDevice.resolve_device(None)._single_device == Device.mps()
torch_cuda_is_available.return_value = False
torch_xpu_is_available.return_value = False
torch_backends_mps_is_available.return_value = False
assert ComponentDevice.resolve_device(None)._single_device == Device.cpu()
torch_cuda_is_available.return_value = False
torch_xpu_is_available.return_value = False
torch_backends_mps_is_available.return_value = True
os.environ["HAYSTACK_MPS_ENABLED"] = "false"
assert ComponentDevice.resolve_device(None)._single_device == Device.cpu()
torch_cuda_is_available.return_value = False
torch_xpu_is_available.return_value = True
os.environ["HAYSTACK_XPU_ENABLED"] = "false"
torch_backends_mps_is_available.return_value = False
assert ComponentDevice.resolve_device(None)._single_device == Device.cpu()