leejunhyeok commited on
Commit
bdd0329
·
verified ·
1 Parent(s): 9b40539

Update modeling_motif.py

Browse files
Files changed (1) hide show
  1. modeling_motif.py +23 -29
modeling_motif.py CHANGED
@@ -35,11 +35,7 @@ logger = logging.get_logger(__name__)
35
  if is_flash_attn_2_available():
36
  from transformers.modeling_flash_attention_utils import _flash_attention_forward
37
 
38
- MorehRMSNorm = None
39
- ScaledDotProductAttention = None
40
- MorehFlashAttention = None
41
 
42
- #_CHECKPOINT_FOR_DOC = "moreh/Motif-102B"
43
  _CONFIG_FOR_DOC = "MotifConfig"
44
 
45
  from transformers.activations import ACT2CLS as _ACT2CLS
@@ -538,28 +534,27 @@ class MotifFlashAttention2(MotifAttention):
538
  return tensor.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
539
 
540
  def _compute_attention(self, query_states, key_states, value_states, attention_mask, q_len, position_ids,
541
- dropout_rate, sliding_window, is_moreh_attention, batch_num):
542
  """Flash Attention 2 implements"""
543
- if is_moreh_attention:
544
- scale_factor = 1.0 / math.sqrt(self.head_dim)
545
- # Copied from _flash_attention_forward
546
- if not self._flash_attn_uses_top_left_mask:
547
- causal = self.is_causal
548
- else:
549
- causal = self.is_causal and q_len != 1
550
 
551
- bsz = query_states.shape[0]
552
 
553
- return _flash_attention_forward(query_states,
554
- key_states,
555
- value_states,
556
- attention_mask,
557
- q_len,
558
- position_ids=position_ids,
559
- dropout=dropout_rate,
560
- sliding_window=sliding_window,
561
- is_causal=self.is_causal,
562
- use_top_left_mask=self._flash_attn_uses_top_left_mask)
563
 
564
  def forward(
565
  self,
@@ -660,13 +655,12 @@ class MotifFlashAttention2(MotifAttention):
660
  k1, k2 = k1.contiguous(), k2.contiguous()
661
  v1, v2 = v1.contiguous(), v2.contiguous()
662
 
663
- is_moreh_attention = MorehFlashAttention is not None
664
-
665
- attn11, attn12 = self._compute_attention(q1, k1, v1, attention_mask, q_len, position_ids, dropout_rate, sliding_window, is_moreh_attention, self.batch_num), \
666
- self._compute_attention(q1, k1, v2, attention_mask, q_len, position_ids, dropout_rate, sliding_window, is_moreh_attention, self.batch_num)
667
- attn21, attn22 = self._compute_attention(q2, k2, v1, attention_mask, q_len, position_ids, dropout_rate, sliding_window, is_moreh_attention, self.batch_num), \
668
- self._compute_attention(q2, k2, v2, attention_mask, q_len, position_ids, dropout_rate, sliding_window, is_moreh_attention, self.batch_num)
669
 
 
670
  attn1, attn2 = torch.cat([attn11, attn12], dim=-1), torch.cat([attn21, attn22], dim=-1)
671
 
672
  lambda_q1 = self.lambda_q1.unsqueeze(0).expand([bsz, self.lambda_q1.shape[0]]) # bsz, num_head
 
35
  if is_flash_attn_2_available():
36
  from transformers.modeling_flash_attention_utils import _flash_attention_forward
37
 
 
 
 
38
 
 
39
  _CONFIG_FOR_DOC = "MotifConfig"
40
 
41
  from transformers.activations import ACT2CLS as _ACT2CLS
 
534
  return tensor.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
535
 
536
  def _compute_attention(self, query_states, key_states, value_states, attention_mask, q_len, position_ids,
537
+ dropout_rate, sliding_window, batch_num):
538
  """Flash Attention 2 implements"""
539
+ scale_factor = 1.0 / math.sqrt(self.head_dim)
540
+ # Copied from _flash_attention_forward
541
+ if not self._flash_attn_uses_top_left_mask:
542
+ causal = self.is_causal
543
+ else:
544
+ causal = self.is_causal and q_len != 1
 
545
 
546
+ bsz = query_states.shape[0]
547
 
548
+ return _flash_attention_forward(query_states,
549
+ key_states,
550
+ value_states,
551
+ attention_mask,
552
+ q_len,
553
+ position_ids=position_ids,
554
+ dropout=dropout_rate,
555
+ sliding_window=sliding_window,
556
+ is_causal=self.is_causal,
557
+ use_top_left_mask=self._flash_attn_uses_top_left_mask)
558
 
559
  def forward(
560
  self,
 
655
  k1, k2 = k1.contiguous(), k2.contiguous()
656
  v1, v2 = v1.contiguous(), v2.contiguous()
657
 
658
+ attn11, attn12 = self._compute_attention(q1, k1, v1, attention_mask, q_len, position_ids, dropout_rate, sliding_window, self.batch_num), \
659
+ self._compute_attention(q1, k1, v2, attention_mask, q_len, position_ids, dropout_rate, sliding_window, self.batch_num)
660
+ attn21, attn22 = self._compute_attention(q2, k2, v1, attention_mask, q_len, position_ids, dropout_rate, sliding_window, self.batch_num), \
661
+ self._compute_attention(q2, k2, v2, attention_mask, q_len, position_ids, dropout_rate, sliding_window, self.batch_num)
 
 
662
 
663
+
664
  attn1, attn2 = torch.cat([attn11, attn12], dim=-1), torch.cat([attn21, attn22], dim=-1)
665
 
666
  lambda_q1 = self.lambda_q1.unsqueeze(0).expand([bsz, self.lambda_q1.shape[0]]) # bsz, num_head