Spaces:
Paused
Paused
| """Factory with forced hidden states""" | |
| import sys | |
| import torch | |
| import torch.nn as nn | |
| sys.path.insert(0, '/home/user/app') | |
| from open_flamingo.src.factory import create_model_and_transforms as create_base | |
| from huggingface_hub import hf_hub_download | |
| from policy_head import LSTMPolicyHead | |
| class RoboFlamingoWithPolicy(nn.Module): | |
| def __init__(self, base_model, policy_head): | |
| super().__init__() | |
| self.base_model = base_model | |
| self.policy_head = policy_head | |
| self.vision_encoder = base_model.vision_encoder | |
| self.lang_encoder = base_model.lang_encoder | |
| def forward(self, vision_x, lang_x, attention_mask=None): | |
| # Get the internal model | |
| # OpenFlamingo wraps the language model, we need to call it with output_hidden_states | |
| # The base_model is Flamingo, which has lang_encoder | |
| # We need to get embeddings from the language encoder | |
| # First, process vision | |
| if vision_x is not None: | |
| # Vision encoder | |
| vision_features = self.base_model._encode_vision_x(vision_x=vision_x) | |
| else: | |
| vision_features = None | |
| # Now call language model with output_hidden_states=True | |
| # The lang_encoder should support this parameter | |
| lang_output = self.base_model.lang_encoder( | |
| input_ids=lang_x, | |
| attention_mask=attention_mask, | |
| output_hidden_states=True, # FORCE hidden states output! | |
| return_dict=True | |
| ) | |
| # Now we should have hidden states | |
| if hasattr(lang_output, 'hidden_states') and lang_output.hidden_states is not None: | |
| embeddings = lang_output.hidden_states[-1] | |
| print(f" β Got hidden states: {embeddings.shape}") | |
| else: | |
| print(f" β Still no hidden states!") | |
| raise RuntimeError("Cannot get hidden states from language model") | |
| # Apply policy head | |
| actions, gripper, _ = self.policy_head(embeddings) | |
| return {'actions': actions, 'gripper': gripper} | |
| def create_model_and_transforms(checkpoint_path=None): | |
| print("π¦ Creating base...") | |
| base_model, image_processor, tokenizer = create_base( | |
| clip_vision_encoder_path="ViT-L-14", | |
| clip_vision_encoder_pretrained="openai", | |
| lang_encoder_path="mosaicml/mpt-1b-redpajama-200b", | |
| tokenizer_path="mosaicml/mpt-1b-redpajama-200b", | |
| cross_attn_every_n_layers=4, | |
| ) | |
| print("π¨ Creating policy head...") | |
| policy_head = LSTMPolicyHead( | |
| input_dim=2048, | |
| hidden_dim=1024, | |
| num_layers=4 | |
| ) | |
| model = RoboFlamingoWithPolicy(base_model, policy_head) | |
| print("β Model ready") | |
| if checkpoint_path: | |
| print("π₯ Loading checkpoint...") | |
| ckpt_file = hf_hub_download( | |
| repo_id="robovlms/RoboFlamingo", | |
| filename="checkpoint_gripper_post_hist_1_aug_10_4_traj_cons_ws_12_mpt_3b_4.pth", | |
| repo_type="model" | |
| ) | |
| checkpoint = torch.load(ckpt_file, map_location='cpu') | |
| state_dict = checkpoint.get('model_state_dict', checkpoint) | |
| new_state_dict = {} | |
| for key, value in state_dict.items(): | |
| if 'action_head.rnn' in key: | |
| new_key = key.replace('module.action_head.rnn', 'policy_head.lstm') | |
| new_state_dict[new_key] = value | |
| elif 'action_head.actions.mlp' in key: | |
| new_key = key.replace('module.action_head.actions.mlp', 'policy_head.action_head') | |
| new_state_dict[new_key] = value | |
| elif 'action_head.gripper.mlp' in key: | |
| new_key = key.replace('module.action_head.gripper.mlp', 'policy_head.gripper_head') | |
| new_state_dict[new_key] = value | |
| elif 'transformer.wte.weight' in key: | |
| if value.shape[0] == 50280: | |
| value = torch.cat([value, torch.zeros(1, value.shape[1])], dim=0) | |
| new_key = key.replace('module.', 'base_model.') | |
| new_state_dict[new_key] = value | |
| else: | |
| new_key = key.replace('module.', 'base_model.') | |
| new_state_dict[new_key] = value | |
| model.load_state_dict(new_state_dict, strict=False) | |
| print(f"β Checkpoint loaded!") | |
| return model, image_processor, tokenizer | |