aw1app's picture
Deploy REAL RoboFlamingo with LSTM policy head and trained checkpoint
f71030e
import torch
from einops import rearrange, repeat
from torch import nn
import copy
from open_flamingo.src.helpers import PerceiverResampler
from robot_flamingo.models.action_head import DeterministicDecoder, DiffusionDecoder, FCDecoder, GPTDecoder
from collections import namedtuple
class BCFlamingo(nn.Module):
def __init__(
self,
vision_encoder: nn.Module,
lang_encoder: nn.Module,
eoc_token_id: int,
media_token_id: int,
vis_dim: int,
cross_attn_every_n_layers: int = 1,
use_media_placement_augmentation: bool = False,
# this is the window size sampled from the episode
window_size: int = 8,
use_gripper=False,
fusion_mode='',
sep_resampler=False,
use_state=False,
use_diff=False,
diff_horizon=32,
last_action=False,
n_timesteps=150,
state_dim=15,
use_hist=False,
debug=False,
predict_epsilon=True,
pad_length=-1,
multi_step_action=1,
sep_lm_head=False,
return_feature = False,
llm='llama_9b',
pooling='max',
residual=False,
tcp_rel=False,
replan=-1,
decoder_type='lstm',
hidden_size=None,
fwd_pred=False,
fwd_pred_hand=False,
refresh=-1
):
"""
Args:
vision_encoder (nn.Module): HF CLIPModel
lang_encoder (nn.Module): HF causal language model
eoc_token_id (int): Token id for <|endofchunk|>
media_token_id (int): Token id for <image>
vis_dim (int): Dimension of the visual features.
Visual features are projected to match this shape along the last dimension.
cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1.
use_media_placement_augmentation (bool, optional): Whether to randomly assign images to the preceding or following text in training. Defaults to False.
"""
super().__init__()
self.use_gripper = use_gripper
self.use_state = use_state
self.fusion_mode = fusion_mode
self.eoc_token_id = eoc_token_id
self.media_token_id = media_token_id
self.use_media_placement_augmentation = use_media_placement_augmentation
self.vis_dim = vis_dim
self.window_size = window_size
self.tcp_rel = tcp_rel
self.act_step = multi_step_action
print('window size: {}'.format(window_size))
self.vision_encoder = vision_encoder
self.perceiver = PerceiverResampler(dim=self.vis_dim)
self.sep_resampler = sep_resampler
self.use_hist = use_hist
self.lang_encoder = lang_encoder
self.pad_length = pad_length
self.replan = replan
if self.replan != -1:
self.replan = min(int(replan * self.window_size), 180)
self.refresh = refresh
if hasattr(lang_encoder.config, "d_model"):
self.lang_dim = lang_encoder.config.d_model # mpt uses d_model
else:
self.lang_dim = lang_encoder.config.hidden_size
# print(self.vis_dim, self.lang_dim)
self.residual = residual
if not debug:
if 'llama' in llm:
self.lang_encoder.init_flamingo(
media_token_id=media_token_id,
vis_hidden_size=self.vis_dim,
cross_attn_every_n_layers=cross_attn_every_n_layers,
use_media_placement_augmentation=self.use_media_placement_augmentation,
residual=residual,
)
else:
self.lang_encoder.init_flamingo(
media_token_id=media_token_id,
lang_hidden_size=self.lang_dim,
vis_hidden_size=self.vis_dim,
cross_attn_every_n_layers=cross_attn_every_n_layers,
gradient_checkpointing=False,
)
if sep_resampler:
self.perceiver_gripper = PerceiverResampler(dim=self.vis_dim)
self.perceiver_gripper.load_state_dict(copy.deepcopy(self.perceiver.state_dict()))
if use_state:
self.state_fc = nn.Linear(state_dim, self.vis_dim)
if use_hist:
self.frame_embs = nn.Parameter(torch.randn(self.window_size, self.vis_dim))
# To-do: nn archiecture for actor
self.llm = llm
if llm=='llama':
in_features = lang_encoder.lm_head.in_features
else:
in_features = self.lang_dim
self.use_diff = use_diff
self.decoder_type = decoder_type
if decoder_type == 'lstm':
lm_head = DeterministicDecoder(in_features, self.window_size,
use_diff=use_diff, last_action=last_action, fusion_mode=fusion_mode, use_state=use_state, return_feature=return_feature, multi_step_action=multi_step_action, pooling=pooling)
self.lang_encoder.lm_head = lm_head
elif decoder_type == 'fc':
if use_hist:
self.lang_encoder.lm_head = self.action_head = FCDecoder(in_features, self.window_size,
use_diff=use_diff, last_action=last_action, fusion_mode=fusion_mode, use_state=use_state, return_feature=return_feature, multi_step_action=multi_step_action)
elif 'vit_concat' in fusion_mode:
self.lang_encoder.lm_head = self.action_head = FCDecoder(in_features, self.window_size,
use_diff=use_diff, last_action=last_action, fusion_mode=fusion_mode, use_state=use_state, return_feature=return_feature, multi_step_action=multi_step_action)
else:
raise NotImplementedError
elif decoder_type == 'diffusion':
if use_diff:
self.diffusion_model = DiffusionDecoder(
self.action_head.hidden_size,
self.window_size,
input_dim=self.action_head.out_features+1,
n_timesteps=n_timesteps,
horizon=diff_horizon,
predict_epsilon=predict_epsilon,
)
else:
raise NotImplementedError
elif decoder_type=='gpt':
lm_head = GPTDecoder(in_features, self.window_size, use_diff=use_diff, last_action=last_action, fusion_mode=fusion_mode, multi_step_action=multi_step_action, pooling=pooling, hidden_size=hidden_size)
self.lang_encoder.lm_head = self.action_head = lm_head
else:
raise NotImplementedError
self.sep_lm_head = sep_lm_head
if sep_lm_head:
self.lm_head = self.lang_encoder.lm_head
self.lang_encoder.lm_head = nn.Identity()
def forward(
self,
vision_x: torch.Tensor,
lang_x: torch.Tensor,
attention_mask: torch.Tensor = None,
labels: torch.Tensor = None,
use_cached_vision_x: bool = False,
clear_conditioned_layers: bool = True,
past_key_values=None,
use_cache: bool = False,
vision_gripper = None,
state_tensor = None,
return_feature = False,
policy_mask=None
):
"""
Forward pass of Flamingo.
Args:
vision_x (torch.Tensor): Vision input
shape (B, T_img, F, C, H, W) with F=1
lang_x (torch.Tensor): Language input ids
shape (B, T_txt)
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
labels (torch.Tensor, optional): Labels. Defaults to None.
clear_conditioned_layers: if True, clear the conditioned layers
once the foward pass is completed. Set this to false if the
same set of images will be reused in another subsequent
forward pass.
past_key_values: pre-computed values to pass to language model.
See past_key_values documentation in Hugging Face
CausalLM models.
use_cache: whether to use cached key values. See use_cache
documentation in Hugging Face CausalLM models.
"""
raw_rgb = vision_x.clone()
raw_gripper = vision_gripper.clone()
assert (
vision_x is not None
) or use_cached_vision_x, (
"Must provide either vision_x or use_cached_vision_x to True."
)
if use_cached_vision_x:
# Case: use cached; vision_x should be cached and other
# vision-related inputs should not be provided.
assert (
vision_x is None
), "Expect vision_x to be None when use_cached_vision_x is True."
assert self.lang_encoder.is_conditioned()
else:
# Case: do not use caching (i.e. this is a standard forward pass);
if self.use_hist:
self._encode_history_vision_post_fusion(vision_x, vision_gripper, state_tensor)
else:
if not self.use_gripper or self.fusion_mode == 'two_way':
vision_x = self._encode_vision_x(vision_x=vision_x)
else:
if self.fusion_mode == 'pre':
self._encode_multi_vision_pre_fusion(vision_x, vision_gripper, state_tensor)
elif self.fusion_mode == 'post':
self._encode_multi_vision_post_fusion(vision_x, vision_gripper, state_tensor)
elif self.fusion_mode == 'vit_concat':
self._encode_history_vision_fc_post(vision_x, vision_gripper, state_tensor)
if 'llama' in self.llm:
output = self.lang_encoder(
input_ids=lang_x,
attention_mask=attention_mask,
# labels=labels, # 不输入label,程序就不会计算loss
past_key_values=past_key_values,
use_cache=use_cache,
)
else:
output = self.lang_encoder(
input_ids=lang_x,
attention_mask=attention_mask
)
if self.sep_lm_head:
output_llm = output.logits
output_lm_head = self.lm_head(output_llm, state_tensor=state_tensor, return_feature=return_feature)
output.logits = output_lm_head
if clear_conditioned_layers:
self.lang_encoder.clear_conditioned_layers()
# action_seq = self.action_head(vision_x)
return output
# Generate function with actor for text time adpatation
def generate(
self,
vision_x: torch.Tensor,
lang_x: torch.Tensor,
attention_mask: torch.Tensor = None,
num_beams=1,
max_new_tokens=None,
temperature=1.0,
top_k=0,
top_p=1.0,
no_repeat_ngram_size=0,
prefix_allowed_tokens_fn=None,
length_penalty=1.0,
num_return_sequences=1,
do_sample=False,
early_stopping=False,
):
"""
Generate text conditioned on vision and language inputs.
Args:
vision_x (torch.Tensor): Vision input
shape (B, T_img, F, C, H, W)
images in the same chunk are collated along T_img, and frames are collated along F
currently only F=1 is supported (single-frame videos)
lang_x (torch.Tensor): Language input
shape (B, T_txt)
max_length (int, optional): Maximum length of the output. Defaults to None.
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
num_beams (int, optional): Number of beams. Defaults to 1.
max_new_tokens (int, optional): Maximum new tokens. Defaults to None.
temperature (float, optional): Temperature. Defaults to 1.0.
top_k (int, optional): Top k. Defaults to 0.
top_p (float, optional): Top p. Defaults to 1.0.
no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0.
length_penalty (float, optional): Length penalty. Defaults to 1.0.
num_return_sequences (int, optional): Number of return sequences. Defaults to 1.
do_sample (bool, optional): Do sample. Defaults to False.
early_stopping (bool, optional): Early stopping. Defaults to False.
Returns:
torch.Tensor: lang_x with generated tokens appended to it
"""
if num_beams > 1:
vision_x = vision_x.repeat_interleave(num_beams, dim=0)
self._encode_vision_x(vision_x=vision_x)
output = self.lang_encoder.generate(
lang_x,
attention_mask=attention_mask,
eos_token_id=self.eoc_token_id,
num_beams=num_beams,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
no_repeat_ngram_size=no_repeat_ngram_size,
length_penalty=length_penalty,
num_return_sequences=num_return_sequences,
do_sample=do_sample,
early_stopping=early_stopping,
)
self.lang_encoder.clear_conditioned_layers()
return output
def _encode_vision_x(self, vision_x: torch.Tensor):
"""
Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
Args:
vision_x (torch.Tensor): Vision input
shape (B, T_img, F, C, H, W)
Images in the same chunk are collated along T_img, and frames are collated along F
Currently only F=1 is supported (single-frame videos)
rearrange code based on https://github.com/dhansmair/flamingo-mini
"""
assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
b, T, F = vision_x.shape[:3]
assert F == 1, "Only single frame supported"
vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
with torch.no_grad():
vision_x = self.vision_encoder.visual(vision_x)[1]
vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
vision_x = self.perceiver(vision_x) # reshapes to (b, T, n, d)
for layer in self.lang_encoder._get_decoder_layers():
layer.condition_vis_x(vision_x)
return vision_x
def _encode_vision(self, vision_x: torch.Tensor, state_tensor=None):
"""
Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
Args:
vision_x (torch.Tensor): Vision input
shape (B, T_img, F, C, H, W)
Images in the same chunk are collated along T_img, and frames are collated along F
Currently only F=1 is supported (single-frame videos)
rearrange code based on https://github.com/dhansmair/flamingo-mini
"""
assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
b, T, F = vision_x.shape[:3]
assert F == 1, "Only single frame supported"
vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
with torch.no_grad():
vision_x = self.vision_encoder.visual(vision_x)[1]
vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
return vision_x
def _encode_multi_vision_pre_fusion(self, vision_rgb: torch.Tensor, vision_gripper: torch.Tensor, state_tensor=None):
"""
Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
Args:
vision_rgb (torch.Tensor): Vision rgb input
shape (B, T_img, F, C, H, W)
vision_gripper (torch.Tensor): Vision rgb input
shape (B, T_img, F, C, H, W)
Images in the same chunk are collated along T_img, and frames are collated along F
Currently only F=1 is supported (single-frame videos)
rearrange code based on https://github.com/dhansmair/flamingo-mini
"""
vision_rgb = self._encode_vision(vision_rgb)
vision_gripper = self._encode_vision(vision_gripper)
vision_x = torch.cat([vision_rgb, vision_gripper], dim=3)
vision_x = self.perceiver(vision_x) # reshapes to (b, T, n, d)
for layer in self.lang_encoder._get_decoder_layers():
layer.condition_vis_x(vision_x)
return vision_x
def _encode_multi_vision_post_fusion(self, vision_rgb: torch.Tensor, vision_gripper: torch.Tensor, state_tensor=None):
"""
Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
Args:
vision_rgb (torch.Tensor): Vision rgb input
shape (B, T_img, F, C, H, W)
vision_gripper (torch.Tensor): Vision rgb input
shape (B, T_img, F, C, H, W)
Images in the same chunk are collated along T_img, and frames are collated along F
Currently only F=1 is supported (single-frame videos)
rearrange code based on https://github.com/dhansmair/flamingo-mini
"""
vision_rgb = self._encode_vision(vision_rgb)
vision_gripper = self._encode_vision(vision_gripper)
vision_rgb = self.perceiver(vision_rgb)
if self.sep_resampler:
vision_gripper = self.perceiver_gripper(vision_gripper)
else:
vision_gripper = self.perceiver(vision_gripper)
vision_x = torch.cat([vision_rgb, vision_gripper], dim=2) # reshapes to (b, T, 2*n, d)
if self.use_state and state_tensor is not None:
# state_tensor = state_tensor.double()
state_tensor = self.state_fc(state_tensor)
vision_x = torch.cat([vision_x, state_tensor], dim=2) # reshapes to (b, T, 2*n+1, d)
for layer in self.lang_encoder._get_decoder_layers():
layer.condition_vis_x(vision_x)
return vision_x
def _encode_multi_vision_two_way(self, vision_rgb: torch.Tensor, vision_gripper: torch.Tensor, state_tensor=None):
"""
Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
Args:
vision_rgb (torch.Tensor): Vision rgb input
shape (B, T_img, F, C, H, W)
vision_gripper (torch.Tensor): Vision rgb input
shape (B, T_img, F, C, H, W)
Images in the same chunk are collated along T_img, and frames are collated along F
Currently only F=1 is supported (single-frame videos)
rearrange code based on https://github.com/dhansmair/flamingo-mini
"""
vision_rgb = self._encode_vision(vision_rgb)
vision_gripper = self._encode_vision(vision_gripper)
vision_rgb = self.perceiver(vision_rgb)
if self.sep_resampler:
vision_gripper = self.perceiver_gripper(vision_gripper)
else:
vision_gripper = self.perceiver(vision_gripper)
vision_x = torch.cat([vision_rgb, vision_gripper], dim=0) # reshapes to (b, T, 2*n, d)
if self.use_state and state_tensor is not None:
state_tensor = self.state_fc(state_tensor)
vision_x = torch.cat([vision_x, state_tensor], dim=0) # reshapes to (b, T, 2*n+1, d)
for layer in self.lang_encoder._get_decoder_layers():
layer.condition_vis_x(vision_x)
return vision_x
def _encode_history_vision_post_fusion(self, vision_rgb: torch.Tensor, vision_gripper: torch.Tensor, state_tensor=None):
"""
Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
Args:
vision_rgb (torch.Tensor): Vision rgb input
shape (B, T_img, F, C, H, W)
vision_gripper (torch.Tensor): Vision rgb input
shape (B, T_img, F, C, H, W)
Images in the same chunk are collated along T_img, and frames are collated along F
Currently only F=1 is supported (single-frame videos)
rearrange code based on https://github.com/dhansmair/flamingo-mini
"""
vision_rgb = self._encode_vision(vision_rgb)
vision_gripper = self._encode_vision(vision_gripper)
bs = int(vision_rgb.shape[0] // self.window_size)
vision_rgb = vision_rgb.view(bs, self.window_size, *vision_rgb.shape[1:])
_, _, T, p, v_tok, dim = vision_rgb.shape[:6]
frame_embs = repeat(self.frame_embs, 'F d -> b F T p v d', b=bs, T=T, p=p, v=v_tok)
vision_rgb = vision_rgb + frame_embs
vision_rgb = rearrange(vision_rgb, 'b F T p v d -> (b F) T p v d')
vision_rgb = self.perceiver(vision_rgb)
vision_gripper = vision_gripper.view(vision_gripper.shape[0] // self.window_size, self.window_size,
*vision_gripper.shape[1:])
frame_embs = repeat(self.frame_embs, 'F d -> b F T p v d', b=bs, T=T, p=p, v=v_tok)
vision_gripper = vision_gripper + frame_embs
vision_gripper = rearrange(vision_gripper, 'b F T p v d -> (b F) T p v d')
if self.sep_resampler:
vision_gripper = self.perceiver_gripper(vision_gripper)
else:
vision_gripper = self.perceiver(vision_gripper)
vision_x = torch.cat([vision_rgb, vision_gripper], dim=2) # reshapes to (b, T, 2*n, d)
if self.use_state and state_tensor is not None:
state_tensor = self.state_fc(state_tensor)
vision_x = torch.cat([vision_x, state_tensor], dim=2) # reshapes to (b, T, 2*n+1, d)
for layer in self.lang_encoder._get_decoder_layers():
layer.condition_vis_x(vision_x)
return vision_x
def _encode_history_vision_fc_post(self, vision_rgb: torch.Tensor, vision_gripper: torch.Tensor, state_tensor=None):
"""
Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
Args:
vision_rgb (torch.Tensor): Vision rgb input
shape (B, T_img, F, C, H, W)
vision_gripper (torch.Tensor): Vision rgb input
shape (B, T_img, F, C, H, W)
Images in the same chunk are collated along T_img, and frames are collated along F
Currently only F=1 is supported (single-frame videos)
rearrange code based on https://github.com/dhansmair/flamingo-mini
"""
bs = int(vision_rgb.shape[0] // self.window_size)
vision_rgb = self._encode_vision(vision_rgb)
vision_rgb = self.perceiver(vision_rgb) # BxL, T, n, d
vision_rgb = vision_rgb.view(-1, self.window_size, *vision_rgb.shape[1:]) # B, L, T, n, d
vision_rgb = rearrange(vision_rgb, 'b L T n d -> b T (n L) d')
vision_gripper = self._encode_vision(vision_gripper)
if self.sep_resampler:
vision_gripper = self.perceiver_gripper(vision_gripper)
else:
vision_gripper = self.perceiver(vision_gripper)
vision_gripper = vision_gripper.view(-1, self.window_size, *vision_gripper.shape[1:]) # B, L, T, n, d
vision_gripper = rearrange(vision_gripper, 'b L T n d -> b T (n L) d')
vision_x = torch.cat([vision_rgb, vision_gripper], dim=2)
if self.use_state and state_tensor is not None:
state_tensor = self.state_fc(state_tensor)
vision_x = torch.cat([vision_x, state_tensor], dim=2) # reshapes to (b, T, 2*n+1, d)
for layer in self.lang_encoder._get_decoder_layers():
layer.condition_vis_x(vision_x)
return vision_x