Sofia Casadei commited on
Commit
aacc5eb
Β·
1 Parent(s): 5c44b80
Files changed (1) hide show
  1. main.py +2 -1
main.py CHANGED
@@ -44,7 +44,8 @@ LANGUAGE = os.getenv("LANGUAGE", "english").lower()
44
  device = get_device(force_cpu=False)
45
  use_device_map = True if device == "cuda" else False
46
  try_compile_model = True if device == "cuda" or (device == "mps" and torch.__version__ >= "2.7.0") else False
47
- try_use_flash_attention = True if device == "cuda" and is_flash_attn_2_available() else False
 
48
 
49
  torch_dtype, np_dtype = get_torch_and_np_dtypes(device, use_bfloat16=False)
50
  logger.info(f"Using device: {device}, torch_dtype: {torch_dtype}, np_dtype: {np_dtype}")
 
44
  device = get_device(force_cpu=False)
45
  use_device_map = True if device == "cuda" else False
46
  try_compile_model = True if device == "cuda" or (device == "mps" and torch.__version__ >= "2.7.0") else False
47
+ try_use_flash_attention = False
48
+ #try_use_flash_attention = True if device == "cuda" and is_flash_attn_2_available() else False
49
 
50
  torch_dtype, np_dtype = get_torch_and_np_dtypes(device, use_bfloat16=False)
51
  logger.info(f"Using device: {device}, torch_dtype: {torch_dtype}, np_dtype: {np_dtype}")