# coding=utf-8 # Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch DeepSeek model.""" import math import os import re import warnings from typing import List, Optional, Tuple, Union import requests import torch import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_attn_mask_utils import ( AttentionMaskConverter, _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask, ) from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast, ) from transformers.modeling_utils import PreTrainedModel from transformers.pytorch_utils import ( ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13, ) from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) from transformers.utils.import_utils import is_torch_fx_available from .configuration_deepseek import DeepseekV3Config import torch.distributed as dist import numpy as np if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. # It means that the function will not be traced through and simply appear as a node in the graph. if is_torch_fx_available(): if not is_torch_greater_or_equal_than_1_13: import torch.fx _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "DeepseekV3Config" def _get_unpad_data(attention_mask): seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad( torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) ) return ( indices, cu_seqlens, max_seqlen_in_batch, ) class DeepseekV3RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ DeepseekV3RMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) ALL_LAYERNORM_LAYERS.append(DeepseekV3RMSNorm) class DeepseekV3RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / ( self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) ) self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype(), ) self.max_seq_len_cached = None def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len t = torch.arange( self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype ) freqs = torch.outer(t, self.inv_freq.to(t.device)) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( self.cos_cached[:seq_len].to(dtype=x.dtype), self.sin_cached[:seq_len].to(dtype=x.dtype), ) # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV3 class DeepseekV3LinearScalingRotaryEmbedding(DeepseekV3RotaryEmbedding): """DeepseekV3RotaryEmbedding extended with linear scaling.""" def __init__( self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0, ): self.scaling_factor = scaling_factor super().__init__(dim, max_position_embeddings, base, device) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len t = torch.arange( self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype ) t = t / self.scaling_factor freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV3 class DeepseekV3DynamicNTKScalingRotaryEmbedding(DeepseekV3RotaryEmbedding): """DeepseekV3RotaryEmbedding extended with Dynamic NTK scaling.""" def __init__( self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0, ): self.scaling_factor = scaling_factor super().__init__(dim, max_position_embeddings, base, device) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len if seq_len > self.max_position_embeddings: base = self.base * ( (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) ) ** (self.dim / (self.dim - 2)) inv_freq = 1.0 / ( base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) ) self.register_buffer("inv_freq", inv_freq, persistent=False) t = torch.arange( self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype ) freqs = torch.outer(t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): cos = cos[position_ids].unsqueeze(unsqueeze_dim) sin = sin[position_ids].unsqueeze(unsqueeze_dim) b, h, s, d = q.shape q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class DeepseekV3MLP(nn.Module): def __init__(self, config, hidden_size=None, intermediate_size=None): super().__init__() self.config = config self.hidden_size = config.hidden_size if hidden_size is None else hidden_size self.intermediate_size = ( config.intermediate_size if intermediate_size is None else intermediate_size ) self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj class MoEGate(nn.Module): def __init__(self, config): super().__init__() self.config = config self.top_k = config.num_experts_per_tok self.n_routed_experts = config.n_routed_experts self.routed_scaling_factor = config.routed_scaling_factor self.scoring_func = config.scoring_func self.seq_aux = config.seq_aux self.topk_method = config.topk_method self.n_group = config.n_group self.topk_group = config.topk_group # topk selection algorithm self.norm_topk_prob = config.norm_topk_prob self.gating_dim = config.hidden_size self.weight = nn.Parameter( torch.empty((self.n_routed_experts, self.gating_dim)) ) if self.topk_method == "noaux_tc": self.e_score_correction_bias = nn.Parameter( torch.empty((self.n_routed_experts)) ) self.reset_parameters() def reset_parameters(self) -> None: import torch.nn.init as init init.kaiming_uniform_(self.weight, a=math.sqrt(5)) def forward(self, hidden_states): bsz, seq_len, h = hidden_states.shape hidden_states = hidden_states.view(-1, h) logits = F.linear( hidden_states.type(torch.float32), self.weight.type(torch.float32), None ) if self.scoring_func == "sigmoid": scores = logits.sigmoid() else: raise NotImplementedError( f"insupportable scoring function for MoE gating: {self.scoring_func}" ) if self.topk_method == "noaux_tc": assert not self.training scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0) group_scores = ( scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim = -1) ) group_idx = torch.topk( group_scores, k=self.topk_group, dim=-1, sorted=False )[1] group_mask = torch.zeros_like(group_scores) group_mask.scatter_(1, group_idx, 1) score_mask = ( group_mask.unsqueeze(-1) .expand( bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group ) .reshape(bsz * seq_len, -1) ) tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) _, topk_idx = torch.topk( tmp_scores, k=self.top_k, dim=-1, sorted=False ) topk_weight = scores.gather(1, topk_idx) else: raise NotImplementedError( f"insupportable TopK function for MoE gating: {self.topk_method}" ) if self.top_k > 1 and self.norm_topk_prob: denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 topk_weight = topk_weight / denominator topk_weight = topk_weight * self.routed_scaling_factor return topk_idx, topk_weight class DeepseekV3MoE(nn.Module): """ A mixed expert module containing shared experts. """ def __init__(self, config): super().__init__() self.config = config self.num_experts_per_tok = config.num_experts_per_tok if hasattr(config, "ep_size") and config.ep_size > 1: assert config.ep_size == dist.get_world_size() self.ep_size = config.ep_size self.experts_per_rank = config.n_routed_experts // config.ep_size self.ep_rank = dist.get_rank() self.experts = nn.ModuleList( [ ( DeepseekV3MLP( config, intermediate_size=config.moe_intermediate_size ) if i >= self.ep_rank * self.experts_per_rank and i < (self.ep_rank + 1) * self.experts_per_rank else None ) for i in range(config.n_routed_experts) ] ) else: self.ep_size = 1 self.experts_per_rank = config.n_routed_experts self.ep_rank = 0 self.experts = nn.ModuleList( [ DeepseekV3MLP( config, intermediate_size=config.moe_intermediate_size ) for i in range(config.n_routed_experts) ] ) self.gate = MoEGate(config) if config.n_shared_experts is not None: intermediate_size = config.moe_intermediate_size * config.n_shared_experts self.shared_experts = DeepseekV3MLP( config=config, intermediate_size=intermediate_size ) def forward(self, hidden_states): identity = hidden_states orig_shape = hidden_states.shape topk_idx, topk_weight = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) flat_topk_idx = topk_idx.view(-1) if not self.training: y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape) if self.config.n_shared_experts is not None: y = y + self.shared_experts(identity) return y @torch.no_grad() def moe_infer(self, x, topk_ids, topk_weight): cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) cnts.scatter_(1, topk_ids, 1) tokens_per_expert = cnts.sum(dim=0) idxs = topk_ids.view(-1).argsort() sorted_tokens = x[idxs // topk_ids.shape[1]] sorted_tokens_shape = sorted_tokens.shape if self.ep_size > 1: tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1) tokens_per_expert_group = tokens_per_expert.new_empty( tokens_per_expert.shape[0] ) dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert) output_splits = ( tokens_per_expert_group.view(self.ep_size, -1) .sum(1) .cpu() .numpy() .tolist() ) gathered_tokens = sorted_tokens.new_empty( tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1] ) input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist() dist.all_to_all( list(gathered_tokens.split(output_splits)), list(sorted_tokens.split(input_split_sizes)), ) tokens_per_expert_post_gather = tokens_per_expert_group.view( self.ep_size, self.experts_per_rank ).sum(dim=0) gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32) s = 0 for i, k in enumerate(tokens_per_expert_group.cpu().numpy()): gatherd_idxs[s : s + k] = i % self.experts_per_rank s += k gatherd_idxs = gatherd_idxs.argsort() sorted_tokens = gathered_tokens[gatherd_idxs] tokens_per_expert = tokens_per_expert_post_gather tokens_per_expert = tokens_per_expert.cpu().numpy() outputs = [] start_idx = 0 for i, num_tokens in enumerate(tokens_per_expert): end_idx = start_idx + num_tokens if num_tokens == 0: continue expert = self.experts[i + self.ep_rank * self.experts_per_rank] tokens_for_this_expert = sorted_tokens[start_idx:end_idx] expert_out = expert(tokens_for_this_expert) outputs.append(expert_out) start_idx = end_idx outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) if self.ep_size > 1: new_x = torch.empty_like(outs) new_x[gatherd_idxs] = outs gathered_tokens = new_x.new_empty(*sorted_tokens_shape) dist.all_to_all( list(gathered_tokens.split(input_split_sizes)), list(new_x.split(output_splits)), ) outs = gathered_tokens new_x = torch.empty_like(outs) new_x[idxs] = outs final_out = ( new_x.view(*topk_ids.shape, -1) .type(topk_weight.dtype) .mul_(topk_weight.unsqueeze(dim=-1)) .sum(dim=1) .type(new_x.dtype) ) return final_out def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ repeat_kv used by grouped query attention (MQA/GQA). """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand( batch, num_key_value_heads, n_rep, slen, head_dim ) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) class DeepseekV3Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: DeepseekV3Config, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx if layer_idx is None: logger.warning_once( f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " "lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " "when creating this class." ) self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.q_lora_rank = config.q_lora_rank self.qk_rope_head_dim = config.qk_rope_head_dim self.kv_lora_rank = config.kv_lora_rank self.v_head_dim = config.v_head_dim self.qk_nope_head_dim = config.qk_nope_head_dim self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim self.is_causal = True if self.q_lora_rank is None: self.q_proj = nn.Linear( self.hidden_size, self.num_heads * self.q_head_dim, bias=False ) else: self.q_a_proj = nn.Linear( self.hidden_size, config.q_lora_rank, bias=config.attention_bias ) self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank) self.q_b_proj = nn.Linear( config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False ) self.kv_a_proj_with_mqa = nn.Linear( self.hidden_size, config.kv_lora_rank + config.qk_rope_head_dim, bias=config.attention_bias, ) self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank) self.kv_b_proj = nn.Linear( config.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), bias=False, ) self.o_proj = nn.Linear( self.num_heads * self.v_head_dim, self.hidden_size, bias=config.attention_bias, ) self._init_rope() self.softmax_scale = self.q_head_dim ** (-0.5) def _init_rope(self): # Minimal demonstration, ignoring dynamic/linear yarn scaling for brevity self.rotary_emb = DeepseekV3RotaryEmbedding( self.qk_rope_head_dim, max_position_embeddings=self.config.max_position_embeddings, base=self.config.rope_theta, ) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return ( tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim) .transpose(1, 2) .contiguous() ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() if self.q_lora_rank is None: q = self.q_proj(hidden_states) else: q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) q_nope, q_pe = torch.split( q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 ) compressed_kv = self.kv_a_proj_with_mqa(hidden_states) compressed_kv, k_pe = torch.split( compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 ) k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) kv = ( self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) .transpose(1, 2) ) k_nope, value_states = torch.split( kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 ) kv_seq_len = value_states.shape[-2] if past_key_value is not None and self.layer_idx is not None: kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) query_states[:, :, :, : self.qk_nope_head_dim] = q_nope query_states[:, :, :, self.qk_nope_head_dim :] = q_pe key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) key_states[:, :, :, : self.qk_nope_head_dim] = k_nope key_states[:, :, :, self.qk_nope_head_dim :] = k_pe if past_key_value is not None and self.layer_idx is not None: cache_kwargs = {"sin": sin, "cos": cos} key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs ) attn_weights = ( torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale ) if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( query_states.dtype ) attn_weights = nn.functional.dropout( attn_weights, p=self.attention_dropout, training=self.training ) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV3 class DeepseekV3FlashAttention2(DeepseekV3Attention): """ DeepseekV3 flash attention module. This module inherits from `DeepseekV3Attention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # DeepseekV3FlashAttention2 attention does not support output_attentions if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) # overwrite attention_mask with padding_mask attention_mask = kwargs.pop("padding_mask") output_attentions = False bsz, q_len, _ = hidden_states.size() if self.q_lora_rank is None: q = self.q_proj(hidden_states) else: q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) q_nope, q_pe = torch.split( q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 ) # Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim # therefore we just need to keep the original shape compressed_kv = self.kv_a_proj_with_mqa(hidden_states) compressed_kv, k_pe = torch.split( compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 ) k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) kv = ( self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) .transpose(1, 2) ) k_nope, value_states = torch.split( kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 ) kv_seq_len = value_states.shape[-2] kv_seq_len = value_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) query_states[:, :, :, : self.qk_nope_head_dim] = q_nope query_states[:, :, :, self.qk_nope_head_dim :] = q_pe key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) key_states[:, :, :, : self.qk_nope_head_dim] = k_nope key_states[:, :, :, self.qk_nope_head_dim :] = k_pe if self.q_head_dim != self.v_head_dim: value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs ) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) dropout_rate = self.attention_dropout if self.training else 0.0 # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in the correct dtype just to be sure everything works as expected. # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. (DeepseekV3RMSNorm handles it correctly) input_dtype = query_states.dtype if input_dtype == torch.float32: # Handle the case where the model is quantized if hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype elif torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() else: target_dtype = ( self.q_proj.weight.dtype if self.q_lora_rank is None else self.q_a_proj.weight.dtype ) logger.warning_once( f"The input hidden states seems to be silently casted in float32, this might be related to" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" f" {target_dtype}." ) query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) attn_output = self._flash_attention_forward( query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate, softmax_scale=self.softmax_scale, ) if self.q_head_dim != self.v_head_dim: attn_output = attn_output[:, :, :, : self.v_head_dim] attn_output = attn_output.reshape( bsz, q_len, self.num_heads * self.v_head_dim ).contiguous() attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value def _flash_attention_forward( self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None, ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token first unpad the input, then computes the attention scores and pad the final attention scores. Args: query_states (`torch.Tensor`): Input query states to be passed to Flash Attention API key_states (`torch.Tensor`): Input key states to be passed to Flash Attention API value_states (`torch.Tensor`): Input value states to be passed to Flash Attention API attention_mask (`torch.Tensor`): The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the position of padding tokens and 1 for the position of non-padding tokens. dropout (`int`, *optional*): Attention dropout softmax_scale (`float`, *optional*): The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) """ if not self._flash_attn_uses_top_left_mask: causal = self.is_causal else: # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV3FlashAttention2 __init__. causal = self.is_causal and query_length != 1 # Contains at least one padding token in the sequence if attention_mask is not None: batch_size = query_states.shape[0] ( query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens, ) = self._upad_input( query_states, key_states, value_states, attention_mask, query_length ) cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens attn_output_unpad = flash_attn_varlen_func( query_states, key_states, value_states, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, causal=causal, ) attn_output = pad_input( attn_output_unpad, indices_q, batch_size, query_length ) else: attn_output = flash_attn_func( query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, ) return attn_output def _upad_input( self, query_layer, key_layer, value_layer, attention_mask, query_length ): indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape key_layer = index_first_axis( key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k, ) value_layer = index_first_axis( value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k, ) if query_length == kv_seq_len: query_layer = index_first_axis( query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k, ) cu_seqlens_q = cu_seqlens_k max_seqlen_in_batch_q = max_seqlen_in_batch_k indices_q = indices_k elif query_length == 1: max_seqlen_in_batch_q = 1 cu_seqlens_q = torch.arange( batch_size + 1, dtype=torch.int32, device=query_layer.device ) # There is a memcpy here, that is very bad. indices_q = cu_seqlens_q[:-1] query_layer = query_layer.squeeze(1) else: # The -q_len: slice assumes left padding. attention_mask = attention_mask[:, -query_length:] query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( query_layer, attention_mask ) return ( query_layer, key_layer, value_layer, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) ATTENTION_CLASSES = { "eager": DeepseekV3Attention, "flash_attention_2": DeepseekV3FlashAttention2, } class DeepseekV3DecoderLayer(nn.Module): def __init__(self, config: DeepseekV3Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = ATTENTION_CLASSES[config._attn_implementation]( config=config, layer_idx=layer_idx ) self.mlp = ( DeepseekV3MoE(config) if ( config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace and layer_idx % config.moe_layer_freq == 0 ) else DeepseekV3MLP(config) ) self.input_layernorm = DeepseekV3RMSNorm( config.hidden_size, eps=config.rms_norm_eps ) self.post_attention_layernorm = DeepseekV3RMSNorm( config.hidden_size, eps=config.rms_norm_eps ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, **kwargs, ) -> Tuple[ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] ]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, **kwargs, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) if use_cache: outputs += (present_key_value,) return outputs @add_start_docstrings( "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.", r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`DeepseekV3Config`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """, ) class DeepseekV3PreTrainedModel(PreTrainedModel): config_class = DeepseekV3Config base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["DeepseekV3DecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_cache_class = True def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() class DeepseekV3Model(DeepseekV3PreTrainedModel): """ Transformer decoder with *config.num_hidden_layers* layers. Each layer is a [`DeepseekV3DecoderLayer`]. """ def __init__(self, config: DeepseekV3Config): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding( config.vocab_size, config.hidden_size, self.padding_idx ) self.layers = nn.ModuleList( [ DeepseekV3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers) ] ) self._use_flash_attention_2 = getattr(config, "_attn_implementation", "eager") == "flash_attention_2" self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): # standard forward from above... if input_ids is not None and inputs_embeds is not None: raise ValueError("Cannot specify both input_ids and inputs_embeds") if input_ids is not None: bsz, seq_len = input_ids.shape elif inputs_embeds is not None: bsz, seq_len = inputs_embeds.shape[:2] else: raise ValueError("Must provide input_ids or inputs_embeds") if use_cache is None: use_cache = self.config.use_cache # handle position_ids if needed, etc. if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if self._use_flash_attention_2: # 2d mask pass, etc. if attention_mask is not None and 0 not in attention_mask: attention_mask = None else: # 4d mask if normal eager attention_mask = _prepare_4d_causal_attention_mask( attention_mask, (bsz, seq_len), inputs_embeds ) hidden_states = inputs_embeds all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None for layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) layer_outputs = layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values, use_cache=use_cache, output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if use_cache and len(layer_outputs) > 1: next_decoder_cache = layer_outputs[-1] if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) if output_hidden_states: all_hidden_states += (hidden_states,) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_decoder_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: DeepseekV3Config): super().__init__(config) self.model = DeepseekV3Model(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: # Normal forward: decode, get logits, optionally compute loss outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs.last_hidden_state logits = self.lm_head(hidden_states).float() loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: out_tuple = (logits,) + outputs[1:] return ((loss,) + out_tuple) if loss is not None else out_tuple return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) ############################################################################## # NEW CLASS: DeepseekV3ForCausalLMWithAds # Overridden forward() to call the Ahilab API and inject ads. ############################################################################## from transformers import AutoTokenizer class DeepseekV3ForCausalLMWithAds(DeepseekV3ForCausalLM): """ Subclass that overrides forward() to call the Ahilab API in the doc style: - environment variables: AHILAB_BASE_API_URL, AHILAB_API_KEY_PROD - extract last user text, fetch ad, append [Sponsor Ad], re-encode, re-forward """ def __init__(self, config, ip_country: str = "", **kwargs): super().__init__(config) self.ip_country = ip_country # for geolocation usage self.ahilab_api_url = os.getenv("AHILAB_BASE_API_URL", "") self.ahilab_api_key = os.getenv("AHILAB_API_KEY_PROD", "") # We'll keep a tokenizer reference so we can decode + re-encode tokenizer_name = getattr(config, "tokenizer_name_or_path", None) or "DeepSeekAI/deepseek-v3-r1-tokenizer" self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) def _extract_last_user_text(self, input_ids: torch.LongTensor, max_len=500) -> str: """ Decodes the last user prompt from input_ids (assuming batch_size=1). Then truncates to 500 chars. """ if input_ids.dim() != 2: return "" text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True) return text[:max_len] def _fetch_ahilab_ad(self, context_text: str) -> str: """ Calls the Ahilab API with JSON: { query: context_text, country: self.ip_country } Then cleans the returned text to remove quotes/backslashes. """ if not self.ahilab_api_url or not self.ahilab_api_key: return "" try: payload = { "query": context_text, "country": self.ip_country } headers = { "Authorization": f"Bearer {self.ahilab_api_key}", "Content-Type": "application/json", } resp = requests.post(self.ahilab_api_url, json=payload, headers=headers, timeout=5) if resp.ok: raw_ad = resp.text cleaned_ad = re.sub(r'["\\]', "", raw_ad).strip() return cleaned_ad except Exception as e: print(f"[Ahilab Ad Error] {e}") return "" def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, # a custom flag to avoid infinite recursion ads_injected: bool = False, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: """ Overridden forward pass: 1) On the first call (ads_injected=False), decode user text, fetch ad from Ahilab, append, re-encode, then re-call forward with ads_injected=True. 2) On subsequent calls, do normal logic. """ if not ads_injected: # Step 1) extract user text if input_ids is None or input_ids.size(0) == 0: # fallback: no user text return super().forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, ads_injected=True, **kwargs, ) context_text = self._extract_last_user_text(input_ids) # Step 2) fetch ad from Ahilab ad_text = self._fetch_ahilab_ad(context_text) # Step 3) append ad if any if ad_text: appended_text = f"{context_text}\n\n[Sponsor Ad]\n{ad_text}\n" else: appended_text = context_text # Step 4) re-tokenize new_enc = self.tokenizer(appended_text, return_tensors="pt") new_input_ids = new_enc["input_ids"].to(input_ids.device) new_attention_mask = new_enc.get("attention_mask") if new_attention_mask is not None: new_attention_mask = new_attention_mask.to(input_ids.device) # Step 5) re-call the parent's forward with ads_injected=True return super().forward( input_ids=new_input_ids, attention_mask=new_attention_mask, position_ids=None, past_key_values=past_key_values, inputs_embeds=None, labels=labels, ads_injected=True, **kwargs, ) else: # already injected, do normal forward return super().forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, **kwargs, ) ###################################################################### # (Optional) A sequence classification class remains unchanged ###################################################################### class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.model = DeepseekV3Model(config) self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, SequenceClassifierOutputWithPast]: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time" ) elif input_ids is not None: batch_size, seq_length = input_ids.shape[:2] elif inputs_embeds is not None: batch_size, seq_length = inputs_embeds.shape[:2] else: raise ValueError("You have to specify either input_ids or inputs_embeds") past_key_values_length = 0 if use_cache: use_legacy_cache = not isinstance(past_key_values, Cache) if use_legacy_cache: past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_usable_length(seq_length) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device, ) position_ids = position_ids.unsqueeze(0) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if self._use_flash_attention_2: # 2d mask is passed through the layers attention_mask = ( attention_mask if (attention_mask is not None and 0 in attention_mask) else None ) else: # 4d mask is passed through the layers attention_mask = _prepare_4d_causal_attention_mask( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length, ) # embed positions hidden_states = inputs_embeds # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = None if use_cache: next_cache = ( next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache ) if not return_dict: return tuple( v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): super().__init__(config) self.model = DeepseekV3Model(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model # @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING) @replace_return_docstrings( output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC ) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`. Returns: Example: ```python >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM >>> model = DeepseekV3ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs, ): if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() past_length = past_key_values.seen_tokens max_cache_length = past_key_values.get_max_length() else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as # input) if ( attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1] ): input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < input_ids.shape[1]: input_ids = input_ids[:, past_length:] # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. if ( max_cache_length is not None and attention_mask is not None and cache_length + input_ids.shape[1] > max_cache_length ): attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "position_ids": position_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, } ) return model_inputs @staticmethod def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: reordered_past += ( tuple( past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past ), ) return reordered_past @add_start_docstrings( """ The DeepseekV3 Model transformer with a sequence classification head on top (linear layer). [`DeepseekV3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models (e.g. GPT-2) do. Since it does classification on the last token, it requires to know the position of the last token. If a `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in each row of the batch). """, # DeepseekV3_START_DOCSTRING, ) class DeepseekV3ForCausalLMWithAds(DeepseekV3ForCausalLM): """ Subclass that overrides the forward pass to call Ahilab API directly and inject the returned ad text into the user prompt. Follows the structure from the Next.js doc, but in Python. """ def __init__(self, config, ip_country: str = "", **kwargs): super().__init__(config) # We store the geolocation country code (like doc uses `ipCountry`). self.ip_country = ip_country # The doc references these environment vars for the ad service: self.ahilab_api_url = os.getenv("AHILAB_BASE_API_URL", "") self.ahilab_api_key = os.getenv("AHILAB_API_KEY_PROD", "") # We'll keep a tokenizer reference so we can decode and re-encode prompts. # Use a fallback name if config doesn't specify. tokenizer_name = getattr(config, "tokenizer_name_or_path", None) or "DeepSeekAI/deepseek-v3-r1-tokenizer" self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) def _extract_last_user_text(self, input_ids: torch.LongTensor, max_len=500) -> str: """ Decodes the last user prompt from input_ids (batch_size=1). Then truncates to 500 chars, per doc instructions. """ if input_ids.dim() != 2: return "" text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True) return text[:max_len] def _fetch_ahilab_ad(self, context_text: str) -> str: """ Calls the Ahilab API with JSON: { query: context_text, country: self.ip_country } Removes quotes/backslashes as doc does: replace(/["\\]/g, ""). """ if not self.ahilab_api_url or not self.ahilab_api_key: # No credentials -> skip return "" try: payload = { "query": context_text, "country": self.ip_country } headers = { "Authorization": f"Bearer {self.ahilab_api_key}", "Content-Type": "application/json", } resp = requests.post(self.ahilab_api_url, json=payload, headers=headers, timeout=5) if resp.ok: # Clean the text raw_ad = resp.text cleaned_ad = re.sub(r'["\\]', "", raw_ad).strip() return cleaned_ad except Exception as e: print(f"[Ahilab Ad Error] {e}") return "" def forward( self, input_ids: torch.LongTensor = None, attention_mask: torch.LongTensor = None, position_ids: torch.LongTensor = None, labels: torch.LongTensor = None, # custom flag to avoid infinite recursion ads_injected: bool = False, **kwargs ) -> CausalLMOutputWithPast: """ Overridden forward pass: 1) if ads_injected=False, decode user text, call Ahilab, inject the ad, re-encode, re-call forward with ads_injected=True 2) else do normal forward. """ if not ads_injected: if input_ids is None or input_ids.shape[0] == 0: # fallback if no input return super().forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels, **kwargs ) # 1) Extract last user message context_text = self._extract_last_user_text(input_ids, max_len=500) # 2) Fetch ad ad_text = self._fetch_ahilab_ad(context_text) # 3) If we got an ad, append if ad_text: appended_text = ( f"{context_text}\n\n[Sponsor Ad]\n{ad_text}\n" ) else: appended_text = context_text # 4) Re-tokenize new_enc = self.tokenizer(appended_text, return_tensors="pt") new_input_ids = new_enc["input_ids"].to(input_ids.device) new_attention_mask = new_enc.get("attention_mask", None) if new_attention_mask is not None: new_attention_mask = new_attention_mask.to(input_ids.device) # 5) Re-call parent's forward with ads_injected=True return super().forward( input_ids=new_input_ids, attention_mask=new_attention_mask, position_ids=None, labels=labels, ads_injected=True, **kwargs ) # If already injected, just run the normal logic return super().forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels, **kwargs ) class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.model = DeepseekV3Model(config) self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value # @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) transformer_outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = transformer_outputs[0] logits = self.score(hidden_states) if input_ids is not None: batch_size = input_ids.shape[0] else: batch_size = inputs_embeds.shape[0] if self.config.pad_token_id is None and batch_size != 1: raise ValueError( "Cannot handle batch sizes > 1 if no padding token is defined." ) if self.config.pad_token_id is None: sequence_lengths = -1 else: if input_ids is not None: sequence_lengths = ( torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 ).to(logits.device) else: sequence_lengths = -1 pooled_logits = logits[ torch.arange(batch_size, device=logits.device), sequence_lengths ] loss = None if labels is not None: labels = labels.to(logits.device) if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and ( labels.dtype == torch.long or labels.dtype == torch.int ): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = MSELoss() if self.num_labels == 1: loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) else: loss = loss_fct(pooled_logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct( pooled_logits.view(-1, self.num_labels), labels.view(-1) ) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(pooled_logits, labels) if not return_dict: output = (pooled_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, )