Update modeling_motif.py
Browse files- modeling_motif.py +23 -29
modeling_motif.py
CHANGED
|
@@ -35,11 +35,7 @@ logger = logging.get_logger(__name__)
|
|
| 35 |
if is_flash_attn_2_available():
|
| 36 |
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
| 37 |
|
| 38 |
-
MorehRMSNorm = None
|
| 39 |
-
ScaledDotProductAttention = None
|
| 40 |
-
MorehFlashAttention = None
|
| 41 |
|
| 42 |
-
#_CHECKPOINT_FOR_DOC = "moreh/Motif-102B"
|
| 43 |
_CONFIG_FOR_DOC = "MotifConfig"
|
| 44 |
|
| 45 |
from transformers.activations import ACT2CLS as _ACT2CLS
|
|
@@ -538,28 +534,27 @@ class MotifFlashAttention2(MotifAttention):
|
|
| 538 |
return tensor.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
|
| 539 |
|
| 540 |
def _compute_attention(self, query_states, key_states, value_states, attention_mask, q_len, position_ids,
|
| 541 |
-
dropout_rate, sliding_window,
|
| 542 |
"""Flash Attention 2 implements"""
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
causal = self.is_causal and q_len != 1
|
| 550 |
|
| 551 |
-
|
| 552 |
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
|
| 564 |
def forward(
|
| 565 |
self,
|
|
@@ -660,13 +655,12 @@ class MotifFlashAttention2(MotifAttention):
|
|
| 660 |
k1, k2 = k1.contiguous(), k2.contiguous()
|
| 661 |
v1, v2 = v1.contiguous(), v2.contiguous()
|
| 662 |
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
self._compute_attention(
|
| 667 |
-
attn21, attn22 = self._compute_attention(q2, k2, v1, attention_mask, q_len, position_ids, dropout_rate, sliding_window, is_moreh_attention, self.batch_num), \
|
| 668 |
-
self._compute_attention(q2, k2, v2, attention_mask, q_len, position_ids, dropout_rate, sliding_window, is_moreh_attention, self.batch_num)
|
| 669 |
|
|
|
|
| 670 |
attn1, attn2 = torch.cat([attn11, attn12], dim=-1), torch.cat([attn21, attn22], dim=-1)
|
| 671 |
|
| 672 |
lambda_q1 = self.lambda_q1.unsqueeze(0).expand([bsz, self.lambda_q1.shape[0]]) # bsz, num_head
|
|
|
|
| 35 |
if is_flash_attn_2_available():
|
| 36 |
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
| 37 |
|
|
|
|
|
|
|
|
|
|
| 38 |
|
|
|
|
| 39 |
_CONFIG_FOR_DOC = "MotifConfig"
|
| 40 |
|
| 41 |
from transformers.activations import ACT2CLS as _ACT2CLS
|
|
|
|
| 534 |
return tensor.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
|
| 535 |
|
| 536 |
def _compute_attention(self, query_states, key_states, value_states, attention_mask, q_len, position_ids,
|
| 537 |
+
dropout_rate, sliding_window, batch_num):
|
| 538 |
"""Flash Attention 2 implements"""
|
| 539 |
+
scale_factor = 1.0 / math.sqrt(self.head_dim)
|
| 540 |
+
# Copied from _flash_attention_forward
|
| 541 |
+
if not self._flash_attn_uses_top_left_mask:
|
| 542 |
+
causal = self.is_causal
|
| 543 |
+
else:
|
| 544 |
+
causal = self.is_causal and q_len != 1
|
|
|
|
| 545 |
|
| 546 |
+
bsz = query_states.shape[0]
|
| 547 |
|
| 548 |
+
return _flash_attention_forward(query_states,
|
| 549 |
+
key_states,
|
| 550 |
+
value_states,
|
| 551 |
+
attention_mask,
|
| 552 |
+
q_len,
|
| 553 |
+
position_ids=position_ids,
|
| 554 |
+
dropout=dropout_rate,
|
| 555 |
+
sliding_window=sliding_window,
|
| 556 |
+
is_causal=self.is_causal,
|
| 557 |
+
use_top_left_mask=self._flash_attn_uses_top_left_mask)
|
| 558 |
|
| 559 |
def forward(
|
| 560 |
self,
|
|
|
|
| 655 |
k1, k2 = k1.contiguous(), k2.contiguous()
|
| 656 |
v1, v2 = v1.contiguous(), v2.contiguous()
|
| 657 |
|
| 658 |
+
attn11, attn12 = self._compute_attention(q1, k1, v1, attention_mask, q_len, position_ids, dropout_rate, sliding_window, self.batch_num), \
|
| 659 |
+
self._compute_attention(q1, k1, v2, attention_mask, q_len, position_ids, dropout_rate, sliding_window, self.batch_num)
|
| 660 |
+
attn21, attn22 = self._compute_attention(q2, k2, v1, attention_mask, q_len, position_ids, dropout_rate, sliding_window, self.batch_num), \
|
| 661 |
+
self._compute_attention(q2, k2, v2, attention_mask, q_len, position_ids, dropout_rate, sliding_window, self.batch_num)
|
|
|
|
|
|
|
| 662 |
|
| 663 |
+
|
| 664 |
attn1, attn2 = torch.cat([attn11, attn12], dim=-1), torch.cat([attn21, attn22], dim=-1)
|
| 665 |
|
| 666 |
lambda_q1 = self.lambda_q1.unsqueeze(0).expand([bsz, self.lambda_q1.shape[0]]) # bsz, num_head
|