leejunhyeok commited on
Commit
6d0fba5
·
verified ·
1 Parent(s): 9be8a4a

Update modeling_motif.py

Browse files
Files changed (1) hide show
  1. modeling_motif.py +4 -4
modeling_motif.py CHANGED
@@ -1032,9 +1032,9 @@ class MotifFlashAttention2(MotifAttention):
1032
  causal=causal)
1033
  return attn_out
1034
  else:
1035
- attn_out = _flash_attention_forward(query_states,
1036
- key_states,
1037
- value_states,
1038
  attention_mask,
1039
  q_len,
1040
  position_ids=position_ids,
@@ -1044,7 +1044,7 @@ class MotifFlashAttention2(MotifAttention):
1044
  softmax_scale=scale_factor,
1045
  use_top_left_mask=self._flash_attn_uses_top_left_mask)
1046
  #logger.info(attn_out)
1047
- return attn_out
1048
 
1049
  def forward(
1050
  self,
 
1032
  causal=causal)
1033
  return attn_out
1034
  else:
1035
+ attn_out = _flash_attention_forward(query_states.bfloat16(),
1036
+ key_states.bfloat16(),
1037
+ value_states.bfloat16(),
1038
  attention_mask,
1039
  q_len,
1040
  position_ids=position_ids,
 
1044
  softmax_scale=scale_factor,
1045
  use_top_left_mask=self._flash_attn_uses_top_left_mask)
1046
  #logger.info(attn_out)
1047
+ return attn_out.float()
1048
 
1049
  def forward(
1050
  self,