pranjalchitale commited on
Commit
67ac308
1 Parent(s): 79e484a

Update modeling_indictrans.py

Browse files
Files changed (1) hide show
  1. modeling_indictrans.py +7 -5
modeling_indictrans.py CHANGED
@@ -54,11 +54,13 @@ logger = logging.get_logger(__name__)
54
 
55
  INDICTRANS_PRETRAINED_MODEL_ARCHIVE_LIST = [""]
56
 
57
- if is_flash_attn_2_available():
58
- from flash_attn import flash_attn_func, flash_attn_varlen_func
59
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
60
-
61
-
 
 
62
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
63
  def _get_unpad_data(attention_mask):
64
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
 
54
 
55
  INDICTRANS_PRETRAINED_MODEL_ARCHIVE_LIST = [""]
56
 
57
+ try:
58
+ if is_flash_attn_2_available():
59
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
60
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
61
+ except:
62
+ pass
63
+
64
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
65
  def _get_unpad_data(attention_mask):
66
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)