momergul commited on
Commit
79df9ca
·
verified ·
1 Parent(s): 946471f

Upload modeling_flamingo.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_flamingo.py +133 -0
modeling_flamingo.py CHANGED
@@ -560,3 +560,136 @@ class FlamingoForCausalLM(modeling_opt.OPTForCausalLM):
560
  hidden_states=outputs.hidden_states,
561
  attentions=outputs.attentions,
562
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
560
  hidden_states=outputs.hidden_states,
561
  attentions=outputs.attentions,
562
  )
563
+
564
+ class FlamingoModel(modeling_opt.OPTForCausalLM):
565
+ _keys_to_ignore_on_load_missing = [
566
+ r"lm_head.weight",
567
+ ]
568
+ config_class = FlamingoConfig
569
+
570
+ def __init__(self, config):
571
+ OPTPreTrainedModel.__init__(self, config)
572
+ config = setup_default_flamingo_configs(config)
573
+ self.model = OPTModel(config)
574
+
575
+ # the lm_head weight is automatically tied to the embed tokens weight
576
+ self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
577
+
578
+ # Initialize weights and apply final processing
579
+ self.post_init()
580
+ self.model.decoder.img_encoder = None
581
+ self.loss_fct = CrossEntropyLoss()
582
+ dino_model = ViTModel.from_pretrained("facebook/dino-vitb16")
583
+ self.setup_vis_encoder(dino_model)
584
+
585
+ def setup_vis_encoder(self, img_encoder):
586
+ self.model.decoder.img_encoder = img_encoder
587
+ freeze_all_layers_(img_encoder)
588
+
589
+ def forward(
590
+ self,
591
+ input_ids: torch.LongTensor = None,
592
+ attention_mask: Optional[torch.Tensor] = None,
593
+ head_mask: Optional[torch.Tensor] = None,
594
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
595
+ inputs_embeds: Optional[torch.FloatTensor] = None,
596
+ labels: Optional[torch.LongTensor] = None,
597
+ use_cache: Optional[bool] = None,
598
+ output_attentions: Optional[bool] = None,
599
+ output_hidden_states: Optional[bool] = None,
600
+ return_dict: Optional[bool] = None,
601
+ *args, **kwargs) -> Union[Tuple, CausalLMOutputWithPast]:
602
+ r"""
603
+ Args:
604
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
605
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
606
+ provide it.
607
+
608
+ Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
609
+ [`PreTrainedTokenizer.__call__`] for details.
610
+
611
+ [What are input IDs?](../glossary#input-ids)
612
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
613
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
614
+
615
+ - 1 for tokens that are **not masked**,
616
+ - 0 for tokens that are **masked**.
617
+
618
+ [What are attention masks?](../glossary#attention-mask)
619
+ head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
620
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
621
+
622
+ - 1 indicates the head is **not masked**,
623
+ - 0 indicates the head is **masked**.
624
+
625
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
626
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
627
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
628
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
629
+ tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
630
+
631
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
632
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
633
+
634
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
635
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
636
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
637
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
638
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
639
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
640
+ than the model's internal embedding lookup matrix.
641
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
642
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
643
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
644
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
645
+ use_cache (`bool`, *optional*):
646
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
647
+ (see `past_key_values`).
648
+ output_attentions (`bool`, *optional*):
649
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
650
+ returned tensors for more detail.
651
+ output_hidden_states (`bool`, *optional*):
652
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
653
+ for more detail.
654
+ return_dict (`bool`, *optional*):
655
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
656
+
657
+ Returns:
658
+
659
+ Example:
660
+
661
+ ```python
662
+ >>> from transformers import GPT2Tokenizer, OPTForCausalLM
663
+
664
+ >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
665
+ >>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")
666
+
667
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
668
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
669
+
670
+ >>> # Generate
671
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
672
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
673
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
674
+ ```"""
675
+
676
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
677
+ output_hidden_states = (
678
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
679
+ )
680
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
681
+
682
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
683
+ outputs = self.model.decoder(
684
+ input_ids=input_ids,
685
+ attention_mask=attention_mask,
686
+ head_mask=head_mask,
687
+ past_key_values=past_key_values,
688
+ inputs_embeds=inputs_embeds,
689
+ use_cache=use_cache,
690
+ output_attentions=output_attentions,
691
+ output_hidden_states=output_hidden_states,
692
+ return_dict=return_dict,
693
+ *args, **kwargs)
694
+
695
+ return outputs