|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
PyTorch RWKV079Qwen3 model. |
|
base code from SmerkyG @ recursal.ai, featherless.ai |
|
hxa079 implementation RWKV079 + NoPE Hybrid Attention |
|
|
|
""" |
|
|
|
import math |
|
import inspect |
|
from typing import List, Optional, Tuple, Union, Dict, Any |
|
|
|
import torch |
|
import torch.utils.checkpoint |
|
from torch import nn |
|
import torch.nn.functional as F |
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
|
|
|
from transformers.activations import ACT2FN |
|
from transformers.cache_utils import Cache, DynamicCache, CacheLayerMixin |
|
from transformers.generation import GenerationMixin |
|
from transformers.integrations import use_kernel_forward_from_hub |
|
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask |
|
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
|
from transformers.modeling_layers import ( |
|
GenericForQuestionAnswering, |
|
GenericForSequenceClassification, |
|
GenericForTokenClassification, |
|
GradientCheckpointingLayer, |
|
) |
|
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
|
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel |
|
from transformers.processing_utils import Unpack |
|
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple |
|
from transformers.utils.generic import check_model_inputs |
|
|
|
from .configuration_rwkv079qwen3 import RWKV079Qwen3Config |
|
|
|
from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer, Qwen3MLP, Qwen3RMSNorm, Qwen3Attention |
|
|
|
class RWKV079State(): |
|
def __init__(self) -> None: |
|
|
|
self._seen_tokens = 0 |
|
self.layer_kv_states: List[torch.Tensor] = [] |
|
self.layer_shift_states: List[torch.Tensor] = [] |
|
self.cumulative_scores: List[torch.Tensor] = [] |
|
self.sin: List[torch.Tensor] = [] |
|
self.cos: List[torch.Tensor] = [] |
|
|
|
def __getitem__(self, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the |
|
sequence length. |
|
""" |
|
if layer_idx < len(self): |
|
return (self.layer_kv_states[layer_idx], self.layer_shift_states[layer_idx]) |
|
else: |
|
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") |
|
|
|
def __iter__(self): |
|
""" |
|
Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over |
|
keys and values |
|
""" |
|
for layer_idx in range(len(self)): |
|
yield (self.layer_kv_states[layer_idx], self.layer_shift_states[layer_idx]) |
|
|
|
def __len__(self): |
|
""" |
|
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds |
|
to the number of layers in the model. |
|
""" |
|
return len(self.layer_kv_states) |
|
|
|
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: |
|
"""Given the sequence length of the new inputs, returns the usable length of the cache.""" |
|
|
|
return new_seq_length |
|
|
|
def reorder_cache(self, beam_idx: torch.LongTensor): |
|
"""Reorders the cache for beam search, given the selected beam indices.""" |
|
raise NotImplementedError('Cannot reorder Linear Attention state') |
|
|
|
def get_seq_length(self, layer_idx: int = 0) -> int: |
|
"""Returns the sequence length of the cached states. A layer index can be optionally passed.""" |
|
return self._seen_tokens |
|
|
|
def get_max_cache_shape(self) -> Optional[int]: |
|
"""Returns the maximum sequence length of the cache object. DynamicCache does not have a maximum length.""" |
|
return None |
|
|
|
def get_max_length(self) -> Optional[int]: |
|
""" |
|
Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length. |
|
""" |
|
return None |
|
|
|
def crop(self, max_length: int): |
|
|
|
return |
|
|
|
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: |
|
"""Return the length and offset of the cache, used to generate the mask""" |
|
kv_offset = 0 |
|
query_length = cache_position.shape[0] |
|
past_seen_tokens = self.get_seq_length() |
|
kv_length = query_length + past_seen_tokens |
|
return kv_length, kv_offset |
|
|
|
@property |
|
def is_compileable(self) -> bool: |
|
"""Return whether the cache is compileable""" |
|
return True |
|
|
|
@torch.no_grad |
|
def update( |
|
self, |
|
kv_state: torch.Tensor, |
|
shift_state: torch.Tensor, |
|
layer_idx: int, |
|
token_count: int = 0, |
|
is_attention_layer: bool = True, |
|
cache_kwargs: Optional[Dict[str, Any]] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
if layer_idx == 0: |
|
if is_attention_layer: |
|
token_count = kv_state.size(-2) |
|
self._seen_tokens += token_count |
|
|
|
|
|
|
|
|
|
if kv_state is not None: |
|
|
|
if layer_idx >= len(self.layer_kv_states): |
|
for _ in range(len(self.layer_kv_states), layer_idx): |
|
if is_attention_layer: |
|
self.layer_kv_states.append(torch.tensor([], dtype=kv_state.dtype, device=kv_state.device)) |
|
self.layer_shift_states.append(torch.tensor([], dtype=shift_state.dtype, device=shift_state.device)) |
|
else: |
|
self.layer_kv_states.append(torch.zeros_like(kv_state).requires_grad_(False)) |
|
self.layer_shift_states.append(torch.zeros_like(shift_state).requires_grad_(False)) |
|
self.layer_kv_states.append(kv_state) |
|
self.layer_shift_states.append(shift_state) |
|
else: |
|
if is_attention_layer: |
|
self.layer_kv_states[layer_idx] = torch.cat([self.layer_kv_states[layer_idx], kv_state], dim=-2) |
|
self.layer_shift_states[layer_idx] = torch.cat([self.layer_shift_states[layer_idx], shift_state], dim=-2) |
|
else: |
|
self.layer_kv_states[layer_idx].copy_(kv_state) |
|
self.layer_shift_states[layer_idx].copy_(shift_state) |
|
|
|
return self.layer_kv_states[layer_idx], self.layer_shift_states[layer_idx] |
|
|
|
try: |
|
from fla.ops.rwkv7.chunk import chunk_rwkv7 |
|
from fla.ops.rwkv7.fused_recurrent import fused_recurrent_rwkv7 |
|
except ImportError: |
|
print("Required module is not installed. Please install it using the following commands:") |
|
print("pip install --no-use-pep517 flash-linear-attention") |
|
print("Additionally, ensure you have at least version 2.2.0 of Triton installed:") |
|
print("pip install triton>=2.2.0") |
|
|
|
|
|
|
|
|
|
def is_layer_attention(config, layer_id): |
|
return layer_id in config.transformer_layers |
|
|
|
def repeat_kv_rwkv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
|
""" |
|
Repeat KV heads along the head dimension (GQA). |
|
Input: (B, T, H_kv, D) |
|
Output: (B, T, H_kv * n_rep, D) |
|
""" |
|
B, T, H_kv, D = hidden_states.shape |
|
if n_rep == 1: |
|
return hidden_states |
|
|
|
hidden_states = hidden_states[:, :, :, None, :] |
|
hidden_states = hidden_states.expand(B, T, H_kv, n_rep, D) |
|
return hidden_states.reshape(B, T, H_kv * n_rep, D).contiguous() |
|
|
|
def T5RMSNorm(hidden_states,weight,variance_epsilon:float=1e-6): |
|
input_dtype = hidden_states.dtype |
|
hidden_states = hidden_states.to(torch.float32) |
|
variance = hidden_states.pow(2).mean(-1, keepdim=True) |
|
hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon) |
|
return (weight * hidden_states).to(input_dtype) |
|
|
|
def compute_qwen3_rope_cache(seq_len, rotary_dim, device, dtype, rope_theta): |
|
half_dim = rotary_dim // 2 |
|
freq_seq = torch.arange(half_dim, dtype=dtype, device=device) |
|
inv_freq = 1.0 / (rope_theta ** (freq_seq / half_dim)) |
|
positions = torch.arange(seq_len, dtype=dtype, device=device) |
|
freqs = torch.einsum("i,j->ij", positions, inv_freq) |
|
emb = torch.cat([freqs, freqs], dim=-1) |
|
cos = emb.cos() |
|
sin = emb.sin() |
|
return cos.unsqueeze(0), sin.unsqueeze(0), inv_freq |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Qwen3RotaryEmbedding(nn.Module): |
|
def __init__(self, config: RWKV079Qwen3Config, device=None): |
|
super().__init__() |
|
|
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: |
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) |
|
else: |
|
self.rope_type = "default" |
|
self.max_seq_len_cached = config.max_position_embeddings |
|
self.original_max_seq_len = config.max_position_embeddings |
|
|
|
self.config = config |
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] |
|
|
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) |
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
self.original_inv_freq = self.inv_freq |
|
|
|
def _dynamic_frequency_update(self, position_ids, device): |
|
""" |
|
dynamic RoPE layers should recompute `inv_freq` in the following situations: |
|
1 - growing beyond the cached sequence length (allow scaling) |
|
2 - the current sequence length is in the original scale (avoid losing precision with small sequences) |
|
""" |
|
seq_len = torch.max(position_ids) + 1 |
|
if seq_len > self.max_seq_len_cached: |
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) |
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
self.max_seq_len_cached = seq_len |
|
|
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: |
|
|
|
|
|
self.original_inv_freq = self.original_inv_freq.to(device) |
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) |
|
self.max_seq_len_cached = self.original_max_seq_len |
|
|
|
@torch.no_grad() |
|
def forward(self, x, position_ids): |
|
if "dynamic" in self.rope_type: |
|
self._dynamic_frequency_update(position_ids, device=x.device) |
|
|
|
|
|
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) |
|
position_ids_expanded = position_ids[:, None, :].float() |
|
|
|
device_type = x.device.type |
|
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" |
|
with torch.autocast(device_type=device_type, enabled=False): |
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) |
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
cos = emb.cos() |
|
sin = emb.sin() |
|
|
|
|
|
cos = cos * self.attention_scaling |
|
sin = sin * self.attention_scaling |
|
|
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
|
|
|
def rms_norm(hidden_states, eps = 1e-6): |
|
|
|
input_dtype = hidden_states.dtype |
|
hidden_states = hidden_states.to(torch.float32) |
|
variance = hidden_states.pow(2).mean(-1, keepdim=True) |
|
hidden_states = hidden_states * torch.rsqrt(variance + eps) |
|
return hidden_states.to(input_dtype) |
|
|
|
def generate_rotary_embedding(max_seqlen:int, dim:int, theta:float = 10000.0, scale:float = 1): |
|
|
|
|
|
angular_velocity = theta ** -(torch.arange(0, dim, 2, dtype=torch.float) / dim) / scale |
|
angles = torch.outer(torch.arange(max_seqlen), angular_velocity) |
|
|
|
emb = torch.cat((angles, angles), dim=-1) |
|
return torch.stack([emb.cos(), emb.sin()], dim=0) |
|
|
|
|
|
|
|
def rotate_half(x): |
|
"""Rotates half the hidden dims of the input.""" |
|
x1 = x[..., : x.shape[-1] // 2] |
|
x2 = x[..., x.shape[-1] // 2 :] |
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
|
"""Applies Rotary Position Embedding to the query and key tensors. |
|
|
|
Args: |
|
q (`torch.Tensor`): The query tensor. |
|
k (`torch.Tensor`): The key tensor. |
|
cos (`torch.Tensor`): The cosine part of the rotary embedding. |
|
sin (`torch.Tensor`): The sine part of the rotary embedding. |
|
position_ids (`torch.Tensor`, *optional*): |
|
Deprecated and unused. |
|
unsqueeze_dim (`int`, *optional*, defaults to 1): |
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and |
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and |
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes |
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have |
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
|
Returns: |
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. |
|
""" |
|
cos = cos.unsqueeze(unsqueeze_dim) |
|
sin = sin.unsqueeze(unsqueeze_dim) |
|
q_embed = (q * cos) + (rotate_half(q) * sin) |
|
k_embed = (k * cos) + (rotate_half(k) * sin) |
|
return q_embed, k_embed |
|
|
|
def apply_rotary_pos_emb_single(x, cos, sin, unsqueeze_dim=1): |
|
return (x * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(x) * sin.unsqueeze(unsqueeze_dim)) |
|
|
|
from typing import Callable, Optional, Tuple, Union |
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel |
|
from transformers.processing_utils import Unpack |
|
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
|
""" |
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
|
""" |
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
|
if n_rep == 1: |
|
return hidden_states |
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
def eager_attention_forward( |
|
module: nn.Module, |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor], |
|
scaling: float, |
|
dropout: float = 0.0, |
|
**kwargs, |
|
): |
|
key_states = repeat_kv(key, module.num_key_value_groups) |
|
value_states = repeat_kv(value, module.num_key_value_groups) |
|
|
|
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling |
|
if attention_mask is not None: |
|
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] |
|
attn_weights = attn_weights + causal_mask |
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) |
|
attn_weights = attn_weights.masked_fill(attn_weights.isnan(), 0) |
|
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) |
|
attn_output = torch.matmul(attn_weights, value_states) |
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
|
return attn_output, attn_weights |
|
|
|
from torch.nn.attention.flex_attention import create_block_mask, flex_attention, create_mask |
|
from functools import lru_cache |
|
|
|
block_mask = None |
|
|
|
|
|
|
|
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, |
|
is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor: |
|
L, S = query.size(-2), key.size(-2) |
|
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale |
|
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) |
|
if is_causal: |
|
assert attn_mask is None |
|
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) |
|
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) |
|
attn_bias.to(query.dtype) |
|
|
|
if attn_mask is not None: |
|
if attn_mask.dtype == torch.bool: |
|
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) |
|
else: |
|
attn_bias = attn_mask + attn_bias |
|
|
|
if enable_gqa: |
|
key = key.repeat_interleave(query.size(-3)//key.size(-3), -3) |
|
value = value.repeat_interleave(query.size(-3)//value.size(-3), -3) |
|
|
|
attn_weight = query.float() @ key.float().transpose(-2, -1) * scale_factor |
|
attn_weight += attn_bias.float() |
|
|
|
attn_weight = torch.softmax(attn_weight, dim=-1) |
|
attn_weight = attn_weight.masked_fill(attn_weight.isnan(), 0) |
|
|
|
return attn_weight @ value.float() |
|
|
|
|
|
|
|
class Qwen3AttentionNoPE_Causal(Qwen3Attention): |
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
frozen_residual: torch.Tensor, |
|
v_first: Optional[torch.Tensor] = None, |
|
k_first: Optional[torch.Tensor] = None, |
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
past_key_values: Optional[Cache] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
**kwargs: Unpack[FlashAttentionKwargs], |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: |
|
x = hidden_states |
|
|
|
B, L, D = x.size() |
|
|
|
input_shape = x.shape[:-1] |
|
hidden_shape = (*input_shape, -1, self.head_dim) |
|
|
|
q = self.q_proj(x).view(hidden_shape).transpose(1, 2) |
|
k = self.k_proj(x).view(hidden_shape).transpose(1, 2) |
|
v = self.v_proj(x).view(hidden_shape).transpose(1, 2) |
|
|
|
if past_key_values is not None: |
|
|
|
cache_kwargs = {"cache_position": cache_position} |
|
k, v = past_key_values.update(k, v, self.layer_idx, cache_kwargs) |
|
|
|
|
|
k = repeat_kv(k, self.num_key_value_groups) |
|
v = repeat_kv(v, self.num_key_value_groups) |
|
|
|
S = k.size(-2) |
|
|
|
y = nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, attn_mask=attention_mask, is_causal=attention_mask is None and L==S) |
|
y = y.transpose(1,2) |
|
y = y.reshape(*input_shape, -1) |
|
y = self.o_proj(y) |
|
|
|
attn_weights = None |
|
|
|
return y, v_first, k_first |
|
|
|
|
|
class RWKV079Attention(nn.Module): |
|
def __init__(self, config, layer_idx: Optional[int] = None): |
|
super().__init__() |
|
self.config = config |
|
self.layer_idx = layer_idx |
|
C = self.hidden_size = config.hidden_size |
|
H = self.num_heads = config.num_attention_heads |
|
H_kv = config.num_key_value_heads |
|
N = self.head_dim = getattr(config, 'head_dim', self.hidden_size // self.num_heads) |
|
self.num_key_value_heads = config.num_key_value_heads |
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
|
self.attention_dropout = config.attention_dropout |
|
|
|
if self.hidden_size % self.num_heads != 0: |
|
raise ValueError( |
|
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" |
|
f" and `num_heads`: {self.num_heads})." |
|
) |
|
self.receptance = nn.Linear( |
|
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias |
|
) |
|
self.key = nn.Linear( |
|
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias |
|
) |
|
self.value = nn.Linear( |
|
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias |
|
) |
|
self.output = nn.Linear( |
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
|
) |
|
|
|
|
|
|
|
|
|
lora_rank_decay = config.lora_rank_decay |
|
lora_rank_iclr = config.lora_rank_iclr |
|
lora_rank_value_residual_mix = config.lora_rank_value_residual_mix |
|
lora_rank_key_residual_mix = config.lora_rank_key_residual_mix |
|
lora_rank_gate = config.lora_rank_gate |
|
|
|
print(f"lora_rank_value_residual_mix = {lora_rank_value_residual_mix} lora_rank_key_residual_mix={lora_rank_key_residual_mix}") |
|
|
|
|
|
self.w0 = nn.Parameter(torch.empty(1,1,H*N)) |
|
self.w1 = nn.Parameter(torch.empty(C, lora_rank_decay)) |
|
self.w2 = nn.Parameter(torch.empty(lora_rank_decay, H*N)) |
|
|
|
self.a0 = nn.Parameter(torch.empty(1,1,H*N)) |
|
self.a1 = nn.Parameter(torch.empty(C, lora_rank_iclr)) |
|
self.a2 = nn.Parameter(torch.empty(lora_rank_iclr, H*N)) |
|
|
|
|
|
self.v0 = nn.Parameter(torch.empty(1,1,H_kv*N)) |
|
self.v1 = nn.Parameter(torch.empty(C, lora_rank_value_residual_mix)) |
|
self.v2 = nn.Parameter(torch.empty(lora_rank_value_residual_mix, H_kv*N)) |
|
|
|
self.k0 = nn.Parameter(torch.empty(1,1,H_kv*N)) |
|
self.k1 = nn.Parameter(torch.empty(C, lora_rank_key_residual_mix)) |
|
self.k2 = nn.Parameter(torch.empty(lora_rank_key_residual_mix, H_kv*N)) |
|
|
|
|
|
self.g1 = nn.Parameter(torch.empty(C, lora_rank_gate)) |
|
self.g2 = nn.Parameter(torch.empty(lora_rank_gate, H*N)) |
|
|
|
self.r_k = nn.Parameter(torch.empty(H,N)) |
|
|
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
frozen_residual: torch.Tensor, |
|
v_first: Optional[torch.Tensor] = None, |
|
k_first: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[RWKV079State] = None, |
|
output_attentions: bool = False, |
|
use_cache: bool = False, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
**kwargs, |
|
): |
|
if attention_mask is not None: |
|
assert len(attention_mask.shape) in (2, 4) |
|
|
|
output_shift_state = hidden_states[:, -1:].detach().clone() |
|
|
|
x = hidden_states |
|
|
|
B, T, C = hidden_states.shape |
|
H = self.num_heads |
|
N = self.head_dim |
|
|
|
q_len = T |
|
|
|
if use_cache and past_key_values is not None and len(past_key_values) > self.layer_idx: |
|
|
|
input_vk_state, input_shift_state = past_key_values[self.layer_idx] |
|
else: |
|
input_vk_state, input_shift_state = torch.zeros(B,H,N,N, dtype=torch.bfloat16,device=x.device), torch.zeros_like(x[:, -1:]) |
|
|
|
xr = xw = xk = xv = xa = xg = x |
|
|
|
r = self.receptance(xr).view(B,T,-1,N) |
|
w = -F.softplus(-(self.w0 + torch.tanh(xw @ self.w1) @ self.w2)) -0.5 |
|
k = self.key(xk).view(B,T,-1,N) |
|
v = self.value(xv).view(B,T,-1,N) |
|
a = torch.sigmoid(self.a0 + (xa @ self.a1) @ self.a2) |
|
g = torch.sigmoid(xg @ self.g1) @ self.g2 |
|
|
|
if position_embeddings is not None: |
|
cos, sin = position_embeddings |
|
r, k = apply_rotary_pos_emb(r, k, cos, sin, unsqueeze_dim=2) |
|
|
|
|
|
|
|
if self.layer_idx == 0: |
|
v_first = v |
|
k_first = k |
|
else: |
|
v = v + (v_first - v) * torch.sigmoid(self.v0 + (x @ self.v1) @ self.v2).view(B,T,self.num_key_value_heads,-1) |
|
k = k + (k_first - k) * torch.sigmoid(self.k0 + (x @ self.k1) @ self.k2).view(B,T,self.num_key_value_heads,-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if attention_mask is not None: |
|
if attention_mask is not None: |
|
if attention_mask.ndim == 2: |
|
|
|
mask = attention_mask[:, -T:] |
|
v = v * mask[:, :, None, None] |
|
elif attention_mask.ndim == 4: |
|
|
|
mask = attention_mask[:, 0, -1, -T:] |
|
v = v * mask[:, :, None, None] |
|
|
|
|
|
|
|
|
|
|
|
k = repeat_kv_rwkv(k, self.num_key_value_groups).view(B, T, -1) |
|
v = repeat_kv_rwkv(v, self.num_key_value_groups).view(B, T, -1) |
|
dropout_rate = 0.0 if not self.training else self.attention_dropout |
|
|
|
kk = F.normalize(k.view(B,T,H,-1), dim=-1, p=2.0).view(B,T,-1) |
|
k = k * (1.0 - w + a) |
|
|
|
aa = -kk |
|
bb = kk * a |
|
w = -w.exp() |
|
|
|
|
|
|
|
r_,w_,k_,v_,aa_,bb_ = [i.view(B,T,H,N) for i in [r,w,k,v,aa,bb]] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x, output_vk_state = fused_recurrent_rwkv7(r_, w_, k_, v_, aa_, bb_, scale=1.0, initial_state=input_vk_state, output_final_state=True, head_first=False) |
|
|
|
|
|
|
|
|
|
x = x.view(B,T,-1) * (float(N) ** -0.5) |
|
|
|
x = x + ((r.view(B,T,H,-1)*k.view(B,T,H,-1)*self.r_k).sum(dim=-1, keepdim=True) * v.view(B,T,H,-1)).view(B,T,-1) |
|
|
|
|
|
|
|
|
|
x = x * g |
|
x = self.output(x) |
|
|
|
if past_key_values is not None: |
|
past_key_values.update(output_vk_state, output_shift_state, self.layer_idx, q_len, is_layer_attention(self.config, self.layer_idx)) |
|
|
|
return x, v_first, k_first |
|
|
|
class RWKV079Qwen3DecoderLayer(nn.Module): |
|
def __init__(self, config: RWKV079Qwen3Config, layer_idx: int): |
|
nn.Module.__init__(self) |
|
self.hidden_size = config.hidden_size |
|
self.layer_idx = layer_idx |
|
|
|
if is_layer_attention(config, layer_idx): |
|
print(f'layer {layer_idx} : attention') |
|
att_fn = Qwen3AttentionNoPE_Causal |
|
else: |
|
print(f'layer {layer_idx} : rwkv') |
|
att_fn = RWKV079Attention |
|
|
|
self.self_attn = att_fn(config, layer_idx) |
|
|
|
self.mlp = Qwen3MLP(config) |
|
self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
self.attention_type = config.layer_types[layer_idx] |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
frozen_residual: torch.Tensor, |
|
v_first: Optional[torch.Tensor], |
|
k_first: Optional[torch.Tensor], |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[Cache] = None, |
|
output_attentions: Optional[bool] = False, |
|
use_cache: Optional[bool] = False, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
**kwargs, |
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: |
|
residual = hidden_states |
|
|
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
|
|
|
hidden_states, v_first, k_first = self.self_attn( |
|
hidden_states=hidden_states, |
|
frozen_residual=frozen_residual, |
|
v_first=v_first, |
|
k_first=k_first, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
cache_position=cache_position, |
|
position_embeddings=position_embeddings, |
|
|
|
) |
|
|
|
hidden_states = residual + hidden_states |
|
|
|
|
|
residual = hidden_states |
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
hidden_states = self.mlp(hidden_states) |
|
hidden_states = residual + hidden_states |
|
|
|
outputs = (hidden_states, v_first,k_first,) |
|
|
|
if output_attentions: |
|
outputs += (self_attn_weights,) |
|
|
|
return outputs |
|
|
|
|
|
@auto_docstring |
|
class RWKV079Qwen3PreTrainedModel(PreTrainedModel): |
|
config: RWKV079Qwen3Config |
|
config_class = RWKV079Qwen3Config |
|
base_model_prefix = "model" |
|
supports_gradient_checkpointing = True |
|
_no_split_modules = ["RWKV079Qwen3DecoderLayer"] |
|
_skip_keys_device_placement = "past_key_values" |
|
_supports_flash_attn_2 = True |
|
_supports_sdpa = True |
|
_supports_flex_attn = True |
|
|
|
_supports_cache_class = True |
|
_supports_quantized_cache = True |
|
_supports_static_cache = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@auto_docstring |
|
class RWKV079Qwen3Model(RWKV079Qwen3PreTrainedModel): |
|
""" |
|
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen3DecoderLayer`] |
|
|
|
Args: |
|
config: RWKV079Qwen3Config |
|
""" |
|
|
|
def __init__(self, config: RWKV079Qwen3Config): |
|
super().__init__(config) |
|
self.padding_idx = config.pad_token_id |
|
self.vocab_size = config.vocab_size |
|
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
|
self.layers = nn.ModuleList( |
|
[RWKV079Qwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
|
) |
|
self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
self.rotary_emb = Qwen3RotaryEmbedding(config=config) |
|
self.gradient_checkpointing = False |
|
self.has_sliding_layers = "sliding_attention" in self.config.layer_types |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
@auto_docstring |
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[Cache] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
**kwargs: Unpack[TransformersKwargs], |
|
) -> Union[Tuple, BaseModelOutputWithPast]: |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
|
if (input_ids is None) ^ (inputs_embeds is not None): |
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
|
if self.gradient_checkpointing and self.training and use_cache: |
|
logger.warning_once( |
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
|
) |
|
use_cache = False |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.embed_tokens(input_ids) |
|
|
|
if use_cache and not isinstance(past_key_values, RWKV079State): |
|
past_key_values = RWKV079State() |
|
|
|
if cache_position is None: |
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
|
cache_position = torch.arange( |
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device |
|
) |
|
|
|
if position_ids is None: |
|
position_ids = cache_position.unsqueeze(0) |
|
|
|
|
|
if not isinstance(causal_mask_mapping := attention_mask, dict): |
|
|
|
mask_kwargs = { |
|
"config": self.config, |
|
"input_embeds": inputs_embeds, |
|
"attention_mask": attention_mask, |
|
"cache_position": cache_position, |
|
"past_key_values": past_key_values, |
|
"position_ids": position_ids, |
|
} |
|
|
|
causal_mask_mapping = { |
|
"full_attention": create_causal_mask(**mask_kwargs), |
|
} |
|
|
|
if self.has_sliding_layers: |
|
causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs) |
|
|
|
hidden_states = inputs_embeds |
|
|
|
|
|
if self.config.use_rope: |
|
position_embeddings = self.rotary_emb(hidden_states, position_ids) |
|
else: |
|
position_embeddings = None |
|
|
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
all_self_attns = () if output_attentions else None |
|
next_decoder_cache = None |
|
v_first = None |
|
k_first = None |
|
frozen_residual = None |
|
|
|
for decoder_layer in self.layers: |
|
if not is_layer_attention(self.config, decoder_layer.layer_idx): |
|
frozen_residual = hidden_states |
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
attention_mask = causal_mask_mapping[decoder_layer.attention_type] |
|
if attention_mask is not None and attention_mask.ndim == 1: |
|
attention_mask = None |
|
|
|
|
|
layer_outputs = decoder_layer( |
|
hidden_states, |
|
frozen_residual=frozen_residual, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
cache_position=cache_position, |
|
position_embeddings=position_embeddings, |
|
v_first=v_first, |
|
k_first=k_first |
|
) |
|
|
|
hidden_states = layer_outputs[0] |
|
v_first = layer_outputs[1] |
|
k_first = layer_outputs[2] |
|
|
|
if output_attentions: |
|
all_self_attns += (layer_outputs[2],) |
|
|
|
hidden_states = self.norm(hidden_states) |
|
|
|
|
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
|
|
|
|
|
|
return BaseModelOutputWithPast( |
|
last_hidden_state=hidden_states, |
|
past_key_values=past_key_values if use_cache else None, |
|
hidden_states=all_hidden_states, |
|
attentions=all_self_attns, |
|
) |
|
|
|
class RWKV079Qwen3ForCausalLM(RWKV079Qwen3PreTrainedModel, GenerationMixin): |
|
_tied_weights_keys = ["lm_head.weight"] |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = RWKV079Qwen3Model(config) |
|
self.vocab_size = config.vocab_size |
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
self.post_init() |
|
|
|
@can_return_tuple |
|
@auto_docstring |
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
logits_to_keep: Union[int, torch.Tensor] = 0, |
|
**loss_kwargs, |
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
r""" |
|
Args: |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
|
|
|
num_logits_to_keep (`int`, *optional*): |
|
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all |
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that |
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size. |
|
|
|
Returns: |
|
|
|
Example: |
|
|
|
```python |
|
>>> from transformers import AutoTokenizer, RWKV079Qwen3ForCausalLM |
|
|
|
>>> model = RWKV079Qwen3ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) |
|
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) |
|
|
|
>>> prompt = "Hey, are you conscious? Can you talk to me?" |
|
>>> inputs = tokenizer(prompt, return_tensors="pt") |
|
|
|
>>> # Generate |
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." |
|
```""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
|
|
|
|
outputs = self.model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
cache_position=cache_position, |
|
) |
|
|
|
hidden_states = outputs.last_hidden_state |
|
|
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
|
logits = self.lm_head(hidden_states[:, slice_indices, :]) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.vocab_size, **loss_kwargs) |
|
|
|
return CausalLMOutputWithPast( |
|
loss=loss, |
|
logits=logits, |
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
@auto_docstring |
|
class RWKV079Qwen3ForSequenceClassification(RWKV079Qwen3PreTrainedModel): |
|
pass |
|
|
|
@auto_docstring |
|
class RWKV079Qwen3ForTokenClassification(RWKV079Qwen3PreTrainedModel): |
|
pass |
|
|
|
@auto_docstring |
|
class RWKV079Qwen3ForQuestionAnswering(RWKV079Qwen3PreTrainedModel): |
|
base_model_prefix = "transformer" |
|
|