PIKA665 commited on
Commit
b457e16
·
verified ·
1 Parent(s): f228d5c

Update modular_openpangu_dense.py

Browse files
Files changed (1) hide show
  1. modular_openpangu_dense.py +7 -6
modular_openpangu_dense.py CHANGED
@@ -24,12 +24,13 @@ from typing import Callable, Optional, Tuple
24
  import torch
25
  from torch import nn
26
 
27
- import torch_npu
28
- from torch_npu.contrib import transfer_to_npu
29
- if "910" in torch.npu.get_device_name():
30
- NPU_ATTN_INFR = True
31
- print("[INFO] torch_npu detected. Using NPU fused infer attention.")
32
- else:
 
33
  NPU_ATTN_INFR = False
34
 
35
  from transformers.cache_utils import Cache
 
24
  import torch
25
  from torch import nn
26
 
27
+ try:
28
+ import torch_npu
29
+ from torch_npu.contrib import transfer_to_npu
30
+ if "910" in torch.npu.get_device_name():
31
+ NPU_ATTN_INFR = True
32
+ print("[INFO] torch_npu detected. Using NPU fused infer attention.")
33
+ except ImportError:
34
  NPU_ATTN_INFR = False
35
 
36
  from transformers.cache_utils import Cache