roboflamingo-demo / patched_factory.py
aw1app's picture
Force output_hidden_states=True in lang_encoder forward call
7cc7ee4
"""Factory with forced hidden states"""
import sys
import torch
import torch.nn as nn
sys.path.insert(0, '/home/user/app')
from open_flamingo.src.factory import create_model_and_transforms as create_base
from huggingface_hub import hf_hub_download
from policy_head import LSTMPolicyHead
class RoboFlamingoWithPolicy(nn.Module):
def __init__(self, base_model, policy_head):
super().__init__()
self.base_model = base_model
self.policy_head = policy_head
self.vision_encoder = base_model.vision_encoder
self.lang_encoder = base_model.lang_encoder
def forward(self, vision_x, lang_x, attention_mask=None):
# Get the internal model
# OpenFlamingo wraps the language model, we need to call it with output_hidden_states
# The base_model is Flamingo, which has lang_encoder
# We need to get embeddings from the language encoder
# First, process vision
if vision_x is not None:
# Vision encoder
vision_features = self.base_model._encode_vision_x(vision_x=vision_x)
else:
vision_features = None
# Now call language model with output_hidden_states=True
# The lang_encoder should support this parameter
lang_output = self.base_model.lang_encoder(
input_ids=lang_x,
attention_mask=attention_mask,
output_hidden_states=True, # FORCE hidden states output!
return_dict=True
)
# Now we should have hidden states
if hasattr(lang_output, 'hidden_states') and lang_output.hidden_states is not None:
embeddings = lang_output.hidden_states[-1]
print(f" βœ… Got hidden states: {embeddings.shape}")
else:
print(f" ❌ Still no hidden states!")
raise RuntimeError("Cannot get hidden states from language model")
# Apply policy head
actions, gripper, _ = self.policy_head(embeddings)
return {'actions': actions, 'gripper': gripper}
def create_model_and_transforms(checkpoint_path=None):
print("πŸ“¦ Creating base...")
base_model, image_processor, tokenizer = create_base(
clip_vision_encoder_path="ViT-L-14",
clip_vision_encoder_pretrained="openai",
lang_encoder_path="mosaicml/mpt-1b-redpajama-200b",
tokenizer_path="mosaicml/mpt-1b-redpajama-200b",
cross_attn_every_n_layers=4,
)
print("πŸ”¨ Creating policy head...")
policy_head = LSTMPolicyHead(
input_dim=2048,
hidden_dim=1024,
num_layers=4
)
model = RoboFlamingoWithPolicy(base_model, policy_head)
print("βœ… Model ready")
if checkpoint_path:
print("πŸ“₯ Loading checkpoint...")
ckpt_file = hf_hub_download(
repo_id="robovlms/RoboFlamingo",
filename="checkpoint_gripper_post_hist_1_aug_10_4_traj_cons_ws_12_mpt_3b_4.pth",
repo_type="model"
)
checkpoint = torch.load(ckpt_file, map_location='cpu')
state_dict = checkpoint.get('model_state_dict', checkpoint)
new_state_dict = {}
for key, value in state_dict.items():
if 'action_head.rnn' in key:
new_key = key.replace('module.action_head.rnn', 'policy_head.lstm')
new_state_dict[new_key] = value
elif 'action_head.actions.mlp' in key:
new_key = key.replace('module.action_head.actions.mlp', 'policy_head.action_head')
new_state_dict[new_key] = value
elif 'action_head.gripper.mlp' in key:
new_key = key.replace('module.action_head.gripper.mlp', 'policy_head.gripper_head')
new_state_dict[new_key] = value
elif 'transformer.wte.weight' in key:
if value.shape[0] == 50280:
value = torch.cat([value, torch.zeros(1, value.shape[1])], dim=0)
new_key = key.replace('module.', 'base_model.')
new_state_dict[new_key] = value
else:
new_key = key.replace('module.', 'base_model.')
new_state_dict[new_key] = value
model.load_state_dict(new_state_dict, strict=False)
print(f"βœ… Checkpoint loaded!")
return model, image_processor, tokenizer