Update starvector_arch.py
Browse files- starvector_arch.py +5 -6
starvector_arch.py
CHANGED
@@ -159,23 +159,22 @@ class StarVectorForCausalLM(PreTrainedModel):
|
|
159 |
if hasattr(self.model, 'svg_transformer') and hasattr(self.model.svg_transformer, 'gradient_checkpointing_enable'):
|
160 |
self.model.svg_transformer.gradient_checkpointing_enable()
|
161 |
|
162 |
-
def forward(self,
|
163 |
r"""
|
164 |
Wrapper for the forward pass of the model.
|
165 |
"""
|
166 |
-
device =
|
167 |
|
168 |
completion_embeds = self.model._get_embeddings(input_ids)
|
169 |
-
|
170 |
-
attention_mask = torch.ones_like(
|
171 |
|
172 |
transformer_outputs = self.model.svg_transformer.transformer.transformer(
|
173 |
-
inputs_embeds=
|
174 |
attention_mask=attention_mask,
|
175 |
)
|
176 |
hidden_states = transformer_outputs[0]
|
177 |
|
178 |
-
# If GRPO requested only the last tokens, slice accordingly.
|
179 |
if num_logits_to_keep > 0:
|
180 |
lm_logits = self.model.svg_transformer.transformer.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
181 |
else:
|
|
|
159 |
if hasattr(self.model, 'svg_transformer') and hasattr(self.model.svg_transformer, 'gradient_checkpointing_enable'):
|
160 |
self.model.svg_transformer.gradient_checkpointing_enable()
|
161 |
|
162 |
+
def forward(self, vision_embeds, input_ids, num_generations, num_logits_to_keep) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
163 |
r"""
|
164 |
Wrapper for the forward pass of the model.
|
165 |
"""
|
166 |
+
device = vision_embeds.device
|
167 |
|
168 |
completion_embeds = self.model._get_embeddings(input_ids)
|
169 |
+
vision_embeds = torch.cat([vision_embeds.repeat(num_generations, 1, 1), completion_embeds], dim=1)
|
170 |
+
attention_mask = torch.ones_like(vision_embeds[:, :, 0]).to(device)
|
171 |
|
172 |
transformer_outputs = self.model.svg_transformer.transformer.transformer(
|
173 |
+
inputs_embeds=vision_embeds,
|
174 |
attention_mask=attention_mask,
|
175 |
)
|
176 |
hidden_states = transformer_outputs[0]
|
177 |
|
|
|
178 |
if num_logits_to_keep > 0:
|
179 |
lm_logits = self.model.svg_transformer.transformer.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
180 |
else:
|