File size: 4,418 Bytes
7cc7ee4
a0d16a3
8052e7f
bc9ccfd
a0d16a3
 
 
8052e7f
bc9ccfd
 
 
 
 
 
 
 
 
 
 
7cc7ee4
 
 
 
 
bc9ccfd
7cc7ee4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44a0f72
7cc7ee4
 
 
 
bc9ccfd
7cc7ee4
 
bc9ccfd
44a0f72
bc9ccfd
 
adf1b6d
a0d16a3
8052e7f
7cc7ee4
a0d16a3
8052e7f
 
 
 
 
a0d16a3
 
adf1b6d
bc9ccfd
ed14d35
 
adf1b6d
bc9ccfd
 
 
adf1b6d
bc9ccfd
8052e7f
adf1b6d
8052e7f
 
 
bc9ccfd
8052e7f
 
 
10f4702
8052e7f
10f4702
bc9ccfd
10f4702
ed14d35
10f4702
ed14d35
 
 
 
 
 
 
adf1b6d
 
 
 
 
10f4702
 
 
8052e7f
44a0f72
adf1b6d
a0d16a3
bc9ccfd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""Factory with forced hidden states"""
import sys
import torch
import torch.nn as nn
sys.path.insert(0, '/home/user/app')

from open_flamingo.src.factory import create_model_and_transforms as create_base
from huggingface_hub import hf_hub_download
from policy_head import LSTMPolicyHead

class RoboFlamingoWithPolicy(nn.Module):
    def __init__(self, base_model, policy_head):
        super().__init__()
        self.base_model = base_model
        self.policy_head = policy_head
        self.vision_encoder = base_model.vision_encoder
        self.lang_encoder = base_model.lang_encoder
        
    def forward(self, vision_x, lang_x, attention_mask=None):
        # Get the internal model
        # OpenFlamingo wraps the language model, we need to call it with output_hidden_states
        
        # The base_model is Flamingo, which has lang_encoder
        # We need to get embeddings from the language encoder
        
        # First, process vision
        if vision_x is not None:
            # Vision encoder
            vision_features = self.base_model._encode_vision_x(vision_x=vision_x)
        else:
            vision_features = None
        
        # Now call language model with output_hidden_states=True
        # The lang_encoder should support this parameter
        lang_output = self.base_model.lang_encoder(
            input_ids=lang_x,
            attention_mask=attention_mask,
            output_hidden_states=True,  # FORCE hidden states output!
            return_dict=True
        )
        
        # Now we should have hidden states
        if hasattr(lang_output, 'hidden_states') and lang_output.hidden_states is not None:
            embeddings = lang_output.hidden_states[-1]
            print(f"   βœ… Got hidden states: {embeddings.shape}")
        else:
            print(f"   ❌ Still no hidden states!")
            raise RuntimeError("Cannot get hidden states from language model")
        
        # Apply policy head
        actions, gripper, _ = self.policy_head(embeddings)
        
        return {'actions': actions, 'gripper': gripper}

def create_model_and_transforms(checkpoint_path=None):
    print("πŸ“¦ Creating base...")
    base_model, image_processor, tokenizer = create_base(
        clip_vision_encoder_path="ViT-L-14",
        clip_vision_encoder_pretrained="openai",
        lang_encoder_path="mosaicml/mpt-1b-redpajama-200b",
        tokenizer_path="mosaicml/mpt-1b-redpajama-200b",
        cross_attn_every_n_layers=4,
    )
    
    print("πŸ”¨ Creating policy head...")
    policy_head = LSTMPolicyHead(
        input_dim=2048,
        hidden_dim=1024,
        num_layers=4
    )
    
    model = RoboFlamingoWithPolicy(base_model, policy_head)
    print("βœ… Model ready")
    
    if checkpoint_path:
        print("πŸ“₯ Loading checkpoint...")
        ckpt_file = hf_hub_download(
            repo_id="robovlms/RoboFlamingo",
            filename="checkpoint_gripper_post_hist_1_aug_10_4_traj_cons_ws_12_mpt_3b_4.pth",
            repo_type="model"
        )
        
        checkpoint = torch.load(ckpt_file, map_location='cpu')
        state_dict = checkpoint.get('model_state_dict', checkpoint)
        
        new_state_dict = {}
        
        for key, value in state_dict.items():
            if 'action_head.rnn' in key:
                new_key = key.replace('module.action_head.rnn', 'policy_head.lstm')
                new_state_dict[new_key] = value
            elif 'action_head.actions.mlp' in key:
                new_key = key.replace('module.action_head.actions.mlp', 'policy_head.action_head')
                new_state_dict[new_key] = value
            elif 'action_head.gripper.mlp' in key:
                new_key = key.replace('module.action_head.gripper.mlp', 'policy_head.gripper_head')
                new_state_dict[new_key] = value
            elif 'transformer.wte.weight' in key:
                if value.shape[0] == 50280:
                    value = torch.cat([value, torch.zeros(1, value.shape[1])], dim=0)
                new_key = key.replace('module.', 'base_model.')
                new_state_dict[new_key] = value
            else:
                new_key = key.replace('module.', 'base_model.')
                new_state_dict[new_key] = value
        
        model.load_state_dict(new_state_dict, strict=False)
        print(f"βœ… Checkpoint loaded!")
    
    return model, image_processor, tokenizer