1st
Browse files- .gitattributes +10 -0
- losses.py +103 -0
- step_144783 +3 -0
- step_217174 +3 -0
- step_289565 +3 -0
- step_361957 +3 -0
- step_434348 +3 -0
- step_506739 +3 -0
- step_579130 +3 -0
- step_651522 +3 -0
- step_72391 +3 -0
- step_723914 +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,13 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
step_144783 filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
step_217174 filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
step_289565 filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
step_361957 filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
step_434348 filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
step_506739 filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
step_579130 filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
step_651522 filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
step_72391 filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
step_723914 filter=lfs diff=lfs merge=lfs -text
|
losses.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Tuple, Dict, Sequence, Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torch import nn
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
IGNORE_LABEL_ID = -100
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def s(x, epsilon=1e-30):
|
| 12 |
+
return torch.where(
|
| 13 |
+
x<0,
|
| 14 |
+
1/(1-x+ epsilon),
|
| 15 |
+
x + 1
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def log_stablemax(x, dim=-1):
|
| 20 |
+
s_x = s(x)
|
| 21 |
+
return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def stablemax_cross_entropy(logits, labels, ignore_index: int = -100, valid_mask=None):
|
| 25 |
+
logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
|
| 26 |
+
|
| 27 |
+
if valid_mask is None:
|
| 28 |
+
valid_mask = (labels != ignore_index)
|
| 29 |
+
transformed_labels = torch.where(valid_mask, labels, 0)
|
| 30 |
+
prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
|
| 31 |
+
|
| 32 |
+
return -torch.where(valid_mask, prediction_logprobs, 0)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
|
| 36 |
+
# Cast logits to f32
|
| 37 |
+
# Flatten logits
|
| 38 |
+
return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class ACTLossHead(nn.Module):
|
| 42 |
+
def __init__(self, model: nn.Module, loss_type: str):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.model = model
|
| 45 |
+
self.loss_fn = globals()[loss_type]
|
| 46 |
+
|
| 47 |
+
def initial_carry(self, *args, **kwargs):
|
| 48 |
+
return self.model.initial_carry(*args, **kwargs) # type: ignore
|
| 49 |
+
|
| 50 |
+
def forward(
|
| 51 |
+
self,
|
| 52 |
+
return_keys: Sequence[str],
|
| 53 |
+
# Model args
|
| 54 |
+
**model_kwargs,
|
| 55 |
+
) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
|
| 56 |
+
# Model logits
|
| 57 |
+
# B x SeqLen x D
|
| 58 |
+
new_carry, outputs = self.model(**model_kwargs)
|
| 59 |
+
labels = new_carry.current_data["labels"]
|
| 60 |
+
|
| 61 |
+
with torch.no_grad():
|
| 62 |
+
# Preds
|
| 63 |
+
outputs["preds"] = torch.argmax(outputs["logits"], dim=-1)
|
| 64 |
+
|
| 65 |
+
# Correctness
|
| 66 |
+
mask = (labels != IGNORE_LABEL_ID)
|
| 67 |
+
loss_counts = mask.sum(-1)
|
| 68 |
+
loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
|
| 69 |
+
|
| 70 |
+
is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
|
| 71 |
+
seq_is_correct = is_correct.sum(-1) == loss_counts
|
| 72 |
+
|
| 73 |
+
# Metrics (halted)
|
| 74 |
+
valid_metrics = new_carry.halted & (loss_counts > 0)
|
| 75 |
+
metrics = {
|
| 76 |
+
"count": valid_metrics.sum(),
|
| 77 |
+
|
| 78 |
+
"accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
|
| 79 |
+
"exact_accuracy": (valid_metrics & seq_is_correct).sum(),
|
| 80 |
+
|
| 81 |
+
"q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
|
| 82 |
+
"steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
# Losses
|
| 86 |
+
|
| 87 |
+
lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID, valid_mask=mask) / loss_divisor).sum()
|
| 88 |
+
q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
|
| 89 |
+
metrics.update({
|
| 90 |
+
"lm_loss": lm_loss.detach(),
|
| 91 |
+
"q_halt_loss": q_halt_loss.detach(),
|
| 92 |
+
})
|
| 93 |
+
# Q continue (bootstrapping target loss); Alexia: This fits Q-learning, but seems totally unecessary
|
| 94 |
+
q_continue_loss = 0
|
| 95 |
+
if "target_q_continue" in outputs:
|
| 96 |
+
q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
|
| 97 |
+
|
| 98 |
+
metrics["q_continue_loss"] = q_continue_loss.detach()
|
| 99 |
+
# Filter outputs for return
|
| 100 |
+
detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
|
| 101 |
+
|
| 102 |
+
return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()
|
| 103 |
+
|
step_144783
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9dcc0b0bf3fb7d64c567b59df6ac196706d224b6498f379df945ac3ed6905a9b
|
| 3 |
+
size 2467988405
|
step_217174
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:af2831ec79ccef44854b775ec65adc92eaa05c58a41fa2aabf9f6c5117c05a41
|
| 3 |
+
size 2467988405
|
step_289565
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:26dc32b7d05c8da50de0eb489d4e5c72e6348f78bfbedbfad347c80f5f95507f
|
| 3 |
+
size 2467988405
|
step_361957
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:04083d75970bbdaf4f6649ba7fa3d09136eeebe955149a4e371e24f4e453be7f
|
| 3 |
+
size 2467988405
|
step_434348
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d371fc16816ea73db48eaade70290375feec060e98f211280da69d3b29792ad9
|
| 3 |
+
size 2467988405
|
step_506739
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:efa317c0d741402096f58d0833ab42f619630929f7aebb7cc3e8bbda3899e586
|
| 3 |
+
size 2467988405
|
step_579130
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0c9b2be79abe8123de922fc34266e52324220dda4ab8addfc88652fa2188a04e
|
| 3 |
+
size 2467988405
|
step_651522
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0cf3dfcbaa96d69867db31ef68c7d5015fdf40cad956c973ab34725b10a3fb4b
|
| 3 |
+
size 2467988405
|
step_72391
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:27025e2f399c310ba22bc3711396c0d67e664e245ebd515737aca7e110fbdc40
|
| 3 |
+
size 2467988386
|
step_723914
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:15a35887067ad27145bf8e8551b0c0deb30c12b19c7087082556e5358d5d0f78
|
| 3 |
+
size 2467988405
|