support musa backend in FlagEmbedding

This commit is contained in:
qiyulei-mt 2025-01-23 10:35:05 +08:00
parent f9f673e4ff
commit 4ffa194839
2 changed files with 30 additions and 4 deletions

View File

@ -13,6 +13,11 @@ import torch
import numpy as np
from transformers import is_torch_npu_available
try:
import torch_musa
except Exception:
pass
logger = logging.getLogger(__name__)
@ -125,6 +130,8 @@ class AbsEmbedder(ABC):
return [f"cuda:{i}" for i in range(torch.cuda.device_count())]
elif is_torch_npu_available():
return [f"npu:{i}" for i in range(torch.npu.device_count())]
elif torch.musa.is_available():
return [f"musa:{i}" for i in range(torch.musa.device_count())]
elif torch.backends.mps.is_available():
try:
return [f"mps:{i}" for i in range(torch.mps.device_count())]
@ -135,12 +142,18 @@ class AbsEmbedder(ABC):
elif isinstance(devices, str):
return [devices]
elif isinstance(devices, int):
return [f"cuda:{devices}"]
if torch.musa.is_available():
return [f"musa:{devices}"]
else:
return [f"cuda:{devices}"]
elif isinstance(devices, list):
if isinstance(devices[0], str):
return devices
elif isinstance(devices[0], int):
return [f"cuda:{device}" for device in devices]
if torch.musa.is_available():
return [f"musa:{device}" for device in devices]
else:
return [f"cuda:{device}" for device in devices]
else:
raise ValueError("devices should be a string or an integer or a list of strings or a list of integers.")
else:

View File

@ -12,6 +12,11 @@ import numpy as np
from tqdm import tqdm, trange
from transformers import is_torch_npu_available
try:
import torch_musa
except Exception:
pass
logger = logging.getLogger(__name__)
@ -107,6 +112,8 @@ class AbsReranker(ABC):
return [f"cuda:{i}" for i in range(torch.cuda.device_count())]
elif is_torch_npu_available():
return [f"npu:{i}" for i in range(torch.npu.device_count())]
elif torch.musa.is_available():
return [f"musa:{i}" for i in range(torch.musa.device_count())]
elif torch.backends.mps.is_available():
return ["mps"]
else:
@ -114,12 +121,18 @@ class AbsReranker(ABC):
elif isinstance(devices, str):
return [devices]
elif isinstance(devices, int):
return [f"cuda:{devices}"]
if torch.musa.is_available():
return [f"musa:{devices}"]
else:
return [f"cuda:{devices}"]
elif isinstance(devices, list):
if isinstance(devices[0], str):
return devices
elif isinstance(devices[0], int):
return [f"cuda:{device}" for device in devices]
if torch.musa.is_available():
return [f"musa:{device}" for device in devices]
else:
return [f"cuda:{device}" for device in devices]
else:
raise ValueError("devices should be a string or an integer or a list of strings or a list of integers.")
else: