jbilcke-hf HF Staff commited on
Commit
f3c7aab
·
verified ·
1 Parent(s): 8f50f73

Update wan/modules/attention.py

Browse files
Files changed (1) hide show
  1. wan/modules/attention.py +11 -2
wan/modules/attention.py CHANGED
@@ -2,22 +2,31 @@
2
  import torch
3
 
4
  try:
 
5
  import flash_attn_interface
6
 
7
  def is_hopper_gpu():
 
8
  if not torch.cuda.is_available():
 
9
  return False
10
  device_name = torch.cuda.get_device_name(0).lower()
 
11
  return "h100" in device_name or "hopper" in device_name
12
  FLASH_ATTN_3_AVAILABLE = is_hopper_gpu()
13
- except ModuleNotFoundError:
 
14
  FLASH_ATTN_3_AVAILABLE = False
 
15
 
16
  try:
 
17
  import flash_attn
18
  FLASH_ATTN_2_AVAILABLE = True
19
- except ModuleNotFoundError:
 
20
  FLASH_ATTN_2_AVAILABLE = False
 
21
 
22
  # FLASH_ATTN_3_AVAILABLE = False
23
 
 
2
  import torch
3
 
4
  try:
5
+ print("calling import flash_attn_interface")
6
  import flash_attn_interface
7
 
8
  def is_hopper_gpu():
9
+ print("is_hopper_gpu(): checking if not torch.cuda.is_available()")
10
  if not torch.cuda.is_available():
11
+ print("is_hopper_gpu(): turch.cuda is not available, so this is not Hopper GPU")
12
  return False
13
  device_name = torch.cuda.get_device_name(0).lower()
14
+ print(f"is_hopper_gpu(): device_name = {device_name}")
15
  return "h100" in device_name or "hopper" in device_name
16
  FLASH_ATTN_3_AVAILABLE = is_hopper_gpu()
17
+ except ModuleNotFoundError as e:
18
+ print(f"Got a ModuleNotFoundError for Flash Attention 3: {e}")
19
  FLASH_ATTN_3_AVAILABLE = False
20
+ print(f"FLASH_ATTN_3_AVAILABLE ? -> {FLASH_ATTN_3_AVAILABLE}")
21
 
22
  try:
23
+ print("calling import flash_attn")
24
  import flash_attn
25
  FLASH_ATTN_2_AVAILABLE = True
26
+ except ModuleNotFoundError as e:
27
+ print(f"Got a ModuleNotFoundError for Flash Attention 2: {e}")
28
  FLASH_ATTN_2_AVAILABLE = False
29
+ print(f"FLASH_ATTN_2_AVAILABLE ? -> {FLASH_ATTN_2_AVAILABLE}")
30
 
31
  # FLASH_ATTN_3_AVAILABLE = False
32