pushkin05 commited on
Commit
5490ded
·
verified ·
1 Parent(s): 8b739b8
Files changed (12) hide show
  1. .gitattributes +10 -0
  2. losses.py +103 -0
  3. step_144783 +3 -0
  4. step_217174 +3 -0
  5. step_289565 +3 -0
  6. step_361957 +3 -0
  7. step_434348 +3 -0
  8. step_506739 +3 -0
  9. step_579130 +3 -0
  10. step_651522 +3 -0
  11. step_72391 +3 -0
  12. step_723914 +3 -0
.gitattributes CHANGED
@@ -33,3 +33,13 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ step_144783 filter=lfs diff=lfs merge=lfs -text
37
+ step_217174 filter=lfs diff=lfs merge=lfs -text
38
+ step_289565 filter=lfs diff=lfs merge=lfs -text
39
+ step_361957 filter=lfs diff=lfs merge=lfs -text
40
+ step_434348 filter=lfs diff=lfs merge=lfs -text
41
+ step_506739 filter=lfs diff=lfs merge=lfs -text
42
+ step_579130 filter=lfs diff=lfs merge=lfs -text
43
+ step_651522 filter=lfs diff=lfs merge=lfs -text
44
+ step_72391 filter=lfs diff=lfs merge=lfs -text
45
+ step_723914 filter=lfs diff=lfs merge=lfs -text
losses.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Tuple, Dict, Sequence, Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ import math
7
+
8
+ IGNORE_LABEL_ID = -100
9
+
10
+
11
+ def s(x, epsilon=1e-30):
12
+ return torch.where(
13
+ x<0,
14
+ 1/(1-x+ epsilon),
15
+ x + 1
16
+ )
17
+
18
+
19
+ def log_stablemax(x, dim=-1):
20
+ s_x = s(x)
21
+ return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
22
+
23
+
24
+ def stablemax_cross_entropy(logits, labels, ignore_index: int = -100, valid_mask=None):
25
+ logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
26
+
27
+ if valid_mask is None:
28
+ valid_mask = (labels != ignore_index)
29
+ transformed_labels = torch.where(valid_mask, labels, 0)
30
+ prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
31
+
32
+ return -torch.where(valid_mask, prediction_logprobs, 0)
33
+
34
+
35
+ def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
36
+ # Cast logits to f32
37
+ # Flatten logits
38
+ return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
39
+
40
+
41
+ class ACTLossHead(nn.Module):
42
+ def __init__(self, model: nn.Module, loss_type: str):
43
+ super().__init__()
44
+ self.model = model
45
+ self.loss_fn = globals()[loss_type]
46
+
47
+ def initial_carry(self, *args, **kwargs):
48
+ return self.model.initial_carry(*args, **kwargs) # type: ignore
49
+
50
+ def forward(
51
+ self,
52
+ return_keys: Sequence[str],
53
+ # Model args
54
+ **model_kwargs,
55
+ ) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
56
+ # Model logits
57
+ # B x SeqLen x D
58
+ new_carry, outputs = self.model(**model_kwargs)
59
+ labels = new_carry.current_data["labels"]
60
+
61
+ with torch.no_grad():
62
+ # Preds
63
+ outputs["preds"] = torch.argmax(outputs["logits"], dim=-1)
64
+
65
+ # Correctness
66
+ mask = (labels != IGNORE_LABEL_ID)
67
+ loss_counts = mask.sum(-1)
68
+ loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
69
+
70
+ is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
71
+ seq_is_correct = is_correct.sum(-1) == loss_counts
72
+
73
+ # Metrics (halted)
74
+ valid_metrics = new_carry.halted & (loss_counts > 0)
75
+ metrics = {
76
+ "count": valid_metrics.sum(),
77
+
78
+ "accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
79
+ "exact_accuracy": (valid_metrics & seq_is_correct).sum(),
80
+
81
+ "q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
82
+ "steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
83
+ }
84
+
85
+ # Losses
86
+
87
+ lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID, valid_mask=mask) / loss_divisor).sum()
88
+ q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
89
+ metrics.update({
90
+ "lm_loss": lm_loss.detach(),
91
+ "q_halt_loss": q_halt_loss.detach(),
92
+ })
93
+ # Q continue (bootstrapping target loss); Alexia: This fits Q-learning, but seems totally unecessary
94
+ q_continue_loss = 0
95
+ if "target_q_continue" in outputs:
96
+ q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
97
+
98
+ metrics["q_continue_loss"] = q_continue_loss.detach()
99
+ # Filter outputs for return
100
+ detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
101
+
102
+ return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()
103
+
step_144783 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9dcc0b0bf3fb7d64c567b59df6ac196706d224b6498f379df945ac3ed6905a9b
3
+ size 2467988405
step_217174 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af2831ec79ccef44854b775ec65adc92eaa05c58a41fa2aabf9f6c5117c05a41
3
+ size 2467988405
step_289565 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:26dc32b7d05c8da50de0eb489d4e5c72e6348f78bfbedbfad347c80f5f95507f
3
+ size 2467988405
step_361957 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:04083d75970bbdaf4f6649ba7fa3d09136eeebe955149a4e371e24f4e453be7f
3
+ size 2467988405
step_434348 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d371fc16816ea73db48eaade70290375feec060e98f211280da69d3b29792ad9
3
+ size 2467988405
step_506739 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:efa317c0d741402096f58d0833ab42f619630929f7aebb7cc3e8bbda3899e586
3
+ size 2467988405
step_579130 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c9b2be79abe8123de922fc34266e52324220dda4ab8addfc88652fa2188a04e
3
+ size 2467988405
step_651522 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0cf3dfcbaa96d69867db31ef68c7d5015fdf40cad956c973ab34725b10a3fb4b
3
+ size 2467988405
step_72391 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:27025e2f399c310ba22bc3711396c0d67e664e245ebd515737aca7e110fbdc40
3
+ size 2467988386
step_723914 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:15a35887067ad27145bf8e8551b0c0deb30c12b19c7087082556e5358d5d0f78
3
+ size 2467988405