whoisjones commited on
Commit
afec47f
·
verified ·
1 Parent(s): d8aee93

Upload loss.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. loss.py +94 -0
loss.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class BCELoss(nn.Module):
7
+ def forward(self, logits, labels, mask=None, pos_weight=None, **kwargs):
8
+ loss = F.binary_cross_entropy_with_logits(
9
+ logits,
10
+ labels,
11
+ reduction="none",
12
+ pos_weight=pos_weight
13
+ )
14
+ if mask is not None:
15
+ loss = (loss * mask).mean() * 100
16
+ else:
17
+ loss = loss.mean() * 100
18
+ return loss
19
+
20
+
21
+ class FocalLoss(nn.Module):
22
+ def __init__(self, alpha=0.5, gamma=1.0):
23
+ super().__init__()
24
+ self.alpha = alpha
25
+ self.gamma = gamma
26
+
27
+ def forward(self, logits, labels, mask=None, pos_weight=None, **kwargs):
28
+ if not (0 <= self.alpha <= 1) and self.alpha != -1:
29
+ raise ValueError(f"Invalid alpha value: {self.alpha}. alpha must be in the range [0,1] or -1 for ignore.")
30
+
31
+ p = torch.sigmoid(logits)
32
+ ce_loss = F.binary_cross_entropy_with_logits(logits, labels, reduction="none", pos_weight=pos_weight)
33
+ p_t = p * labels + (1 - p) * (1 - labels)
34
+ loss = ce_loss * ((1 - p_t) ** self.gamma)
35
+
36
+ if self.alpha >= 0:
37
+ alpha_t = self.alpha * labels + (1 - self.alpha) * (1 - labels)
38
+ loss = alpha_t * loss
39
+
40
+ if mask is not None:
41
+ loss = (loss * mask).mean() * 100
42
+ else:
43
+ loss = loss.mean() * 100
44
+
45
+ return loss
46
+
47
+
48
+ class ContrastiveLoss(nn.Module):
49
+
50
+ def __init__(self, tau: float = 1.0):
51
+ super().__init__()
52
+ self.tau = tau
53
+
54
+ def forward(
55
+ self,
56
+ scores: torch.tensor,
57
+ positions: list[int],
58
+ mask: torch.tensor,
59
+ prob_mask: torch.tensor = None
60
+ ) -> torch.tensor:
61
+ batch_size, seq_length = scores.size(0), scores.size(1)
62
+ scores = scores / self.tau
63
+ if len(scores.shape) == 3:
64
+ scores = scores.view(batch_size, -1)
65
+ mask = mask.view(batch_size, -1)
66
+ log_probs = self.masked_log_softmax(scores, mask)
67
+ log_probs = log_probs.view(batch_size, seq_length, seq_length)
68
+ start_positions, end_positions = positions
69
+ batch_indices = list(range(batch_size))
70
+ log_probs = log_probs[batch_indices, start_positions, end_positions]
71
+ else:
72
+ log_probs = self.masked_log_softmax(scores, mask)
73
+ batch_indices = list(range(batch_size))
74
+ log_probs = log_probs[batch_indices, positions]
75
+ if prob_mask is not None:
76
+ log_probs = log_probs * prob_mask
77
+ return - log_probs.mean()
78
+
79
+ def masked_log_softmax(self, vector: torch.tensor, mask: torch.tensor, dim: int = -1) -> torch.tensor:
80
+ if mask is not None:
81
+ while mask.dim() < vector.dim():
82
+ mask = mask.unsqueeze(1)
83
+ vector = vector + (mask + self.tiny_value_of_dtype(vector.dtype)).log()
84
+ return torch.nn.functional.log_softmax(vector, dim=dim)
85
+
86
+ def tiny_value_of_dtype(self, dtype: torch.dtype) -> float:
87
+ if not dtype.is_floating_point:
88
+ raise TypeError("Only supports floating point dtypes.")
89
+ if dtype == torch.float or dtype == torch.double or dtype == torch.bfloat16:
90
+ return 1e-13
91
+ elif dtype == torch.half:
92
+ return 1e-4
93
+ else:
94
+ raise TypeError("Does not support dtype " + str(dtype))