Update modeling_motif.py
Browse files- modeling_motif.py +4 -4
modeling_motif.py
CHANGED
|
@@ -1032,9 +1032,9 @@ class MotifFlashAttention2(MotifAttention):
|
|
| 1032 |
causal=causal)
|
| 1033 |
return attn_out
|
| 1034 |
else:
|
| 1035 |
-
attn_out = _flash_attention_forward(query_states,
|
| 1036 |
-
key_states,
|
| 1037 |
-
value_states,
|
| 1038 |
attention_mask,
|
| 1039 |
q_len,
|
| 1040 |
position_ids=position_ids,
|
|
@@ -1044,7 +1044,7 @@ class MotifFlashAttention2(MotifAttention):
|
|
| 1044 |
softmax_scale=scale_factor,
|
| 1045 |
use_top_left_mask=self._flash_attn_uses_top_left_mask)
|
| 1046 |
#logger.info(attn_out)
|
| 1047 |
-
return attn_out
|
| 1048 |
|
| 1049 |
def forward(
|
| 1050 |
self,
|
|
|
|
| 1032 |
causal=causal)
|
| 1033 |
return attn_out
|
| 1034 |
else:
|
| 1035 |
+
attn_out = _flash_attention_forward(query_states.bfloat16(),
|
| 1036 |
+
key_states.bfloat16(),
|
| 1037 |
+
value_states.bfloat16(),
|
| 1038 |
attention_mask,
|
| 1039 |
q_len,
|
| 1040 |
position_ids=position_ids,
|
|
|
|
| 1044 |
softmax_scale=scale_factor,
|
| 1045 |
use_top_left_mask=self._flash_attn_uses_top_left_mask)
|
| 1046 |
#logger.info(attn_out)
|
| 1047 |
+
return attn_out.float()
|
| 1048 |
|
| 1049 |
def forward(
|
| 1050 |
self,
|