Update modeling_motif.py
Browse files- modeling_motif.py +2 -2
modeling_motif.py
CHANGED
|
@@ -493,7 +493,7 @@ class MotifFlashAttention2(MotifAttention):
|
|
| 493 |
def _compute_attention(self, query_states, key_states, value_states, attention_mask, q_len, position_ids,
|
| 494 |
dropout_rate, sliding_window):
|
| 495 |
"""Flash Attention 2 implements"""
|
| 496 |
-
|
| 497 |
scale_factor = 1.0 / math.sqrt(self.head_dim)
|
| 498 |
if not self._flash_attn_uses_top_left_mask:
|
| 499 |
causal = self.is_causal
|
|
@@ -511,7 +511,7 @@ class MotifFlashAttention2(MotifAttention):
|
|
| 511 |
is_causal=True,
|
| 512 |
softmax_scale=scale_factor,
|
| 513 |
use_top_left_mask=self._flash_attn_uses_top_left_mask)
|
| 514 |
-
return attn_out.
|
| 515 |
|
| 516 |
def forward(
|
| 517 |
self,
|
|
|
|
| 493 |
def _compute_attention(self, query_states, key_states, value_states, attention_mask, q_len, position_ids,
|
| 494 |
dropout_rate, sliding_window):
|
| 495 |
"""Flash Attention 2 implements"""
|
| 496 |
+
_input_type = query_states.dtype
|
| 497 |
scale_factor = 1.0 / math.sqrt(self.head_dim)
|
| 498 |
if not self._flash_attn_uses_top_left_mask:
|
| 499 |
causal = self.is_causal
|
|
|
|
| 511 |
is_causal=True,
|
| 512 |
softmax_scale=scale_factor,
|
| 513 |
use_top_left_mask=self._flash_attn_uses_top_left_mask)
|
| 514 |
+
return attn_out.to(_input_type)
|
| 515 |
|
| 516 |
def forward(
|
| 517 |
self,
|