Spaces:
Paused
Paused
File size: 4,418 Bytes
7cc7ee4 a0d16a3 8052e7f bc9ccfd a0d16a3 8052e7f bc9ccfd 7cc7ee4 bc9ccfd 7cc7ee4 44a0f72 7cc7ee4 bc9ccfd 7cc7ee4 bc9ccfd 44a0f72 bc9ccfd adf1b6d a0d16a3 8052e7f 7cc7ee4 a0d16a3 8052e7f a0d16a3 adf1b6d bc9ccfd ed14d35 adf1b6d bc9ccfd adf1b6d bc9ccfd 8052e7f adf1b6d 8052e7f bc9ccfd 8052e7f 10f4702 8052e7f 10f4702 bc9ccfd 10f4702 ed14d35 10f4702 ed14d35 adf1b6d 10f4702 8052e7f 44a0f72 adf1b6d a0d16a3 bc9ccfd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
"""Factory with forced hidden states"""
import sys
import torch
import torch.nn as nn
sys.path.insert(0, '/home/user/app')
from open_flamingo.src.factory import create_model_and_transforms as create_base
from huggingface_hub import hf_hub_download
from policy_head import LSTMPolicyHead
class RoboFlamingoWithPolicy(nn.Module):
def __init__(self, base_model, policy_head):
super().__init__()
self.base_model = base_model
self.policy_head = policy_head
self.vision_encoder = base_model.vision_encoder
self.lang_encoder = base_model.lang_encoder
def forward(self, vision_x, lang_x, attention_mask=None):
# Get the internal model
# OpenFlamingo wraps the language model, we need to call it with output_hidden_states
# The base_model is Flamingo, which has lang_encoder
# We need to get embeddings from the language encoder
# First, process vision
if vision_x is not None:
# Vision encoder
vision_features = self.base_model._encode_vision_x(vision_x=vision_x)
else:
vision_features = None
# Now call language model with output_hidden_states=True
# The lang_encoder should support this parameter
lang_output = self.base_model.lang_encoder(
input_ids=lang_x,
attention_mask=attention_mask,
output_hidden_states=True, # FORCE hidden states output!
return_dict=True
)
# Now we should have hidden states
if hasattr(lang_output, 'hidden_states') and lang_output.hidden_states is not None:
embeddings = lang_output.hidden_states[-1]
print(f" β
Got hidden states: {embeddings.shape}")
else:
print(f" β Still no hidden states!")
raise RuntimeError("Cannot get hidden states from language model")
# Apply policy head
actions, gripper, _ = self.policy_head(embeddings)
return {'actions': actions, 'gripper': gripper}
def create_model_and_transforms(checkpoint_path=None):
print("π¦ Creating base...")
base_model, image_processor, tokenizer = create_base(
clip_vision_encoder_path="ViT-L-14",
clip_vision_encoder_pretrained="openai",
lang_encoder_path="mosaicml/mpt-1b-redpajama-200b",
tokenizer_path="mosaicml/mpt-1b-redpajama-200b",
cross_attn_every_n_layers=4,
)
print("π¨ Creating policy head...")
policy_head = LSTMPolicyHead(
input_dim=2048,
hidden_dim=1024,
num_layers=4
)
model = RoboFlamingoWithPolicy(base_model, policy_head)
print("β
Model ready")
if checkpoint_path:
print("π₯ Loading checkpoint...")
ckpt_file = hf_hub_download(
repo_id="robovlms/RoboFlamingo",
filename="checkpoint_gripper_post_hist_1_aug_10_4_traj_cons_ws_12_mpt_3b_4.pth",
repo_type="model"
)
checkpoint = torch.load(ckpt_file, map_location='cpu')
state_dict = checkpoint.get('model_state_dict', checkpoint)
new_state_dict = {}
for key, value in state_dict.items():
if 'action_head.rnn' in key:
new_key = key.replace('module.action_head.rnn', 'policy_head.lstm')
new_state_dict[new_key] = value
elif 'action_head.actions.mlp' in key:
new_key = key.replace('module.action_head.actions.mlp', 'policy_head.action_head')
new_state_dict[new_key] = value
elif 'action_head.gripper.mlp' in key:
new_key = key.replace('module.action_head.gripper.mlp', 'policy_head.gripper_head')
new_state_dict[new_key] = value
elif 'transformer.wte.weight' in key:
if value.shape[0] == 50280:
value = torch.cat([value, torch.zeros(1, value.shape[1])], dim=0)
new_key = key.replace('module.', 'base_model.')
new_state_dict[new_key] = value
else:
new_key = key.replace('module.', 'base_model.')
new_state_dict[new_key] = value
model.load_state_dict(new_state_dict, strict=False)
print(f"β
Checkpoint loaded!")
return model, image_processor, tokenizer
|