Add BridgeAttention adapter + config + policy definition
Browse files- adapter.pt +3 -0
- adapter.safetensors +3 -0
- config.json +27 -0
- policy_definition.py +37 -0
adapter.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4c15b7174ab82dc4590b3e0a35a63c2d833353496a7239ef54feaada727e00e1
|
3 |
+
size 55237642
|
adapter.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3749947322390daf4ca9093633518845885f2bd289bea321ff4ab8c691fa47af
|
3 |
+
size 55220344
|
config.json
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"vision_model_id": "google/siglip-base-patch16-224",
|
3 |
+
"text_model_id": "Qwen/Qwen2.5-0.5B-Instruct",
|
4 |
+
"image_size": 224,
|
5 |
+
"num_action_queries": 64,
|
6 |
+
"policy_dim": 512,
|
7 |
+
"policy_layers": 4,
|
8 |
+
"n_heads": 8,
|
9 |
+
"dropout": 0.1,
|
10 |
+
"action_dim": 43,
|
11 |
+
"state_dim": 43,
|
12 |
+
"per_device_batch": 16,
|
13 |
+
"max_steps": 400,
|
14 |
+
"lr": 0.0003,
|
15 |
+
"warmup_ratio": 0.03,
|
16 |
+
"weight_decay": 0.01,
|
17 |
+
"log_every": 20,
|
18 |
+
"hf_repo": "nvidia/PhysicalAI-Robotics-GR00T-Teleop-G1",
|
19 |
+
"split": "train[:2%]",
|
20 |
+
"num_workers": 4,
|
21 |
+
"enable_vla_download": false,
|
22 |
+
"vla_patterns": [
|
23 |
+
"g1-pick-apple/**",
|
24 |
+
"g1-pick-pear/**"
|
25 |
+
],
|
26 |
+
"local_data_dir": "/kaggle/working/gr00t_g1_subset"
|
27 |
+
}
|
policy_definition.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import math, torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from einops import repeat
|
5 |
+
|
6 |
+
class BridgeAttentionPolicy(nn.Module):
|
7 |
+
def __init__(self, v_hidden, t_hidden, state_dim, policy_dim, n_heads, n_layers, n_queries, action_dim, dropout=0.1):
|
8 |
+
super().__init__()
|
9 |
+
self.n_queries = n_queries
|
10 |
+
self.query = nn.Parameter(torch.randn(n_queries, policy_dim) / math.sqrt(policy_dim))
|
11 |
+
self.v_proj = nn.Linear(v_hidden, policy_dim)
|
12 |
+
self.t_proj = nn.Linear(t_hidden, policy_dim)
|
13 |
+
self.s_proj = nn.Linear(state_dim, policy_dim)
|
14 |
+
self.alpha_v = nn.Parameter(torch.tensor(0.7))
|
15 |
+
self.alpha_t = nn.Parameter(torch.tensor(0.7))
|
16 |
+
self.alpha_s = nn.Parameter(torch.tensor(0.7))
|
17 |
+
enc = nn.TransformerEncoderLayer(d_model=policy_dim, nhead=n_heads, dim_feedforward=policy_dim*4,
|
18 |
+
dropout=dropout, activation="gelu", batch_first=True, norm_first=True)
|
19 |
+
self.blocks = nn.TransformerEncoder(enc, num_layers=n_layers)
|
20 |
+
self.norm = nn.LayerNorm(policy_dim)
|
21 |
+
self.head = nn.Sequential(nn.Linear(policy_dim, policy_dim), nn.GELU(), nn.Linear(policy_dim, action_dim))
|
22 |
+
|
23 |
+
def forward(self, v_feats_layers, t_feats_layers, state_vec):
|
24 |
+
B = state_vec.size(0)
|
25 |
+
v_cat = torch.cat(v_feats_layers, dim=1) if v_feats_layers else None
|
26 |
+
t_cat = torch.cat(t_feats_layers, dim=1)
|
27 |
+
s_tok = self.s_proj(state_vec).unsqueeze(1)
|
28 |
+
toks = [s_tok]
|
29 |
+
if v_cat is not None:
|
30 |
+
toks.append(self.v_proj(v_cat) * torch.sigmoid(self.alpha_v))
|
31 |
+
toks.append(self.t_proj(t_cat) * torch.sigmoid(self.alpha_t))
|
32 |
+
ctx = torch.cat(toks, dim=1)
|
33 |
+
q = repeat(self.query, 'Q D -> B Q D', B=B)
|
34 |
+
tokens = torch.cat([q, ctx], dim=1)
|
35 |
+
tokens = self.blocks(tokens)
|
36 |
+
pooled = self.norm(tokens[:, :self.n_queries].mean(dim=1))
|
37 |
+
return self.head(pooled)
|