Spaces:
Paused
Paused
| from typing import Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| from open_flamingo.src.helpers import PerceiverResampler | |
| from robot_flamingo.models.normalizer import LinearNormalizer | |
| from robot_flamingo.models.trajectory_gpt2 import get_gpt_model | |
| # from .unets import * | |
| import copy | |
| def lstm_decoder( | |
| in_features: int, hidden_size: int, num_layers: int, policy_rnn_dropout_p: float | |
| ) -> torch.nn.Module: | |
| return nn.LSTM( | |
| input_size=in_features, | |
| hidden_size=hidden_size, | |
| num_layers=num_layers, | |
| bidirectional=False, | |
| batch_first=True, | |
| dropout=policy_rnn_dropout_p, | |
| ) | |
| class MLPTanhHead(torch.nn.Module): | |
| def __init__(self, hidden_size, output_size): | |
| super().__init__() | |
| self.mlp = torch.nn.Sequential( | |
| torch.nn.Linear(hidden_size, 1024), | |
| torch.nn.ReLU(), | |
| torch.nn.Linear(1024, 512), | |
| torch.nn.ReLU(), | |
| torch.nn.Linear(512, 256), | |
| torch.nn.ReLU(), | |
| torch.nn.Linear(256, output_size), | |
| torch.nn.Tanh(), | |
| ) | |
| def forward(self, x): | |
| return self.mlp(x) | |
| class MLPNohHead(torch.nn.Module): | |
| def __init__(self, hidden_size, output_size): | |
| super().__init__() | |
| self.mlp = torch.nn.Sequential( | |
| torch.nn.Linear(hidden_size, 1024), | |
| torch.nn.ReLU(), | |
| torch.nn.Linear(1024, 512), | |
| torch.nn.ReLU(), | |
| torch.nn.Linear(512, 256), | |
| torch.nn.ReLU(), | |
| torch.nn.Linear(256, output_size) | |
| ) | |
| def forward(self, x): | |
| return self.mlp(x) | |
| class MLPSigmoidHead(torch.nn.Module): | |
| def __init__(self, hidden_size, output_size): | |
| super().__init__() | |
| self.mlp = torch.nn.Sequential( | |
| torch.nn.Linear(hidden_size, 1024), | |
| torch.nn.ReLU(), | |
| torch.nn.Linear(1024, 512), | |
| torch.nn.ReLU(), | |
| torch.nn.Linear(512, 256), | |
| torch.nn.ReLU(), | |
| torch.nn.Linear(256, output_size), | |
| torch.nn.Sigmoid(), | |
| ) | |
| def forward(self, x): | |
| return self.mlp(x) | |
| class MLPActionHead(torch.nn.Module): | |
| def __init__(self, hidden_size): | |
| super().__init__() | |
| self.hidden_size = hidden_size | |
| # Create a linear layer for each action | |
| self.num_head = nn.Sequential( | |
| nn.Linear(hidden_size, 1024), | |
| nn.ReLU(), | |
| nn.Linear(1024, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, 256), | |
| nn.ReLU(), | |
| nn.Linear(256, 6), | |
| ) | |
| self.bin_head = nn.Sequential( | |
| nn.Linear(hidden_size, 1024), | |
| nn.ReLU(), | |
| nn.Linear(1024, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, 256), | |
| nn.ReLU(), | |
| nn.Linear(256, 1), | |
| ) | |
| def forward(self, x): | |
| x = x[:, -1] # pick up the last frame output | |
| x1 = self.num_head(x) | |
| x2 = self.bin_head(x).sigmoid() | |
| return x1, x2 | |
| class ActionDecoder(nn.Module): | |
| def act( | |
| self, | |
| latent_plan: torch.Tensor, | |
| perceptual_emb: torch.Tensor, | |
| latent_goal: torch.Tensor, | |
| robot_obs: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| raise NotImplementedError | |
| def loss( | |
| self, | |
| latent_plan: torch.Tensor, | |
| perceptual_emb: torch.Tensor, | |
| latent_goal: torch.Tensor, | |
| actions: torch.Tensor, | |
| robot_obs: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| raise NotImplementedError | |
| def loss_and_act( | |
| self, | |
| latent_plan: torch.Tensor, | |
| perceptual_emb: torch.Tensor, | |
| latent_goal: torch.Tensor, | |
| actions: torch.Tensor, | |
| robot_obs: Optional[torch.Tensor] = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| raise NotImplementedError | |
| def _sample(self, *args, **kwargs): | |
| raise NotImplementedError | |
| def forward( | |
| self, | |
| latent_plan: torch.Tensor, | |
| perceptual_emb: torch.Tensor, | |
| latent_goal: torch.Tensor, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| raise NotImplementedError | |
| def clear_hidden_state(self) -> None: | |
| pass | |
| class FCDecoder(ActionDecoder): | |
| def __init__( | |
| self, | |
| in_features: int, | |
| window_size: int, | |
| history_len = None, | |
| out_features: int = 6, | |
| hidden_size: int = 1024, | |
| num_layers: int = 4, | |
| policy_rnn_dropout_p: float = 0.1, | |
| use_diff=False, | |
| last_action=False, | |
| fusion_mode='', | |
| use_state=False, | |
| return_feature=False, | |
| multi_step_action=1 | |
| ): | |
| super(FCDecoder, self).__init__() | |
| self.return_feature = return_feature | |
| if use_state: | |
| state_in_dim = 7 | |
| state_out_dim = 128 | |
| self.fc_state = MLPNohHead(state_in_dim, state_out_dim) | |
| in_features += state_out_dim | |
| if fusion_mode == 'two_way': | |
| in_features *= 2 | |
| self.return_feature = return_feature | |
| self.in_features = in_features | |
| self.out_features = out_features | |
| self.window_size = window_size | |
| self.multi_step_action = multi_step_action | |
| if history_len is None: | |
| history_len = window_size | |
| self.history_len = history_len | |
| self.history_memory = [] | |
| self.use_diff = use_diff | |
| self.mlp = torch.nn.Sequential( | |
| torch.nn.Linear(in_features, in_features//2), | |
| torch.nn.ReLU(), | |
| torch.nn.Linear(in_features//2, hidden_size), | |
| ) | |
| if not use_diff: | |
| self.actions = MLPTanhHead(hidden_size, out_features) | |
| self.gripper = MLPSigmoidHead(hidden_size, 1) | |
| self.hidden_state = None | |
| self.hidden_size = hidden_size * history_len | |
| self.rnn_out = None | |
| self.last_action = last_action | |
| if self.use_diff: | |
| self.last_action = True | |
| # self.global_1d_pool = nn.AdaptiveAvgPool1d(1) | |
| self.global_1d_pool = nn.AdaptiveMaxPool1d(1) | |
| def forward( # type: ignore | |
| self, | |
| input_feature: torch.Tensor, | |
| h_0: Optional[torch.Tensor] = None, | |
| state_tensor = None, | |
| ): | |
| if self.return_feature: | |
| org_feat = copy.deepcopy(input_feature) | |
| org_feat = org_feat.view(self.window_size, *org_feat.shape[1:]) | |
| # reshape | |
| input_feature = self.mlp(input_feature) | |
| input_feature = self.global_1d_pool(input_feature.permute(0, 2, 1)).squeeze(-1) | |
| if self.use_diff: | |
| input_feature = input_feature.reshape(-1, self.window_size * input_feature.shape[1]) | |
| return input_feature | |
| input_feature = input_feature.reshape(-1, self.window_size, input_feature.shape[1]) | |
| if state_tensor is not None: | |
| state_tensor = self.fc_state(state_tensor) | |
| state_tensor = state_tensor.reshape(-1, self.window_size, state_tensor.shape[-1]) | |
| input_feature = torch.cat([input_feature, state_tensor], dim=-1) | |
| actions = self.actions(input_feature) | |
| gripper = self.gripper(input_feature) | |
| if self.return_feature: | |
| return actions, gripper, org_feat | |
| else: | |
| return actions, gripper | |
| class DeterministicDecoder(ActionDecoder): | |
| def __init__( | |
| self, | |
| in_features: int, | |
| window_size: int, | |
| history_len = None, | |
| out_features: int = 6, | |
| hidden_size: int = 1024, | |
| num_layers: int = 4, | |
| policy_rnn_dropout_p: float = 0.1, | |
| use_diff=False, | |
| last_action=False, | |
| fusion_mode='', | |
| use_state=False, | |
| multi_step_action=1, | |
| return_feature=False, | |
| pooling='max' | |
| ): | |
| super(DeterministicDecoder, self).__init__() | |
| self.fc_state = None | |
| self.use_state = use_state | |
| if use_state: | |
| print('Using state in decoder') | |
| state_in_dim = 7 | |
| # state_out_dim = 256 | |
| # in_features += state_out_dim | |
| # self.embed_arm_state = nn.Sequential(torch.nn.Linear(state_in_dim-1, state_out_dim), nn.ReLU()) | |
| # self.embed_gripper_state = nn.Sequential(torch.nn.Embedding(2, state_out_dim), nn.ReLU()) # one-hot gripper state | |
| # self.embed_state = torch.nn.Linear(2*state_out_dim, state_out_dim) | |
| self.embed_arm_state = nn.Sequential(torch.nn.Linear(state_in_dim-1, in_features), nn.ReLU()) | |
| self.embed_gripper_state = nn.Sequential(torch.nn.Embedding(2, in_features), nn.ReLU()) # one-hot gripper state | |
| self.embed_state = torch.nn.Linear(2*in_features, in_features) | |
| if fusion_mode == 'two_way': | |
| in_features *= 2 | |
| self.return_feature = return_feature | |
| self.in_features = in_features | |
| self.out_features = out_features | |
| self.window_size = window_size | |
| self.multi_step_action = multi_step_action | |
| if history_len is None: | |
| history_len = window_size | |
| self.history_len = history_len | |
| self.history_memory = [] | |
| self.rnn = lstm_decoder | |
| self.rnn = self.rnn(in_features, hidden_size, num_layers, policy_rnn_dropout_p) | |
| self.use_diff = use_diff | |
| self.fusion_mode = fusion_mode | |
| if not use_diff: | |
| self.actions = MLPTanhHead(hidden_size, out_features*multi_step_action) | |
| self.gripper = MLPSigmoidHead(hidden_size, 1*multi_step_action) | |
| self.hidden_state = None | |
| self.hidden_size = hidden_size | |
| self.rnn_out = None | |
| self.last_action = last_action | |
| if self.use_diff: | |
| self.last_action = True | |
| if pooling == 'max': | |
| self.global_1d_pool = nn.AdaptiveMaxPool1d(1) | |
| else: | |
| self.global_1d_pool = nn.AdaptiveAvgPool1d(1) | |
| if self.fusion_mode == 'two_way': | |
| if pooling == 'max': | |
| self.gripper_1d_max_pool = nn.AdaptiveMaxPool1d(1) | |
| else: | |
| self.gripper_1d_max_pool = nn.AdaptiveAvgPool1d(1) | |
| def clear_hidden_state(self) -> None: | |
| self.hidden_state = None | |
| def forward( # type: ignore | |
| self, | |
| input_feature: torch.Tensor, | |
| h_0: Optional[torch.Tensor] = None, | |
| state_tensor=None, | |
| return_feature=False | |
| ): | |
| # reshape | |
| if input_feature.dim() == 3: | |
| if self.fusion_mode == 'two_way': | |
| input_feature = input_feature.reshape(-1, self.window_size, *input_feature.shape[1:]) | |
| bs = int(input_feature.shape[0] // 2) | |
| rgb_feat = input_feature[:bs].view(bs*self.window_size, *input_feature.shape[2:]) | |
| rgb_feat = self.global_1d_pool(rgb_feat.permute(0, 2, 1)).squeeze(-1) | |
| gripper_feat = input_feature[bs:].view(bs*self.window_size, *input_feature.shape[2:]) | |
| gripper_feat = self.global_1d_pool(gripper_feat.permute(0, 2, 1)).squeeze(-1) | |
| input_feature = torch.cat([rgb_feat, gripper_feat], dim=-1) | |
| else: | |
| input_feature = self.global_1d_pool(input_feature.permute(0, 2, 1)).squeeze(-1) | |
| input_feature = input_feature.reshape(-1, self.window_size, input_feature.shape[1]) | |
| if self.return_feature: | |
| org_feat = copy.deepcopy(input_feature) | |
| org_feat = org_feat.view(self.window_size, org_feat.shape[-1]) | |
| if state_tensor is not None and self.use_state: | |
| arm_state = state_tensor[..., :6] # b,len,state_dim-1 | |
| arm_state_embeddings = self.embed_arm_state(arm_state) | |
| arm_state_embeddings = arm_state_embeddings.view(-1, self.window_size, arm_state_embeddings.shape[-1]) # b,len,h | |
| gripper_state = ((state_tensor[..., -1]+1.0) / 2).long() # b,len,1 | |
| gripper_state_embeddings = self.embed_gripper_state(gripper_state) | |
| gripper_state_embeddings = gripper_state_embeddings.view(-1, self.window_size, gripper_state_embeddings.shape[-1]) # b,len,h | |
| state_embeddings = torch.cat((arm_state_embeddings, gripper_state_embeddings), dim=2) # b,len,2h | |
| state_embeddings = self.embed_state(state_embeddings) # b,len,h | |
| # input_feature = torch.cat([input_feature, state_embeddings], dim=-1) | |
| input_feature = input_feature + state_embeddings | |
| if not isinstance(self.rnn, nn.Sequential) and isinstance(self.rnn, nn.RNNBase): | |
| # print('history len:',self.history_len) | |
| if input_feature.shape[1] == 1: | |
| self.history_memory.append(input_feature) | |
| if len(self.history_memory) <= self.history_len: | |
| # print('cur hist_mem len: {}'.format(len(self.history_memory))) | |
| x, h_n = self.rnn(input_feature, self.hidden_state) | |
| self.hidden_state = h_n | |
| x = x[:, -1].unsqueeze(1) | |
| self.rnn_out = x.squeeze(1) | |
| else: | |
| # the hidden state need to be refreshed based on the history window | |
| # print('hist_mem exceeded, refresh hidden state') | |
| cur_len = len(self.history_memory) | |
| for _ in range(cur_len - self.history_len): | |
| self.history_memory.pop(0) | |
| assert len(self.history_memory) == self.history_len | |
| hist_feature = torch.cat(self.history_memory, dim=1) | |
| self.hidden_state = None | |
| x, h_n = self.rnn(hist_feature, self.hidden_state) | |
| x = x[:, -1].unsqueeze(1) | |
| self.rnn_out = x.squeeze(1) | |
| else: | |
| # print('input feature lenght > 1', input_feature.shape) | |
| self.hidden_state = h_0 | |
| x, h_n = self.rnn(input_feature, self.hidden_state) | |
| self.hidden_state = h_n | |
| if self.last_action: | |
| x = x[:, -1].unsqueeze(1) | |
| self.rnn_out = x.squeeze(1) | |
| else: | |
| raise NotImplementedError | |
| if self.use_diff: | |
| return self.rnn_out | |
| actions = self.actions(x) | |
| gripper = self.gripper(x) | |
| if self.return_feature: | |
| return actions, gripper, org_feat | |
| else: | |
| return actions, gripper | |
| def act( | |
| self, | |
| input_feature: torch.Tensor, | |
| ) -> torch.Tensor: | |
| pred_actions, self.hidden_state = self( | |
| input_feature, self.hidden_state | |
| ) | |
| return pred_actions | |
| class GPTDecoder(ActionDecoder): | |
| def __init__( | |
| self, | |
| in_features: int, | |
| window_size: int, | |
| history_len = None, | |
| out_features: int = 6, | |
| hidden_size = None, | |
| num_layers: int = 4, | |
| policy_rnn_dropout_p: float = 0.1, | |
| last_action=False, | |
| use_diff=False, | |
| fusion_mode='', | |
| use_state=False, | |
| multi_step_action=1, | |
| return_feature=False, | |
| pooling='max', | |
| **kwargs | |
| ): | |
| super(GPTDecoder, self).__init__() | |
| if use_state: | |
| state_in_dim = 7 | |
| state_out_dim = 128 | |
| self.fc_state = MLPNohHead(state_in_dim, state_out_dim) | |
| in_features += state_out_dim | |
| if fusion_mode == 'two_way': | |
| in_features *= 2 | |
| self.return_feature = return_feature | |
| self.in_features = in_features | |
| self.out_features = out_features | |
| self.window_size = window_size | |
| self.multi_step_action = multi_step_action | |
| if history_len is None: | |
| history_len = window_size | |
| self.history_len = history_len | |
| self.history_memory = [] | |
| if hidden_size is None: | |
| hidden_size = in_features | |
| self.gpt = get_gpt_model(hidden_size, history_len) | |
| self.use_diff = use_diff | |
| self.fusion_mode = fusion_mode | |
| self.hidden_size = hidden_size | |
| if hidden_size != in_features: | |
| self.fc = nn.Linear(in_features, hidden_size) | |
| else: | |
| self.fc = nn.Identity() | |
| if not use_diff: | |
| self.actions = MLPTanhHead(hidden_size, out_features*multi_step_action) | |
| self.gripper = MLPSigmoidHead(hidden_size, 1*multi_step_action) | |
| self.hidden_state = None | |
| self.hidden_size = hidden_size | |
| self.rnn_out = None | |
| self.last_action = last_action | |
| if self.use_diff: | |
| self.last_action = True | |
| if pooling == 'max': | |
| self.global_1d_pool = nn.AdaptiveMaxPool1d(1) | |
| else: | |
| self.global_1d_pool = nn.AdaptiveAvgPool1d(1) | |
| if self.fusion_mode == 'two_way': | |
| if pooling == 'max': | |
| self.gripper_1d_max_pool = nn.AdaptiveMaxPool1d(1) | |
| else: | |
| self.gripper_1d_max_pool = nn.AdaptiveAvgPool1d(1) | |
| def forward(self, input_feature: torch.Tensor): | |
| time_step=None | |
| attention_mask=None | |
| if input_feature.dim() == 3: | |
| input_feature = self.global_1d_pool(input_feature.permute(0, 2, 1)).squeeze(-1) | |
| input_feature = input_feature.reshape(-1, self.window_size, input_feature.shape[1]) # bs, seq_len, feat_dim | |
| input_feature = self.fc(input_feature) | |
| if input_feature.shape[1] == 1: | |
| self.history_memory.append(input_feature) | |
| if len(self.history_memory) <= self.history_len: | |
| hist_feature = torch.cat(self.history_memory, dim=1) | |
| x = self.gpt(hist_feature, time_step ,attention_mask) | |
| x = x[:, -1].unsqueeze(1) | |
| else: | |
| # the hidden state need to be refreshed based on the history window | |
| cur_len = len(self.history_memory) | |
| for _ in range(cur_len - self.history_len): | |
| self.history_memory.pop(0) | |
| assert len(self.history_memory) == self.history_len | |
| hist_feature = torch.cat(self.history_memory, dim=1) | |
| x= self.gpt(hist_feature, time_step, attention_mask) | |
| x = x[:, -1].unsqueeze(1) | |
| else: | |
| x = self.gpt(input_feature, time_step, attention_mask) | |
| if self.last_action: | |
| x = x[:, -1].unsqueeze(1) | |
| actions = self.actions(x) | |
| gripper = self.gripper(x) | |
| return actions, gripper | |
| def get_pattern_name(self): | |
| return 'gpt_{}_'.format(self.hidden_size, ) | |
| class GPTDecoderActPad(ActionDecoder): | |
| def __init__( | |
| self, | |
| in_features: int, | |
| window_size: int, | |
| use_vision = False, | |
| history_len = None, | |
| out_features: int = 6, | |
| hidden_size = None, | |
| last_action=False, | |
| use_diff=False, | |
| fusion_mode='', | |
| use_state=False, | |
| multi_step_action=1, | |
| return_feature=False, | |
| pooling='sampler', | |
| global_latent=10, | |
| **kwargs | |
| ): | |
| super(GPTDecoderActPad, self).__init__() | |
| if use_state: | |
| state_in_dim = 7 | |
| state_out_dim = 128 | |
| self.fc_state = MLPNohHead(state_in_dim, state_out_dim) | |
| in_features += state_out_dim | |
| if fusion_mode == 'two_way': | |
| in_features *= 2 | |
| self.return_feature = return_feature | |
| self.in_features = in_features | |
| self.out_features = out_features | |
| self.window_size = window_size | |
| self.multi_step_action = multi_step_action | |
| if history_len is None: | |
| history_len = window_size | |
| self.history_len = history_len | |
| self.history_memory = [] | |
| if hidden_size is None: | |
| hidden_size = in_features | |
| self.gpt = get_gpt_model(hidden_size, history_len, use_pe=False) | |
| self.use_diff = use_diff | |
| self.fusion_mode = fusion_mode | |
| self.hidden_size = hidden_size | |
| if hidden_size != in_features: | |
| self.fc = nn.Linear(in_features, hidden_size) | |
| else: | |
| self.fc = nn.Identity() | |
| if not use_diff: | |
| self.actions = MLPTanhHead(hidden_size, out_features*multi_step_action) | |
| self.gripper = MLPSigmoidHead(hidden_size, 1*multi_step_action) | |
| self.hidden_state = None | |
| self.hidden_size = hidden_size | |
| self.rnn_out = None | |
| self.last_action = last_action | |
| if self.use_diff: | |
| self.last_action = True | |
| self.global_latent = global_latent | |
| self.use_vision = use_vision | |
| if self.use_vision: | |
| self.vision_resampler = PerceiverResampler(dim=hidden_size) | |
| if pooling == 'sampler': | |
| self.global_1d_pool = PerceiverResampler(dim=hidden_size, depth=2, num_latents=global_latent) | |
| if pooling == 'max': | |
| self.global_1d_pool = nn.AdaptiveMaxPool1d(1) | |
| else: | |
| self.global_1d_pool = nn.AdaptiveAvgPool1d(1) | |
| if self.fusion_mode == 'two_way': | |
| if pooling == 'max': | |
| self.gripper_1d_max_pool = nn.AdaptiveMaxPool1d(1) | |
| else: | |
| self.gripper_1d_max_pool = nn.AdaptiveAvgPool1d(1) | |
| def forward(self, input_feature: torch.Tensor, rgb=None): | |
| time_step=None | |
| attention_mask=None | |
| input_feature = self.global_1d_pool(input_feature.unsqueeze(1)).squeeze(1) | |
| input_feature = input_feature.view(-1, self.window_size, self.global_latent, input_feature.shape[-1]) # bs, seq_len, n_tok, feat_dim | |
| bs, seq_len, n_tok = input_feature.shape[:3] | |
| input_feature = self.fc(input_feature) # # bs, seq_len, n_tok, feat_dim | |
| attention_mask = torch.ones((bs, n_tok, seq_len), dtype=torch.long).to(input_feature.device) | |
| if input_feature.shape[1] == 1: | |
| self.history_memory.append(input_feature) | |
| if len(self.history_memory) <= self.history_len: | |
| hist_feature = torch.cat(self.history_memory, dim=1) | |
| x = self.gpt(hist_feature, time_step ,attention_mask) | |
| x = x[:, -1].unsqueeze(1) | |
| else: | |
| # the hidden state need to be refreshed based on the history window | |
| cur_len = len(self.history_memory) | |
| for _ in range(cur_len - self.history_len): | |
| self.history_memory.pop(0) | |
| assert len(self.history_memory) == self.history_len | |
| hist_feature = torch.cat(self.history_memory, dim=1) | |
| x= self.gpt(hist_feature, time_step, attention_mask) | |
| x = x[:, -1].unsqueeze(1) | |
| else: | |
| x = self.gpt(input_feature, time_step, attention_mask) | |
| if self.last_action: | |
| x = x[:, -1].unsqueeze(1) | |
| actions = self.actions(x) | |
| gripper = nn.functional.sigmoid(self.gripper(x)) | |
| return actions, gripper | |
| def get_pattern_name(self): | |
| return 'gpt_{}_'.format(self.hidden_size, ) | |
| class DiffusionDecoder(ActionDecoder): | |
| def __init__( | |
| self, | |
| feature_dim: int, | |
| window_size: int, | |
| history_len = None, | |
| horizon = 32, | |
| input_dim: int = 7, # dim of vectors to be diffused | |
| diffusion_step_embed_dim=256, | |
| down_dims=[256,512,1024], | |
| kernel_size=3, | |
| n_groups=8, | |
| cond_predict_scale=False, | |
| n_timesteps=150, | |
| clip_denoised=False, | |
| predict_epsilon=True, | |
| normalizer = LinearNormalizer() | |
| ): | |
| super(DiffusionDecoder, self).__init__() | |
| self.feature_dim = feature_dim | |
| self.horizon = horizon | |
| self.window_size = window_size | |
| if history_len is None: | |
| history_len = window_size | |
| self.history_len = history_len | |
| self.history_memory = [] | |
| self.normalizer = normalizer | |
| self.data_dim = input_dim | |
| self.model = ConditionalUnet1D( | |
| input_dim, | |
| global_cond_dim=feature_dim, | |
| # global_cond_dim=None, | |
| diffusion_step_embed_dim=diffusion_step_embed_dim, | |
| down_dims=down_dims, | |
| kernel_size=kernel_size, | |
| n_groups=n_groups, | |
| cond_predict_scale=cond_predict_scale, | |
| ) | |
| betas = cosine_beta_schedule(n_timesteps) | |
| alphas = 1.0 - betas | |
| alphas_cumprod = torch.cumprod(alphas, axis=0) | |
| alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]]) | |
| self.n_timesteps = int(n_timesteps) | |
| self.clip_denoised = clip_denoised | |
| self.predict_epsilon = predict_epsilon | |
| self.register_buffer("betas", betas) | |
| self.register_buffer("alphas_cumprod", alphas_cumprod) | |
| self.register_buffer("alphas_cumprod_prev", alphas_cumprod_prev) | |
| # calculations for diffusion q(x_t | x_{t-1}) and others | |
| self.register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod)) | |
| self.register_buffer( | |
| "sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod) | |
| ) | |
| self.register_buffer( | |
| "log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod) | |
| ) | |
| self.register_buffer( | |
| "sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod) | |
| ) | |
| self.register_buffer( | |
| "sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1) | |
| ) | |
| # calculations for posterior q(x_{t-1} | x_t, x_0) | |
| posterior_variance = ( | |
| betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) | |
| ) | |
| self.register_buffer("posterior_variance", posterior_variance) | |
| # log calculation clipped because the posterior variance | |
| # is 0 at the beginning of the diffusion chain | |
| self.register_buffer( | |
| "posterior_log_variance_clipped", | |
| torch.log(torch.clamp(posterior_variance, min=1e-20)), | |
| ) | |
| self.register_buffer( | |
| "posterior_mean_coef1", | |
| betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod), | |
| ) | |
| self.register_buffer( | |
| "posterior_mean_coef2", | |
| (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod), | |
| ) | |
| def get_loss_weights(self, action_weight, discount, weights_dict): | |
| """ | |
| sets loss coefficients for trajectory | |
| action_weight : float | |
| coefficient on first action loss | |
| discount : float | |
| multiplies t^th timestep of trajectory loss by discount**t | |
| weights_dict : dict | |
| { i: c } multiplies dimension i of observation loss by c | |
| """ | |
| self.action_weight = action_weight | |
| dim_weights = torch.ones(self.action_dim, dtype=torch.float32) | |
| # set loss coefficients for dimensions of observation | |
| if weights_dict is None: | |
| weights_dict = {} | |
| for ind, w in weights_dict.items(): | |
| dim_weights[self.action_dim + ind] *= w | |
| # decay loss with trajectory timestep: discount**t | |
| discounts = discount ** torch.arange(self.horizon, dtype=torch.float) | |
| discounts = discounts / discounts.mean() | |
| loss_weights = torch.einsum("h,t->ht", discounts, dim_weights) | |
| loss_weights = loss_weights.unsqueeze(1).clone() | |
| return loss_weights | |
| # ------------------------------------------ sampling ------------------------------------------# | |
| def predict_start_from_noise(self, x_t, t, noise): | |
| """ | |
| if self.predict_epsilon, model output is (scaled) noise; | |
| otherwise, model predicts x0 directly | |
| """ | |
| if self.predict_epsilon: | |
| return ( | |
| extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t | |
| - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise | |
| ) | |
| else: | |
| return noise | |
| def q_posterior(self, x_start, x_t, t): | |
| posterior_mean = ( | |
| extract(self.posterior_mean_coef1, t, x_t.shape) * x_start | |
| + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t | |
| ) | |
| posterior_variance = extract(self.posterior_variance, t, x_t.shape) | |
| posterior_log_variance_clipped = extract( | |
| self.posterior_log_variance_clipped, t, x_t.shape | |
| ) | |
| return posterior_mean, posterior_variance, posterior_log_variance_clipped | |
| def p_mean_variance(self, x, t, local_cond=None, global_cond=None, returns=None): | |
| if returns is not None: | |
| # epsilon could be epsilon or x0 itself | |
| epsilon_cond = self.model(x, t, local_cond, global_cond, returns, use_dropout=False) | |
| epsilon_uncond = self.model(x, t, local_cond, global_cond, returns, force_dropout=True) | |
| epsilon = epsilon_uncond + self.condition_guidance_w * ( | |
| epsilon_cond - epsilon_uncond | |
| ) | |
| else: | |
| epsilon = self.model(x, t, local_cond, global_cond) | |
| t = t.detach().to(torch.int64) | |
| x_recon = self.predict_start_from_noise(x, t=t, noise=epsilon) | |
| if self.clip_denoised: | |
| x_recon.clamp_(-1.0, 1.0) | |
| else: | |
| assert RuntimeError() | |
| model_mean, posterior_variance, posterior_log_variance = self.q_posterior( | |
| x_start=x_recon, x_t=x, t=t | |
| ) | |
| return model_mean, posterior_variance, posterior_log_variance | |
| def p_sample(self, x, t, local_cond=None, global_cond=None, returns=None): | |
| b = x.shape[0] | |
| model_mean, _, model_log_variance = self.p_mean_variance( | |
| x=x, t=t, local_cond=local_cond, global_cond=global_cond, returns=returns | |
| ) | |
| noise = 0.5 * torch.randn_like(x) | |
| # no noise when t == 0 | |
| nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) | |
| return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise | |
| def p_sample_loop( | |
| self, cond_data, cond_mask, local_cond=None, global_cond=None, returns=None, verbose=False, return_diffusion=False, **kwargs | |
| ): | |
| device = self.betas.device | |
| batch_size = cond_data.shape[0] | |
| x = torch.randn( | |
| size=cond_data.shape, | |
| dtype=cond_data.dtype, | |
| device=cond_data.device | |
| ) | |
| if return_diffusion: | |
| diffusion = [x] | |
| x[cond_mask] = cond_data[cond_mask] | |
| progress = Progress(self.n_timesteps) if verbose else Silent() | |
| for i in reversed(range(0, self.n_timesteps)): | |
| timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long) | |
| # 1. predict model output and replace sample | |
| x = self.p_sample(x, timesteps, local_cond, global_cond, returns) | |
| # 2. apply conditioning | |
| x[cond_mask] = cond_data[cond_mask] | |
| progress.update({"t": i}) | |
| if return_diffusion: | |
| diffusion.append(x) | |
| progress.close() | |
| if return_diffusion: | |
| return x, torch.stack(diffusion, dim=1) | |
| else: | |
| return x | |
| def conditional_sample(self, cond_data, cond_mask, local_cond=None, global_cond=None, returns=None, action_seq_len=None, *args, **kwargs): | |
| """ | |
| conditions : [ (time, state), ... ] | |
| """ | |
| # horizon = action_seq_len or self.action_seq_len | |
| # batch_size = len(list(cond_data.values())[0]) | |
| # shape = (batch_size, horizon, self.action_dim) # cond_data.shape | |
| return self.p_sample_loop(cond_data, cond_mask, local_cond, global_cond, returns, *args, **kwargs) | |
| # ------------------------------------------ training ------------------------------------------# | |
| def q_sample(self, x_start, t, noise=None): | |
| if noise is None: | |
| noise = torch.randn_like(x_start) | |
| sample = ( | |
| extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start | |
| + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise | |
| ) | |
| return sample | |
| def forward( | |
| self, | |
| x, | |
| t, | |
| local_cond=None, | |
| global_cond=None, | |
| **kwargs | |
| ): | |
| return self.model(x, t, local_cond, global_cond) | |
| def act( | |
| self, | |
| input_feature: torch.Tensor, | |
| ) -> torch.Tensor: | |
| pred_actions, self.hidden_state = self( | |
| input_feature, self.hidden_state | |
| ) | |
| raise NotImplementedError | |
| if __name__ == "__main__": | |
| model = GPTDecoder(128, 24) | |
| in_feat = torch.randn((4*24, 12, 128)) | |
| out = model(in_feat) | |
| print(out[0].shape, out[1].shape) | |
| pass |