Spaces:
Paused
Paused
| import torch | |
| from einops import rearrange, repeat | |
| from torch import nn | |
| import copy | |
| from open_flamingo.src.helpers import PerceiverResampler | |
| from robot_flamingo.models.action_head import DeterministicDecoder, DiffusionDecoder, FCDecoder, GPTDecoder | |
| from collections import namedtuple | |
| class BCFlamingo(nn.Module): | |
| def __init__( | |
| self, | |
| vision_encoder: nn.Module, | |
| lang_encoder: nn.Module, | |
| eoc_token_id: int, | |
| media_token_id: int, | |
| vis_dim: int, | |
| cross_attn_every_n_layers: int = 1, | |
| use_media_placement_augmentation: bool = False, | |
| # this is the window size sampled from the episode | |
| window_size: int = 8, | |
| use_gripper=False, | |
| fusion_mode='', | |
| sep_resampler=False, | |
| use_state=False, | |
| use_diff=False, | |
| diff_horizon=32, | |
| last_action=False, | |
| n_timesteps=150, | |
| state_dim=15, | |
| use_hist=False, | |
| debug=False, | |
| predict_epsilon=True, | |
| pad_length=-1, | |
| multi_step_action=1, | |
| sep_lm_head=False, | |
| return_feature = False, | |
| llm='llama_9b', | |
| pooling='max', | |
| residual=False, | |
| tcp_rel=False, | |
| replan=-1, | |
| decoder_type='lstm', | |
| hidden_size=None, | |
| fwd_pred=False, | |
| fwd_pred_hand=False, | |
| refresh=-1 | |
| ): | |
| """ | |
| Args: | |
| vision_encoder (nn.Module): HF CLIPModel | |
| lang_encoder (nn.Module): HF causal language model | |
| eoc_token_id (int): Token id for <|endofchunk|> | |
| media_token_id (int): Token id for <image> | |
| vis_dim (int): Dimension of the visual features. | |
| Visual features are projected to match this shape along the last dimension. | |
| cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1. | |
| use_media_placement_augmentation (bool, optional): Whether to randomly assign images to the preceding or following text in training. Defaults to False. | |
| """ | |
| super().__init__() | |
| self.use_gripper = use_gripper | |
| self.use_state = use_state | |
| self.fusion_mode = fusion_mode | |
| self.eoc_token_id = eoc_token_id | |
| self.media_token_id = media_token_id | |
| self.use_media_placement_augmentation = use_media_placement_augmentation | |
| self.vis_dim = vis_dim | |
| self.window_size = window_size | |
| self.tcp_rel = tcp_rel | |
| self.act_step = multi_step_action | |
| print('window size: {}'.format(window_size)) | |
| self.vision_encoder = vision_encoder | |
| self.perceiver = PerceiverResampler(dim=self.vis_dim) | |
| self.sep_resampler = sep_resampler | |
| self.use_hist = use_hist | |
| self.lang_encoder = lang_encoder | |
| self.pad_length = pad_length | |
| self.replan = replan | |
| if self.replan != -1: | |
| self.replan = min(int(replan * self.window_size), 180) | |
| self.refresh = refresh | |
| if hasattr(lang_encoder.config, "d_model"): | |
| self.lang_dim = lang_encoder.config.d_model # mpt uses d_model | |
| else: | |
| self.lang_dim = lang_encoder.config.hidden_size | |
| # print(self.vis_dim, self.lang_dim) | |
| self.residual = residual | |
| if not debug: | |
| if 'llama' in llm: | |
| self.lang_encoder.init_flamingo( | |
| media_token_id=media_token_id, | |
| vis_hidden_size=self.vis_dim, | |
| cross_attn_every_n_layers=cross_attn_every_n_layers, | |
| use_media_placement_augmentation=self.use_media_placement_augmentation, | |
| residual=residual, | |
| ) | |
| else: | |
| self.lang_encoder.init_flamingo( | |
| media_token_id=media_token_id, | |
| lang_hidden_size=self.lang_dim, | |
| vis_hidden_size=self.vis_dim, | |
| cross_attn_every_n_layers=cross_attn_every_n_layers, | |
| gradient_checkpointing=False, | |
| ) | |
| if sep_resampler: | |
| self.perceiver_gripper = PerceiverResampler(dim=self.vis_dim) | |
| self.perceiver_gripper.load_state_dict(copy.deepcopy(self.perceiver.state_dict())) | |
| if use_state: | |
| self.state_fc = nn.Linear(state_dim, self.vis_dim) | |
| if use_hist: | |
| self.frame_embs = nn.Parameter(torch.randn(self.window_size, self.vis_dim)) | |
| # To-do: nn archiecture for actor | |
| self.llm = llm | |
| if llm=='llama': | |
| in_features = lang_encoder.lm_head.in_features | |
| else: | |
| in_features = self.lang_dim | |
| self.use_diff = use_diff | |
| self.decoder_type = decoder_type | |
| if decoder_type == 'lstm': | |
| lm_head = DeterministicDecoder(in_features, self.window_size, | |
| use_diff=use_diff, last_action=last_action, fusion_mode=fusion_mode, use_state=use_state, return_feature=return_feature, multi_step_action=multi_step_action, pooling=pooling) | |
| self.lang_encoder.lm_head = lm_head | |
| elif decoder_type == 'fc': | |
| if use_hist: | |
| self.lang_encoder.lm_head = self.action_head = FCDecoder(in_features, self.window_size, | |
| use_diff=use_diff, last_action=last_action, fusion_mode=fusion_mode, use_state=use_state, return_feature=return_feature, multi_step_action=multi_step_action) | |
| elif 'vit_concat' in fusion_mode: | |
| self.lang_encoder.lm_head = self.action_head = FCDecoder(in_features, self.window_size, | |
| use_diff=use_diff, last_action=last_action, fusion_mode=fusion_mode, use_state=use_state, return_feature=return_feature, multi_step_action=multi_step_action) | |
| else: | |
| raise NotImplementedError | |
| elif decoder_type == 'diffusion': | |
| if use_diff: | |
| self.diffusion_model = DiffusionDecoder( | |
| self.action_head.hidden_size, | |
| self.window_size, | |
| input_dim=self.action_head.out_features+1, | |
| n_timesteps=n_timesteps, | |
| horizon=diff_horizon, | |
| predict_epsilon=predict_epsilon, | |
| ) | |
| else: | |
| raise NotImplementedError | |
| elif decoder_type=='gpt': | |
| lm_head = GPTDecoder(in_features, self.window_size, use_diff=use_diff, last_action=last_action, fusion_mode=fusion_mode, multi_step_action=multi_step_action, pooling=pooling, hidden_size=hidden_size) | |
| self.lang_encoder.lm_head = self.action_head = lm_head | |
| else: | |
| raise NotImplementedError | |
| self.sep_lm_head = sep_lm_head | |
| if sep_lm_head: | |
| self.lm_head = self.lang_encoder.lm_head | |
| self.lang_encoder.lm_head = nn.Identity() | |
| def forward( | |
| self, | |
| vision_x: torch.Tensor, | |
| lang_x: torch.Tensor, | |
| attention_mask: torch.Tensor = None, | |
| labels: torch.Tensor = None, | |
| use_cached_vision_x: bool = False, | |
| clear_conditioned_layers: bool = True, | |
| past_key_values=None, | |
| use_cache: bool = False, | |
| vision_gripper = None, | |
| state_tensor = None, | |
| return_feature = False, | |
| policy_mask=None | |
| ): | |
| """ | |
| Forward pass of Flamingo. | |
| Args: | |
| vision_x (torch.Tensor): Vision input | |
| shape (B, T_img, F, C, H, W) with F=1 | |
| lang_x (torch.Tensor): Language input ids | |
| shape (B, T_txt) | |
| attention_mask (torch.Tensor, optional): Attention mask. Defaults to None. | |
| labels (torch.Tensor, optional): Labels. Defaults to None. | |
| clear_conditioned_layers: if True, clear the conditioned layers | |
| once the foward pass is completed. Set this to false if the | |
| same set of images will be reused in another subsequent | |
| forward pass. | |
| past_key_values: pre-computed values to pass to language model. | |
| See past_key_values documentation in Hugging Face | |
| CausalLM models. | |
| use_cache: whether to use cached key values. See use_cache | |
| documentation in Hugging Face CausalLM models. | |
| """ | |
| raw_rgb = vision_x.clone() | |
| raw_gripper = vision_gripper.clone() | |
| assert ( | |
| vision_x is not None | |
| ) or use_cached_vision_x, ( | |
| "Must provide either vision_x or use_cached_vision_x to True." | |
| ) | |
| if use_cached_vision_x: | |
| # Case: use cached; vision_x should be cached and other | |
| # vision-related inputs should not be provided. | |
| assert ( | |
| vision_x is None | |
| ), "Expect vision_x to be None when use_cached_vision_x is True." | |
| assert self.lang_encoder.is_conditioned() | |
| else: | |
| # Case: do not use caching (i.e. this is a standard forward pass); | |
| if self.use_hist: | |
| self._encode_history_vision_post_fusion(vision_x, vision_gripper, state_tensor) | |
| else: | |
| if not self.use_gripper or self.fusion_mode == 'two_way': | |
| vision_x = self._encode_vision_x(vision_x=vision_x) | |
| else: | |
| if self.fusion_mode == 'pre': | |
| self._encode_multi_vision_pre_fusion(vision_x, vision_gripper, state_tensor) | |
| elif self.fusion_mode == 'post': | |
| self._encode_multi_vision_post_fusion(vision_x, vision_gripper, state_tensor) | |
| elif self.fusion_mode == 'vit_concat': | |
| self._encode_history_vision_fc_post(vision_x, vision_gripper, state_tensor) | |
| if 'llama' in self.llm: | |
| output = self.lang_encoder( | |
| input_ids=lang_x, | |
| attention_mask=attention_mask, | |
| # labels=labels, # 不输入label,程序就不会计算loss | |
| past_key_values=past_key_values, | |
| use_cache=use_cache, | |
| ) | |
| else: | |
| output = self.lang_encoder( | |
| input_ids=lang_x, | |
| attention_mask=attention_mask | |
| ) | |
| if self.sep_lm_head: | |
| output_llm = output.logits | |
| output_lm_head = self.lm_head(output_llm, state_tensor=state_tensor, return_feature=return_feature) | |
| output.logits = output_lm_head | |
| if clear_conditioned_layers: | |
| self.lang_encoder.clear_conditioned_layers() | |
| # action_seq = self.action_head(vision_x) | |
| return output | |
| # Generate function with actor for text time adpatation | |
| def generate( | |
| self, | |
| vision_x: torch.Tensor, | |
| lang_x: torch.Tensor, | |
| attention_mask: torch.Tensor = None, | |
| num_beams=1, | |
| max_new_tokens=None, | |
| temperature=1.0, | |
| top_k=0, | |
| top_p=1.0, | |
| no_repeat_ngram_size=0, | |
| prefix_allowed_tokens_fn=None, | |
| length_penalty=1.0, | |
| num_return_sequences=1, | |
| do_sample=False, | |
| early_stopping=False, | |
| ): | |
| """ | |
| Generate text conditioned on vision and language inputs. | |
| Args: | |
| vision_x (torch.Tensor): Vision input | |
| shape (B, T_img, F, C, H, W) | |
| images in the same chunk are collated along T_img, and frames are collated along F | |
| currently only F=1 is supported (single-frame videos) | |
| lang_x (torch.Tensor): Language input | |
| shape (B, T_txt) | |
| max_length (int, optional): Maximum length of the output. Defaults to None. | |
| attention_mask (torch.Tensor, optional): Attention mask. Defaults to None. | |
| num_beams (int, optional): Number of beams. Defaults to 1. | |
| max_new_tokens (int, optional): Maximum new tokens. Defaults to None. | |
| temperature (float, optional): Temperature. Defaults to 1.0. | |
| top_k (int, optional): Top k. Defaults to 0. | |
| top_p (float, optional): Top p. Defaults to 1.0. | |
| no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0. | |
| length_penalty (float, optional): Length penalty. Defaults to 1.0. | |
| num_return_sequences (int, optional): Number of return sequences. Defaults to 1. | |
| do_sample (bool, optional): Do sample. Defaults to False. | |
| early_stopping (bool, optional): Early stopping. Defaults to False. | |
| Returns: | |
| torch.Tensor: lang_x with generated tokens appended to it | |
| """ | |
| if num_beams > 1: | |
| vision_x = vision_x.repeat_interleave(num_beams, dim=0) | |
| self._encode_vision_x(vision_x=vision_x) | |
| output = self.lang_encoder.generate( | |
| lang_x, | |
| attention_mask=attention_mask, | |
| eos_token_id=self.eoc_token_id, | |
| num_beams=num_beams, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, | |
| no_repeat_ngram_size=no_repeat_ngram_size, | |
| length_penalty=length_penalty, | |
| num_return_sequences=num_return_sequences, | |
| do_sample=do_sample, | |
| early_stopping=early_stopping, | |
| ) | |
| self.lang_encoder.clear_conditioned_layers() | |
| return output | |
| def _encode_vision_x(self, vision_x: torch.Tensor): | |
| """ | |
| Compute media tokens from vision input by passing it through vision encoder and conditioning language model. | |
| Args: | |
| vision_x (torch.Tensor): Vision input | |
| shape (B, T_img, F, C, H, W) | |
| Images in the same chunk are collated along T_img, and frames are collated along F | |
| Currently only F=1 is supported (single-frame videos) | |
| rearrange code based on https://github.com/dhansmair/flamingo-mini | |
| """ | |
| assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)" | |
| b, T, F = vision_x.shape[:3] | |
| assert F == 1, "Only single frame supported" | |
| vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w") | |
| with torch.no_grad(): | |
| vision_x = self.vision_encoder.visual(vision_x)[1] | |
| vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F) | |
| vision_x = self.perceiver(vision_x) # reshapes to (b, T, n, d) | |
| for layer in self.lang_encoder._get_decoder_layers(): | |
| layer.condition_vis_x(vision_x) | |
| return vision_x | |
| def _encode_vision(self, vision_x: torch.Tensor, state_tensor=None): | |
| """ | |
| Compute media tokens from vision input by passing it through vision encoder and conditioning language model. | |
| Args: | |
| vision_x (torch.Tensor): Vision input | |
| shape (B, T_img, F, C, H, W) | |
| Images in the same chunk are collated along T_img, and frames are collated along F | |
| Currently only F=1 is supported (single-frame videos) | |
| rearrange code based on https://github.com/dhansmair/flamingo-mini | |
| """ | |
| assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)" | |
| b, T, F = vision_x.shape[:3] | |
| assert F == 1, "Only single frame supported" | |
| vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w") | |
| with torch.no_grad(): | |
| vision_x = self.vision_encoder.visual(vision_x)[1] | |
| vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F) | |
| return vision_x | |
| def _encode_multi_vision_pre_fusion(self, vision_rgb: torch.Tensor, vision_gripper: torch.Tensor, state_tensor=None): | |
| """ | |
| Compute media tokens from vision input by passing it through vision encoder and conditioning language model. | |
| Args: | |
| vision_rgb (torch.Tensor): Vision rgb input | |
| shape (B, T_img, F, C, H, W) | |
| vision_gripper (torch.Tensor): Vision rgb input | |
| shape (B, T_img, F, C, H, W) | |
| Images in the same chunk are collated along T_img, and frames are collated along F | |
| Currently only F=1 is supported (single-frame videos) | |
| rearrange code based on https://github.com/dhansmair/flamingo-mini | |
| """ | |
| vision_rgb = self._encode_vision(vision_rgb) | |
| vision_gripper = self._encode_vision(vision_gripper) | |
| vision_x = torch.cat([vision_rgb, vision_gripper], dim=3) | |
| vision_x = self.perceiver(vision_x) # reshapes to (b, T, n, d) | |
| for layer in self.lang_encoder._get_decoder_layers(): | |
| layer.condition_vis_x(vision_x) | |
| return vision_x | |
| def _encode_multi_vision_post_fusion(self, vision_rgb: torch.Tensor, vision_gripper: torch.Tensor, state_tensor=None): | |
| """ | |
| Compute media tokens from vision input by passing it through vision encoder and conditioning language model. | |
| Args: | |
| vision_rgb (torch.Tensor): Vision rgb input | |
| shape (B, T_img, F, C, H, W) | |
| vision_gripper (torch.Tensor): Vision rgb input | |
| shape (B, T_img, F, C, H, W) | |
| Images in the same chunk are collated along T_img, and frames are collated along F | |
| Currently only F=1 is supported (single-frame videos) | |
| rearrange code based on https://github.com/dhansmair/flamingo-mini | |
| """ | |
| vision_rgb = self._encode_vision(vision_rgb) | |
| vision_gripper = self._encode_vision(vision_gripper) | |
| vision_rgb = self.perceiver(vision_rgb) | |
| if self.sep_resampler: | |
| vision_gripper = self.perceiver_gripper(vision_gripper) | |
| else: | |
| vision_gripper = self.perceiver(vision_gripper) | |
| vision_x = torch.cat([vision_rgb, vision_gripper], dim=2) # reshapes to (b, T, 2*n, d) | |
| if self.use_state and state_tensor is not None: | |
| # state_tensor = state_tensor.double() | |
| state_tensor = self.state_fc(state_tensor) | |
| vision_x = torch.cat([vision_x, state_tensor], dim=2) # reshapes to (b, T, 2*n+1, d) | |
| for layer in self.lang_encoder._get_decoder_layers(): | |
| layer.condition_vis_x(vision_x) | |
| return vision_x | |
| def _encode_multi_vision_two_way(self, vision_rgb: torch.Tensor, vision_gripper: torch.Tensor, state_tensor=None): | |
| """ | |
| Compute media tokens from vision input by passing it through vision encoder and conditioning language model. | |
| Args: | |
| vision_rgb (torch.Tensor): Vision rgb input | |
| shape (B, T_img, F, C, H, W) | |
| vision_gripper (torch.Tensor): Vision rgb input | |
| shape (B, T_img, F, C, H, W) | |
| Images in the same chunk are collated along T_img, and frames are collated along F | |
| Currently only F=1 is supported (single-frame videos) | |
| rearrange code based on https://github.com/dhansmair/flamingo-mini | |
| """ | |
| vision_rgb = self._encode_vision(vision_rgb) | |
| vision_gripper = self._encode_vision(vision_gripper) | |
| vision_rgb = self.perceiver(vision_rgb) | |
| if self.sep_resampler: | |
| vision_gripper = self.perceiver_gripper(vision_gripper) | |
| else: | |
| vision_gripper = self.perceiver(vision_gripper) | |
| vision_x = torch.cat([vision_rgb, vision_gripper], dim=0) # reshapes to (b, T, 2*n, d) | |
| if self.use_state and state_tensor is not None: | |
| state_tensor = self.state_fc(state_tensor) | |
| vision_x = torch.cat([vision_x, state_tensor], dim=0) # reshapes to (b, T, 2*n+1, d) | |
| for layer in self.lang_encoder._get_decoder_layers(): | |
| layer.condition_vis_x(vision_x) | |
| return vision_x | |
| def _encode_history_vision_post_fusion(self, vision_rgb: torch.Tensor, vision_gripper: torch.Tensor, state_tensor=None): | |
| """ | |
| Compute media tokens from vision input by passing it through vision encoder and conditioning language model. | |
| Args: | |
| vision_rgb (torch.Tensor): Vision rgb input | |
| shape (B, T_img, F, C, H, W) | |
| vision_gripper (torch.Tensor): Vision rgb input | |
| shape (B, T_img, F, C, H, W) | |
| Images in the same chunk are collated along T_img, and frames are collated along F | |
| Currently only F=1 is supported (single-frame videos) | |
| rearrange code based on https://github.com/dhansmair/flamingo-mini | |
| """ | |
| vision_rgb = self._encode_vision(vision_rgb) | |
| vision_gripper = self._encode_vision(vision_gripper) | |
| bs = int(vision_rgb.shape[0] // self.window_size) | |
| vision_rgb = vision_rgb.view(bs, self.window_size, *vision_rgb.shape[1:]) | |
| _, _, T, p, v_tok, dim = vision_rgb.shape[:6] | |
| frame_embs = repeat(self.frame_embs, 'F d -> b F T p v d', b=bs, T=T, p=p, v=v_tok) | |
| vision_rgb = vision_rgb + frame_embs | |
| vision_rgb = rearrange(vision_rgb, 'b F T p v d -> (b F) T p v d') | |
| vision_rgb = self.perceiver(vision_rgb) | |
| vision_gripper = vision_gripper.view(vision_gripper.shape[0] // self.window_size, self.window_size, | |
| *vision_gripper.shape[1:]) | |
| frame_embs = repeat(self.frame_embs, 'F d -> b F T p v d', b=bs, T=T, p=p, v=v_tok) | |
| vision_gripper = vision_gripper + frame_embs | |
| vision_gripper = rearrange(vision_gripper, 'b F T p v d -> (b F) T p v d') | |
| if self.sep_resampler: | |
| vision_gripper = self.perceiver_gripper(vision_gripper) | |
| else: | |
| vision_gripper = self.perceiver(vision_gripper) | |
| vision_x = torch.cat([vision_rgb, vision_gripper], dim=2) # reshapes to (b, T, 2*n, d) | |
| if self.use_state and state_tensor is not None: | |
| state_tensor = self.state_fc(state_tensor) | |
| vision_x = torch.cat([vision_x, state_tensor], dim=2) # reshapes to (b, T, 2*n+1, d) | |
| for layer in self.lang_encoder._get_decoder_layers(): | |
| layer.condition_vis_x(vision_x) | |
| return vision_x | |
| def _encode_history_vision_fc_post(self, vision_rgb: torch.Tensor, vision_gripper: torch.Tensor, state_tensor=None): | |
| """ | |
| Compute media tokens from vision input by passing it through vision encoder and conditioning language model. | |
| Args: | |
| vision_rgb (torch.Tensor): Vision rgb input | |
| shape (B, T_img, F, C, H, W) | |
| vision_gripper (torch.Tensor): Vision rgb input | |
| shape (B, T_img, F, C, H, W) | |
| Images in the same chunk are collated along T_img, and frames are collated along F | |
| Currently only F=1 is supported (single-frame videos) | |
| rearrange code based on https://github.com/dhansmair/flamingo-mini | |
| """ | |
| bs = int(vision_rgb.shape[0] // self.window_size) | |
| vision_rgb = self._encode_vision(vision_rgb) | |
| vision_rgb = self.perceiver(vision_rgb) # BxL, T, n, d | |
| vision_rgb = vision_rgb.view(-1, self.window_size, *vision_rgb.shape[1:]) # B, L, T, n, d | |
| vision_rgb = rearrange(vision_rgb, 'b L T n d -> b T (n L) d') | |
| vision_gripper = self._encode_vision(vision_gripper) | |
| if self.sep_resampler: | |
| vision_gripper = self.perceiver_gripper(vision_gripper) | |
| else: | |
| vision_gripper = self.perceiver(vision_gripper) | |
| vision_gripper = vision_gripper.view(-1, self.window_size, *vision_gripper.shape[1:]) # B, L, T, n, d | |
| vision_gripper = rearrange(vision_gripper, 'b L T n d -> b T (n L) d') | |
| vision_x = torch.cat([vision_rgb, vision_gripper], dim=2) | |
| if self.use_state and state_tensor is not None: | |
| state_tensor = self.state_fc(state_tensor) | |
| vision_x = torch.cat([vision_x, state_tensor], dim=2) # reshapes to (b, T, 2*n+1, d) | |
| for layer in self.lang_encoder._get_decoder_layers(): | |
| layer.condition_vis_x(vision_x) | |
| return vision_x |