update npu api (#9688)

This commit is contained in:
duanyanhui 2023-04-11 09:56:08 +08:00 committed by GitHub
parent 6d44c67848
commit 2d6f3a56a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -134,9 +134,18 @@ def check_device(use_gpu, use_xpu=False, use_npu=False, use_mlu=False):
if use_xpu and not paddle.device.is_compiled_with_xpu():
print(err.format("use_xpu", "xpu", "xpu", "use_xpu"))
sys.exit(1)
if use_npu and not paddle.device.is_compiled_with_npu():
print(err.format("use_npu", "npu", "npu", "use_npu"))
sys.exit(1)
if use_npu:
if int(paddle.version.major) != 0 and int(
paddle.version.major) <= 2 and int(
paddle.version.minor) <= 4:
if not paddle.device.is_compiled_with_npu():
print(err.format("use_npu", "npu", "npu", "use_npu"))
sys.exit(1)
# is_compiled_with_npu() has been updated after paddle-2.4
else:
if not paddle.device.is_compiled_with_custom_device("npu"):
print(err.format("use_npu", "npu", "npu", "use_npu"))
sys.exit(1)
if use_mlu and not paddle.device.is_compiled_with_mlu():
print(err.format("use_mlu", "mlu", "mlu", "use_mlu"))
sys.exit(1)