Nirav-Madhani commited on
Commit
fdb2753
·
verified ·
1 Parent(s): e943c73

Add BridgeAttention adapter + config + policy definition

Browse files
Files changed (4) hide show
  1. adapter.pt +3 -0
  2. adapter.safetensors +3 -0
  3. config.json +27 -0
  4. policy_definition.py +37 -0
adapter.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c15b7174ab82dc4590b3e0a35a63c2d833353496a7239ef54feaada727e00e1
3
+ size 55237642
adapter.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3749947322390daf4ca9093633518845885f2bd289bea321ff4ab8c691fa47af
3
+ size 55220344
config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vision_model_id": "google/siglip-base-patch16-224",
3
+ "text_model_id": "Qwen/Qwen2.5-0.5B-Instruct",
4
+ "image_size": 224,
5
+ "num_action_queries": 64,
6
+ "policy_dim": 512,
7
+ "policy_layers": 4,
8
+ "n_heads": 8,
9
+ "dropout": 0.1,
10
+ "action_dim": 43,
11
+ "state_dim": 43,
12
+ "per_device_batch": 16,
13
+ "max_steps": 400,
14
+ "lr": 0.0003,
15
+ "warmup_ratio": 0.03,
16
+ "weight_decay": 0.01,
17
+ "log_every": 20,
18
+ "hf_repo": "nvidia/PhysicalAI-Robotics-GR00T-Teleop-G1",
19
+ "split": "train[:2%]",
20
+ "num_workers": 4,
21
+ "enable_vla_download": false,
22
+ "vla_patterns": [
23
+ "g1-pick-apple/**",
24
+ "g1-pick-pear/**"
25
+ ],
26
+ "local_data_dir": "/kaggle/working/gr00t_g1_subset"
27
+ }
policy_definition.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math, torch
3
+ import torch.nn as nn
4
+ from einops import repeat
5
+
6
+ class BridgeAttentionPolicy(nn.Module):
7
+ def __init__(self, v_hidden, t_hidden, state_dim, policy_dim, n_heads, n_layers, n_queries, action_dim, dropout=0.1):
8
+ super().__init__()
9
+ self.n_queries = n_queries
10
+ self.query = nn.Parameter(torch.randn(n_queries, policy_dim) / math.sqrt(policy_dim))
11
+ self.v_proj = nn.Linear(v_hidden, policy_dim)
12
+ self.t_proj = nn.Linear(t_hidden, policy_dim)
13
+ self.s_proj = nn.Linear(state_dim, policy_dim)
14
+ self.alpha_v = nn.Parameter(torch.tensor(0.7))
15
+ self.alpha_t = nn.Parameter(torch.tensor(0.7))
16
+ self.alpha_s = nn.Parameter(torch.tensor(0.7))
17
+ enc = nn.TransformerEncoderLayer(d_model=policy_dim, nhead=n_heads, dim_feedforward=policy_dim*4,
18
+ dropout=dropout, activation="gelu", batch_first=True, norm_first=True)
19
+ self.blocks = nn.TransformerEncoder(enc, num_layers=n_layers)
20
+ self.norm = nn.LayerNorm(policy_dim)
21
+ self.head = nn.Sequential(nn.Linear(policy_dim, policy_dim), nn.GELU(), nn.Linear(policy_dim, action_dim))
22
+
23
+ def forward(self, v_feats_layers, t_feats_layers, state_vec):
24
+ B = state_vec.size(0)
25
+ v_cat = torch.cat(v_feats_layers, dim=1) if v_feats_layers else None
26
+ t_cat = torch.cat(t_feats_layers, dim=1)
27
+ s_tok = self.s_proj(state_vec).unsqueeze(1)
28
+ toks = [s_tok]
29
+ if v_cat is not None:
30
+ toks.append(self.v_proj(v_cat) * torch.sigmoid(self.alpha_v))
31
+ toks.append(self.t_proj(t_cat) * torch.sigmoid(self.alpha_t))
32
+ ctx = torch.cat(toks, dim=1)
33
+ q = repeat(self.query, 'Q D -> B Q D', B=B)
34
+ tokens = torch.cat([q, ctx], dim=1)
35
+ tokens = self.blocks(tokens)
36
+ pooled = self.norm(tokens[:, :self.n_queries].mean(dim=1))
37
+ return self.head(pooled)