AbstractPhil commited on
Commit
e8d2a79
·
verified ·
1 Parent(s): 5cf7984

Create robust_velocity_adapter.py

Browse files
Files changed (1) hide show
  1. robust_velocity_adapter.py +149 -0
robust_velocity_adapter.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ import math
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ class RobustVelocityAdapter(nn.Module):
11
+ """
12
+ Fixed version: manual multi-head cross-attention emits [B, heads, Q, K] scores
13
+ so that _add_rel_pos_bias can unpack them correctly.
14
+ """
15
+ def __init__(
16
+ self,
17
+ t5_dim: int = 512,
18
+ clip_dim: int = 768,
19
+ hidden_dim: int = 1024,
20
+ out_tokens: int = 64, # now aligned with your T5 finetune
21
+ self_attn_layers: int = 2,
22
+ cross_heads: int = 8,
23
+ max_rel_pos: int = 128,
24
+ ):
25
+ super().__init__()
26
+ self.out_tokens = out_tokens
27
+ self.cross_heads = cross_heads
28
+ self.head_dim = t5_dim // cross_heads
29
+ self.max_rel_pos = max_rel_pos
30
+
31
+ # 1) Self-attention stack
32
+ self.self_attn = nn.ModuleList()
33
+ self.self_norm = nn.ModuleList()
34
+ for _ in range(self_attn_layers):
35
+ self.self_attn.append(nn.MultiheadAttention(t5_dim, cross_heads, batch_first=True))
36
+ self.self_norm.append(nn.LayerNorm(t5_dim))
37
+
38
+ # 2) Residual blocks
39
+ def resblock():
40
+ return nn.Sequential(
41
+ nn.LayerNorm(t5_dim),
42
+ nn.Linear(t5_dim, t5_dim),
43
+ nn.GELU(),
44
+ nn.Linear(t5_dim, t5_dim),
45
+ )
46
+ self.res1 = resblock()
47
+ self.res2 = resblock()
48
+
49
+ # 3) Learned queries for cross-attn
50
+ self.query_pos = nn.Parameter(torch.randn(out_tokens, t5_dim))
51
+
52
+ # 4) Projection heads
53
+ self.anchor_proj = nn.Sequential(
54
+ nn.Linear(t5_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, clip_dim)
55
+ )
56
+ self.delta_proj = nn.Sequential(
57
+ nn.Linear(t5_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, clip_dim)
58
+ )
59
+ self.var_proj = nn.Sequential(
60
+ nn.Linear(t5_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, clip_dim)
61
+ )
62
+ self.gate_proj = nn.Sequential(
63
+ nn.Linear(t5_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, clip_dim), nn.Sigmoid()
64
+ )
65
+
66
+ # 5) Relative-position bias table
67
+ self.rel_bias = nn.Parameter(torch.zeros(2*max_rel_pos-1, cross_heads))
68
+
69
+ # 6) Norm after cross-attn
70
+ self.cross_norm = nn.LayerNorm(t5_dim)
71
+
72
+ def _add_rel_pos_bias(self, attn_scores: torch.Tensor) -> torch.Tensor:
73
+ """
74
+ attn_scores: [B, heads, Q, K]
75
+ returns: attn_scores + bias where bias is [B, heads, Q, K]
76
+ """
77
+ B, H, Q, K = attn_scores.shape
78
+ device = attn_scores.device
79
+
80
+ # 1) Query & key position indices
81
+ idx_q = torch.arange(Q, device=device) # [Q]
82
+ idx_k = torch.arange(K, device=device) # [K]
83
+
84
+ # 2) Compute relative distances for every (q, k) pair
85
+ # rel[i,j] = idx_q[i] - idx_k[j]
86
+ rel = idx_q.unsqueeze(1) - idx_k.unsqueeze(0) # [Q, K]
87
+
88
+ # 3) Clamp & shift into bias table range [0, 2*max_rel-2]
89
+ max_rel = self.max_rel_pos
90
+ rel = rel.clamp(-max_rel+1, max_rel-1) + (max_rel - 1)
91
+
92
+ # 4) Lookup per-head biases
93
+ # self.rel_bias has shape [2*max_rel-1, H]
94
+ bias = self.rel_bias[rel] # [Q, K, H]
95
+ bias = bias.permute(2, 0, 1) # [H, Q, K]
96
+
97
+ # 5) Broadcast to [B, H, Q, K] and add
98
+ bias = bias.unsqueeze(0).expand(B, -1, -1, -1)
99
+ return attn_scores + bias
100
+
101
+
102
+ def forward(self, t5_seq: torch.Tensor):
103
+ """
104
+ t5_seq: [B, L, t5_dim]
105
+ returns:
106
+ anchor: [B, out_tokens, clip_dim]
107
+ delta: [B, out_tokens, clip_dim]
108
+ sigma: [B, out_tokens, clip_dim]
109
+ """
110
+ x = t5_seq
111
+ B, L, D = x.shape
112
+
113
+ # 1) Self-attention + residual
114
+ for attn, norm in zip(self.self_attn, self.self_norm):
115
+ res, _ = attn(x, x, x)
116
+ x = norm(x + res)
117
+
118
+ # 2) Residual blocks
119
+ x = x + self.res1(x)
120
+ x = x + self.res2(x)
121
+
122
+ # 3) Prepare queries & split heads
123
+ queries = self.query_pos.unsqueeze(0).expand(B, -1, -1) # [B, Q, D]
124
+ # reshape into heads
125
+ q = queries.view(B, self.out_tokens, self.cross_heads, self.head_dim).permute(0,2,1,3)
126
+ k = x.view(B, L, self.cross_heads, self.head_dim).permute(0,2,1,3)
127
+ v = k
128
+
129
+ # 4) Scaled dot-product to get [B, heads, Q, K]
130
+ scores = (q @ k.transpose(-2,-1)) / math.sqrt(self.head_dim)
131
+ scores = self._add_rel_pos_bias(scores)
132
+ probs = F.softmax(scores, dim=-1) # [B, H, Q, K]
133
+
134
+ # 5) Attend & merge heads → [B, Q, D]
135
+ ctx = probs @ v # [B, H, Q, head_dim]
136
+ ctx = ctx.permute(0,2,1,3).reshape(B, self.out_tokens, D)
137
+ ctx = self.cross_norm(ctx)
138
+
139
+ # 6) Project to anchor, delta_mean, delta_logvar, gate
140
+ anchor = self.anchor_proj(ctx)
141
+ delta_mean = self.delta_proj(ctx)
142
+ delta_logvar = self.var_proj(ctx)
143
+ gate = self.gate_proj(ctx)
144
+
145
+ # 7) Compute sigma & gated delta
146
+ sigma = torch.exp(0.5 * delta_logvar)
147
+ delta = delta_mean * gate
148
+
149
+ return anchor, delta, sigma