mirror of
https://github.com/FlagOpen/FlagEmbedding.git
synced 2025-12-30 00:36:17 +00:00
Merge pull request #1350 from qiyulei-mt/musa_support
support musa backend in FlagEmbedding
This commit is contained in:
commit
44e552575f
@ -13,6 +13,11 @@ import torch
|
||||
import numpy as np
|
||||
from transformers import is_torch_npu_available
|
||||
|
||||
try:
|
||||
import torch_musa
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -106,6 +111,8 @@ class AbsEmbedder(ABC):
|
||||
return [f"cuda:{i}" for i in range(torch.cuda.device_count())]
|
||||
elif is_torch_npu_available():
|
||||
return [f"npu:{i}" for i in range(torch.npu.device_count())]
|
||||
elif hasattr(torch, "musa") and torch.musa.is_available():
|
||||
return [f"musa:{i}" for i in range(torch.musa.device_count())]
|
||||
elif torch.backends.mps.is_available():
|
||||
try:
|
||||
return [f"mps:{i}" for i in range(torch.mps.device_count())]
|
||||
@ -116,12 +123,18 @@ class AbsEmbedder(ABC):
|
||||
elif isinstance(devices, str):
|
||||
return [devices]
|
||||
elif isinstance(devices, int):
|
||||
return [f"cuda:{devices}"]
|
||||
if hasattr(torch, "musa") and torch.musa.is_available():
|
||||
return [f"musa:{devices}"]
|
||||
else:
|
||||
return [f"cuda:{devices}"]
|
||||
elif isinstance(devices, list):
|
||||
if isinstance(devices[0], str):
|
||||
return devices
|
||||
elif isinstance(devices[0], int):
|
||||
return [f"cuda:{device}" for device in devices]
|
||||
if hasattr(torch, "musa") and torch.musa.is_available():
|
||||
return [f"musa:{device}" for device in devices]
|
||||
else:
|
||||
return [f"cuda:{device}" for device in devices]
|
||||
else:
|
||||
raise ValueError("devices should be a string or an integer or a list of strings or a list of integers.")
|
||||
else:
|
||||
|
||||
@ -12,6 +12,11 @@ import numpy as np
|
||||
from tqdm import tqdm, trange
|
||||
from transformers import is_torch_npu_available
|
||||
|
||||
try:
|
||||
import torch_musa
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -107,6 +112,8 @@ class AbsReranker(ABC):
|
||||
return [f"cuda:{i}" for i in range(torch.cuda.device_count())]
|
||||
elif is_torch_npu_available():
|
||||
return [f"npu:{i}" for i in range(torch.npu.device_count())]
|
||||
elif hasattr(torch, "musa") and torch.musa.is_available():
|
||||
return [f"musa:{i}" for i in range(torch.musa.device_count())]
|
||||
elif torch.backends.mps.is_available():
|
||||
return ["mps"]
|
||||
else:
|
||||
@ -114,12 +121,18 @@ class AbsReranker(ABC):
|
||||
elif isinstance(devices, str):
|
||||
return [devices]
|
||||
elif isinstance(devices, int):
|
||||
return [f"cuda:{devices}"]
|
||||
if hasattr(torch, "musa") and torch.musa.is_available():
|
||||
return [f"musa:{devices}"]
|
||||
else:
|
||||
return [f"cuda:{devices}"]
|
||||
elif isinstance(devices, list):
|
||||
if isinstance(devices[0], str):
|
||||
return devices
|
||||
elif isinstance(devices[0], int):
|
||||
return [f"cuda:{device}" for device in devices]
|
||||
if hasattr(torch, "musa") and torch.musa.is_available():
|
||||
return [f"musa:{device}" for device in devices]
|
||||
else:
|
||||
return [f"cuda:{device}" for device in devices]
|
||||
else:
|
||||
raise ValueError("devices should be a string or an integer or a list of strings or a list of integers.")
|
||||
else:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user