joanrodai commited on
Commit
5e6d1bb
·
verified ·
1 Parent(s): 27cff09

Update starvector_arch.py

Browse files
Files changed (1) hide show
  1. starvector_arch.py +5 -6
starvector_arch.py CHANGED
@@ -159,23 +159,22 @@ class StarVectorForCausalLM(PreTrainedModel):
159
  if hasattr(self.model, 'svg_transformer') and hasattr(self.model.svg_transformer, 'gradient_checkpointing_enable'):
160
  self.model.svg_transformer.gradient_checkpointing_enable()
161
 
162
- def forward(self, inputs_embeds, input_ids, num_generations, num_logits_to_keep) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
163
  r"""
164
  Wrapper for the forward pass of the model.
165
  """
166
- device = inputs_embeds.device
167
 
168
  completion_embeds = self.model._get_embeddings(input_ids)
169
- inputs_embeds = torch.cat([inputs_embeds.repeat(num_generations, 1, 1), completion_embeds], dim=1)
170
- attention_mask = torch.ones_like(inputs_embeds[:, :, 0]).to(device)
171
 
172
  transformer_outputs = self.model.svg_transformer.transformer.transformer(
173
- inputs_embeds=inputs_embeds,
174
  attention_mask=attention_mask,
175
  )
176
  hidden_states = transformer_outputs[0]
177
 
178
- # If GRPO requested only the last tokens, slice accordingly.
179
  if num_logits_to_keep > 0:
180
  lm_logits = self.model.svg_transformer.transformer.lm_head(hidden_states[:, -num_logits_to_keep:, :])
181
  else:
 
159
  if hasattr(self.model, 'svg_transformer') and hasattr(self.model.svg_transformer, 'gradient_checkpointing_enable'):
160
  self.model.svg_transformer.gradient_checkpointing_enable()
161
 
162
+ def forward(self, vision_embeds, input_ids, num_generations, num_logits_to_keep) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
163
  r"""
164
  Wrapper for the forward pass of the model.
165
  """
166
+ device = vision_embeds.device
167
 
168
  completion_embeds = self.model._get_embeddings(input_ids)
169
+ vision_embeds = torch.cat([vision_embeds.repeat(num_generations, 1, 1), completion_embeds], dim=1)
170
+ attention_mask = torch.ones_like(vision_embeds[:, :, 0]).to(device)
171
 
172
  transformer_outputs = self.model.svg_transformer.transformer.transformer(
173
+ inputs_embeds=vision_embeds,
174
  attention_mask=attention_mask,
175
  )
176
  hidden_states = transformer_outputs[0]
177
 
 
178
  if num_logits_to_keep > 0:
179
  lm_logits = self.model.svg_transformer.transformer.lm_head(hidden_states[:, -num_logits_to_keep:, :])
180
  else: