granite-3b-code-base-2k / modeling_granite.py
mayank-mishra's picture
upload model
448e236
raw
history blame
58.3 kB
import numbers
from enum import Enum
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import DynamicCache, PreTrainedModel
from transformers.activations import get_activation as get_base_activation
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from transformers.utils import is_flash_attn_2_available
from .configuration_granite import GraniteConfig
class PositionEmbeddingType(Enum):
learned_absolute = "learned_absolute"
alibi = "alibi"
rope = "rope"
class AttentionHeadType(Enum):
mha = "mha"
mqa = "mqa"
gqa = "gqa"
if is_flash_attn_2_available():
from flash_attn.bert_padding import IndexFirstAxis, pad_input, unpad_input
from flash_attn.flash_attn_interface import flash_attn_varlen_func
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return indices, cu_seqlens, max_seqlen_in_batch
def repeat_key_value(x: torch.Tensor, num_heads: int, num_key_value_heads: int) -> torch.Tensor:
num_groups = num_heads // num_key_value_heads
# mha
if num_groups == 1:
return x
# mqa
if num_key_value_heads == 1:
return x.expand(-1, num_heads, -1, -1)
# gqa
return x.repeat_interleave(num_groups, dim=1)
##################################################
# activation functions
_GLU_BASE_MAPPING = {
"geglu": "gelu",
"miglu": "mish",
"mishglu": "mish",
"swiglu": "swish",
}
class GLUActivation(nn.Module):
def __init__(self, base_activation: nn.Module) -> None:
super().__init__()
self.base_activation = base_activation
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.chunk(2, dim=-1)
return x[0] * self.base_activation(x[1])
def is_glu(name: str) -> bool:
return name.endswith("glu")
def get_activation_function(name: str) -> nn.Module:
if is_glu(name):
# for glu and sigmoid_glu, we directly return the pytorch's GLU
if name in ["glu", "sigmoid_glu"]:
activation_function = nn.modules.GLU()
else:
if name in _GLU_BASE_MAPPING:
name = _GLU_BASE_MAPPING[name]
elif name.endswith("_glu"):
name = name.rstrip("_glu")
else:
raise ValueError("invalid activation function")
base_activation = get_base_activation(name)
activation_function = GLUActivation(base_activation)
else:
activation_function = get_base_activation(name)
return activation_function
##################################################
# normalization functions
class RMSNorm(nn.Module):
def __init__(self, normalized_shape: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.eps = eps
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = normalized_shape
def forward(self, input: torch.Tensor) -> torch.Tensor:
input_dtype = input.dtype
input = input.to(torch.float32)
variance = input.pow(2).mean(-1, keepdim=True)
input = input * torch.rsqrt(variance + self.eps)
return self.weight * input.to(input_dtype)
def extra_repr(self) -> str:
return f"{self.normalized_shape}, eps={self.eps}"
def reset_parameters(self) -> None:
nn.init.ones_(self.weight)
_NORMALIZATION_FUNCTIONS = {
"layernorm": nn.LayerNorm,
"rmsnorm": RMSNorm,
}
def get_normalization_function(name: str, normalized_shape: int, eps: float = 1e-5) -> nn.Module:
if name in _NORMALIZATION_FUNCTIONS:
return _NORMALIZATION_FUNCTIONS[name](normalized_shape, eps=eps)
raise ValueError(f"unexpected `normalization_function` {name}")
##################################################
# attention modules
class GraniteAttention(nn.Module):
def __init__(self, config: GraniteConfig, causal: bool, layer_idx: Optional[int] = None) -> None:
super().__init__()
self.causal = causal
self.hidden_size = config.n_embd
self.num_heads = config.n_head
self.num_key_value_heads = config.num_key_value_heads
self.add_bias = config.add_bias
assert (
self.hidden_size % self.num_heads == 0
), f"`hidden_size` ({self.hidden_size}) must be divisible by `num_heads` ({self.num_heads})"
self.head_dim = self.hidden_size // self.num_heads
self.attention_head_type = AttentionHeadType(config.attention_head_type)
self.position_embedding_type = PositionEmbeddingType(config.position_embedding_type)
self.scale_attn_weights = config.scale_attn_weights
self.attention_multiplier = config.attention_multiplier
self.layer_idx = layer_idx
self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
self.scale_attention_softmax_in_fp32 = (
config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32
)
if self.attention_head_type == AttentionHeadType.mha:
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_heads
assert (
self.num_heads == self.num_key_value_heads
), f"{self.__class__.__name__} should have same number of heads for query, keys and values"
elif self.attention_head_type == AttentionHeadType.gqa:
assert (
self.num_key_value_heads is not None
), "`num_key_value_heads` needs to be specified with GroupedQueryAttention"
assert self.num_heads % self.num_key_value_heads == 0, (
f"`num_heads` ({self.num_heads}) should be a multiple of `num_key_value_heads` "
f"({self.num_key_value_heads})"
)
elif self.attention_head_type == AttentionHeadType.mqa:
if self.num_key_value_heads is None:
self.num_key_value_heads = 1
assert self.num_key_value_heads == 1, f"{self.__class__.__name__} should have 1 head for keys and values"
else:
raise ValueError(f"unexpected attention_head_type ({self.attention_head_type})")
# note that the actual layout is different for the output and depends on whether we are using MHA, MQA or GQA
# (self.hidden_size + 2 * self.num_key_value_heads * self.head_dim) is just the actual number output features
self.c_attn = nn.Linear(
self.hidden_size, self.hidden_size + 2 * self.num_key_value_heads * self.head_dim, bias=self.add_bias
)
self.c_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.add_bias)
self.attn_pdrop = config.attn_pdrop
self.resid_pdrop = config.resid_pdrop
self.attn_dropout = nn.Identity() if self.attn_pdrop == 0 else nn.Dropout(self.attn_pdrop)
self.resid_dropout = nn.Identity() if self.resid_pdrop == 0 else nn.Dropout(self.resid_pdrop)
def _prepare_qkv_for_forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# ==========================================================================================
# hidden_states -> (batch_size, query_length, num_heads * head_dim)
# ==========================================================================================
# the output of following is a tuple if using MQA with tensor parallel
hidden_states = self.c_attn(hidden_states)
# ==========================================================================================
# hidden_states -> (batch_size, query_length, [num_heads + num_key_value_heads * 2] * head_dim)
# ==========================================================================================
# for MHA, we can get away with doing just 1 transpose which is not true for GQA
if self.attention_head_type == AttentionHeadType.mha:
query, key, value = self._prepare_qkv_for_forward_mha(hidden_states)
elif self.attention_head_type == AttentionHeadType.gqa:
query, key, value = self._prepare_qkv_for_forward_gqa(hidden_states)
elif self.attention_head_type == AttentionHeadType.mqa:
query, key, value = self._prepare_qkv_for_forward_mqa(hidden_states)
else:
raise ValueError(f"unexpected attention_head_type ({self.attention_head_type})")
# ==========================================================================================
# query -> (batch_size, num_heads, query_length, head_dim)
# key -> (batch_size, num_key_value_heads, query_length, head_dim)
# value -> (batch_size, num_key_value_heads, query_length, head_dim)
# ==========================================================================================
return query, key, value
def _prepare_qkv_for_forward_mha(
self, hidden_states: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size, query_length = hidden_states.shape[:-1]
hidden_states = hidden_states.view(batch_size, query_length, self.num_heads, -1)
hidden_states = hidden_states.transpose(1, 2)
query, key, value = hidden_states.chunk(3, dim=-1)
return query, key, value
def _prepare_qkv_for_forward_gqa(
self, hidden_states: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size, query_length = hidden_states.shape[:-1]
hidden_states = hidden_states.view(batch_size, query_length, self.num_key_value_heads, -1)
query, key, value = hidden_states.split(
((self.num_heads // self.num_key_value_heads) * self.head_dim, self.head_dim, self.head_dim), dim=-1
)
# this needs to be a reshape instead of view sadly
query = query.reshape(batch_size, query_length, -1, self.head_dim)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
return query, key, value
def _prepare_qkv_for_forward_mqa(
self, hidden_states: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size, query_length = hidden_states.shape[:-1]
query, key, value = hidden_states.split((self.hidden_size, self.head_dim, self.head_dim), dim=-1)
query = query.view(batch_size, query_length, self.num_heads, -1)
query = query.transpose(1, 2)
key = key.unsqueeze(1)
value = value.unsqueeze(1)
return query, key, value
def forward(
self,
hidden_states: torch.Tensor,
past_key_values: Optional[DynamicCache] = None,
attention_mask: Optional[torch.Tensor] = None,
rope_cos_sin: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
# ==========================================================================================
# hidden_states -> (batch_size, query_length, num_heads * head_dim)
# ==========================================================================================
query, key, value = self._prepare_qkv_for_forward(hidden_states)
# ==========================================================================================
# query -> (batch_size, num_heads, query_length, head_dim)
# key -> (batch_size, num_key_value_heads, query_length, head_dim)
# value -> (batch_size, num_key_value_heads, query_length, head_dim)
# ==========================================================================================
if self.position_embedding_type == PositionEmbeddingType.rope:
query = apply_rotary_pos_emb(query, rope_cos_sin)
key = apply_rotary_pos_emb(key, rope_cos_sin)
if past_key_values is not None:
key, value = past_key_values.update(key, value, self.layer_idx)
# ==========================================================================================
# query -> (batch_size, num_heads, query_length, head_dim)
# key -> (batch_size, num_key_value_heads, key_length, head_dim)
# value -> (batch_size, num_key_value_heads, key_length, head_dim)
# ==========================================================================================
key = key.transpose(-1, -2)
dtype = query.dtype
softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype
if self.scale_attn_weights:
if self.attention_multiplier is None:
scale_factor = 1 / self.head_dim**0.5
else:
scale_factor = self.attention_multiplier
else:
scale_factor = 1
# ==========================================================================================
# query -> (batch_size, num_heads, query_length, head_dim)
# key -> (batch_size, num_key_value_heads, head_dim, key_length)
# value -> (batch_size, num_key_value_heads, key_length, head_dim)
# ==========================================================================================
batch_size = query.shape[0]
query_length = query.shape[2]
key_length = key.shape[-1]
key = repeat_key_value(key, self.num_heads, self.num_key_value_heads)
value = repeat_key_value(value, self.num_heads, self.num_key_value_heads)
# Always copies
query = query.reshape(batch_size * self.num_heads, query_length, self.head_dim)
# No copy when layer_past is provided.
key = key.reshape(batch_size * self.num_heads, self.head_dim, key_length)
# ==========================================================================================
# query -> (batch_size * num_heads, query_length, head_dim)
# key -> (batch_size * num_heads, head_dim, key_length)
# value -> (batch_size, num_heads, key_length, head_dim)
# ==========================================================================================
attn_weights = torch.empty(
(batch_size * self.num_heads, query_length, key_length), device=query.device, dtype=query.dtype
)
attn_weights = torch.baddbmm(attn_weights, query, key, beta=0, alpha=scale_factor).view(
batch_size, self.num_heads, query_length, key_length
)
# ==========================================================================================
# attn_weights -> (batch_size, num_heads, query_length, key_length)
# ==========================================================================================
attn_weights = attn_weights.to(softmax_dtype)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1).to(dtype)
attn_weights = self.attn_dropout(attn_weights)
# ==========================================================================================
# value -> (batch_size, num_heads, key_length, head_dim)
# attn_weights -> (batch_size, num_heads, query_length, key_length)
# ==========================================================================================
attn_output = torch.matmul(attn_weights, value)
# ==========================================================================================
# attn_output -> (batch_size, num_heads, query_length, head_dim)
# ==========================================================================================
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
# ==========================================================================================
# attn_output -> (batch_size, query_length, num_heads * head_dim)
# ==========================================================================================
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
return attn_output
class GraniteSDPA(GraniteAttention):
def forward(
self,
hidden_states: torch.Tensor,
past_key_values: Optional[DynamicCache] = None,
attention_mask: Optional[torch.Tensor] = None,
rope_cos_sin: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
# ==========================================================================================
# hidden_states -> (batch_size, query_length, num_heads * head_dim)
# ==========================================================================================
query, key, value = self._prepare_qkv_for_forward(hidden_states)
# ==========================================================================================
# query -> (batch_size, num_heads, query_length, head_dim)
# key -> (batch_size, num_key_value_heads, query_length, head_dim)
# value -> (batch_size, num_key_value_heads, query_length, head_dim)
# ==========================================================================================
if self.position_embedding_type == PositionEmbeddingType.rope:
query = apply_rotary_pos_emb(query, rope_cos_sin)
key = apply_rotary_pos_emb(key, rope_cos_sin)
if past_key_values is not None:
key, value = past_key_values.update(key, value, self.layer_idx)
# ==========================================================================================
# query -> (batch_size, num_heads, query_length, head_dim)
# key -> (batch_size, num_key_value_heads, key_length, head_dim)
# value -> (batch_size, num_key_value_heads, key_length, head_dim)
# ==========================================================================================
key = repeat_key_value(key, self.num_heads, self.num_key_value_heads)
value = repeat_key_value(value, self.num_heads, self.num_key_value_heads)
# ==========================================================================================
# query -> (batch_size, num_heads, query_length, head_dim)
# key -> (batch_size, num_heads, key_length, head_dim)
# value -> (batch_size, num_heads, key_length, head_dim)
# ==========================================================================================
attn_output = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=self.attn_pdrop if self.training else 0,
is_causal=self.causal if attention_mask is None else False,
scale=self.attention_multiplier if self.scale_attn_weights else 1,
)
# ==========================================================================================
# attn_output -> (batch_size, num_heads, query_length, head_dim)
# ==========================================================================================
batch_size = attn_output.shape[0]
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
# ==========================================================================================
# attn_output -> (batch_size, query_length, num_heads * head_dim)
# ==========================================================================================
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
return attn_output
class GraniteFlashAttention2(GraniteAttention):
def forward(
self,
hidden_states: torch.Tensor,
past_key_values: Optional[DynamicCache] = None,
attention_mask: Optional[torch.Tensor] = None,
rope_cos_sin: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
# ==========================================================================================
# hidden_states -> (batch_size, query_length, num_heads * head_dim)
# ==========================================================================================
query, key, value = self._prepare_qkv_for_forward(hidden_states)
# ==========================================================================================
# query -> (batch_size, num_heads, query_length, head_dim)
# key -> (batch_size, num_key_value_heads, query_length, head_dim)
# value -> (batch_size, num_key_value_heads, query_length, head_dim)
# ==========================================================================================
if self.position_embedding_type == PositionEmbeddingType.rope:
query = apply_rotary_pos_emb(query, rope_cos_sin)
key = apply_rotary_pos_emb(key, rope_cos_sin)
if past_key_values is not None:
key, value = past_key_values.update(key, value, self.layer_idx)
# ==========================================================================================
# query -> (batch_size, num_heads, query_length, head_dim)
# key -> (batch_size, num_key_value_heads, key_length, head_dim)
# value -> (batch_size, num_key_value_heads, key_length, head_dim)
# ==========================================================================================
# TODO avoid this extra transpose
query = query.transpose(1, 2)
if self.attention_head_type == AttentionHeadType.mqa:
key = key.squeeze(1).unsqueeze(2)
value = value.squeeze(1).unsqueeze(2)
else:
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# ==========================================================================================
# query -> (batch_size, query_length, num_heads, head_dim)
# key -> (batch_size, key_length, num_heads, head_dim)
# value -> (batch_size, key_length, num_heads, head_dim)
# ==========================================================================================
batch_size, query_length = query.shape[:2]
key_length = key.shape[1]
indices_k, cu_seqlens_k, max_seqlen_k = get_unpad_data(attention_mask)
key = IndexFirstAxis.apply(
key.reshape(batch_size * key_length, self.num_key_value_heads, self.head_dim), indices_k
)
value = IndexFirstAxis.apply(
value.reshape(batch_size * key_length, self.num_key_value_heads, self.head_dim), indices_k
)
if query_length == key_length:
query = IndexFirstAxis.apply(
query.reshape(batch_size * key_length, self.num_heads, self.head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_q = max_seqlen_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query = query.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query, attention_mask)
# ==========================================================================================
# query -> (total_q, num_heads, head_dim)
# key -> (total_q, num_heads, head_dim)
# value -> (total_q, num_heads, head_dim)
# ==========================================================================================
attn_output = flash_attn_varlen_func(
query,
key,
value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=self.attn_pdrop if self.training else 0,
softmax_scale=self.attention_multiplier if self.scale_attn_weights else 1,
causal=self.causal,
)
# ==========================================================================================
# attn_output -> (total_q, num_heads, head_dim)
# ==========================================================================================
attn_output = pad_input(attn_output, indices_q, batch_size, query_length)
attn_output = attn_output.view(batch_size, query_length, -1)
# ==========================================================================================
# attn_output -> (batch_size, query_length, num_heads * head_dim)
# ==========================================================================================
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
return attn_output
_ATTENTION_MODULES = {
"eager": GraniteAttention,
"sdpa": GraniteSDPA,
"flash_attention_2": GraniteFlashAttention2,
}
def get_attention_module(
config: GraniteConfig, causal: bool, attention_implementation: str, layer_idx: int
) -> GraniteAttention:
if attention_implementation in _ATTENTION_MODULES:
return _ATTENTION_MODULES[attention_implementation](config, causal=causal, layer_idx=layer_idx)
raise ValueError(f"unexpected `attention_implementation` {attention_implementation}")
##################################################
# position embeddings
class Alibi(nn.Module):
def __init__(self, num_heads: int) -> None:
super().__init__()
self.num_heads = num_heads
self.reset_parameters()
def forward(
self, attention_mask: torch.Tensor, batch_size: int, key_length: int, device: torch.device, dtype: torch.dtype
) -> torch.Tensor:
"""
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
`softmax(l+a) = softmax(l)`. Based on
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
Args:
attention_mask (torch.Tensor): attention_mask tensor of shape (`batch_size`, `key_length`)
num_heads (int): `num_heads` for the model
batch_size (int): `batch_size`
key_length (int): `key_length`
device (torch.device): device for the tensors
dtype (torch.dtype): dtype to use for the tensors
Returns:
torch.Tensor: alibi tensor of shape (`batch_size`, `num_heads`, `key_length`)
"""
# Note: alibi will added to the attention bias that will be applied to the query, key product of attention
# => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
# => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
# => the query_length dimension will then be broadcasted correctly
# This is more or less identical to T5's relative position bias:
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
if attention_mask is None:
arange_tensor = (
torch.arange(key_length, device=device).unsqueeze(0).unsqueeze(0).expand(batch_size, -1, -1)
)
else:
arange_tensor = (attention_mask.cumsum(dim=-1) - 1).masked_fill_(attention_mask == 0, 0).unsqueeze(1)
alibi = self.slopes.unsqueeze(1) * arange_tensor
return alibi.to(dtype)
def reset_parameters(self) -> None:
closest_power_of_2 = 2 ** math.floor(math.log2(self.num_heads))
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != self.num_heads:
extra_base = torch.tensor(2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32)
num_remaining_heads = min(closest_power_of_2, self.num_heads - closest_power_of_2)
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
self.register_buffer("slopes", slopes, persistent=False)
class RoPE(nn.Module):
def __init__(
self,
head_dim: int,
max_position_embeddings: int = 2048,
base: int = 10000,
) -> None:
super().__init__()
self.head_dim = head_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.mscale = 1
self.reset_parameters()
def forward(self, seq_len: int, dtype: torch.dtype, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=device, dtype=dtype)
cos = self.cos_cached[:seq_len].to(dtype)
sin = self.sin_cached[:seq_len].to(dtype)
return cos, sin
def reset_parameters(self) -> None:
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=self.max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
@torch.no_grad()
def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> None:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", (emb.cos() * self.mscale).to(dtype), persistent=False)
self.register_buffer("sin_cached", (emb.sin() * self.mscale).to(dtype), persistent=False)
def apply_rotary_pos_emb(x: torch.Tensor, cos_sin: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
cos, sin = cos_sin
x = (x * cos) + (_rotate_half(x) * sin)
return x
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
x1, x2 = torch.chunk(x, 2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
##################################################
# MLP
class GraniteMLP(nn.Module):
def __init__(self, config: GraniteConfig) -> None:
super().__init__()
hidden_size = config.n_embd
intermediate_size = config.n_inner
activation_function = config.activation_function
add_bias = config.add_bias
residual_dropout = config.resid_pdrop
self.c_fc = nn.Linear(
hidden_size,
2 * intermediate_size if is_glu(activation_function) else intermediate_size,
bias=add_bias,
)
self.act = get_activation_function(activation_function)
self.c_proj = nn.Linear(intermediate_size, hidden_size, bias=add_bias)
self.dropout = nn.Identity() if residual_dropout == 0 else nn.Dropout(residual_dropout)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
##################################################
# transformer layer
class GraniteBlock(nn.Module):
def __init__(
self,
config: GraniteConfig,
attention_implementation: str,
layer_idx: Optional[int] = None,
) -> None:
super().__init__()
hidden_size = config.hidden_size
self.inner_dim = config.n_inner
self.layer_idx = layer_idx
self.ln_1 = get_normalization_function(
config.normalization_function,
hidden_size,
eps=config.layer_norm_epsilon,
)
self.attn = get_attention_module(config, True, attention_implementation, layer_idx)
self.ln_2 = get_normalization_function(
config.normalization_function,
hidden_size,
eps=config.layer_norm_epsilon,
)
self.mlp = GraniteMLP(config)
def forward(
self,
hidden_states: torch.Tensor,
past_key_values: Optional[DynamicCache] = None,
attention_mask: Optional[torch.Tensor] = None,
rope_cos_sin: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_output = self.attn(
hidden_states,
past_key_values=past_key_values,
attention_mask=attention_mask,
rope_cos_sin=rope_cos_sin,
)
# residual connection
hidden_states = attn_output + residual
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
# residual connection
hidden_states = residual + feed_forward_hidden_states
return hidden_states
##################################################
# model classes
class GranitePreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = GraniteConfig
base_model_prefix = "transformer"
causal = True
_no_split_modules = ["GraniteBlock"]
_skip_keys_device_placement = "past_key_values"
_supports_sdpa = True
_supports_flash_attn_2 = True
def __init__(self, config: GraniteConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.attention_implementation = self.config._attn_implementation
self._use_eager_attention = self.attention_implementation == "eager"
self._use_sdpa = self.attention_implementation == "sdpa"
self._use_flash_attention_2 = self.attention_implementation == "flash_attention_2"
self.initializer_range = config.initializer_range
def _init_weights(self, module: nn.Module) -> None:
if isinstance(module, (nn.LayerNorm, RMSNorm, RoPE)):
module.reset_parameters()
elif isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0, std=self.initializer_range)
if module.bias is not None:
module.bias.zero_()
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0, std=self.initializer_range)
if module.padding_idx is not None:
module.weight[module.padding_idx].zero_()
class GraniteModel(GranitePreTrainedModel):
_keys_to_ignore_on_load_missing = ["attn.masked_bias"]
mask_value = None
def __init__(self, config: GraniteConfig, **kwargs) -> None:
super().__init__(config, **kwargs)
self.attention_head_type = AttentionHeadType(config.attention_head_type)
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
assert (
self.embed_dim % self.num_heads == 0
), f"`embed_dim` ({self.embed_dim}) must be divisible by `num_heads` ({self.num_heads})"
self.head_dim = self.embed_dim // self.num_heads
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
self.drop = nn.Identity() if config.embd_pdrop == 0 else nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList(
[GraniteBlock(config, self.attention_implementation, layer_idx=i) for i in range(config.num_hidden_layers)]
)
self.ln_f = get_normalization_function(
config.normalization_function,
self.embed_dim,
eps=config.layer_norm_epsilon,
)
self.position_embedding_type = PositionEmbeddingType(config.position_embedding_type)
if self.position_embedding_type == PositionEmbeddingType.learned_absolute:
self.wpe = nn.Embedding(config.n_positions, self.embed_dim)
elif self.position_embedding_type == PositionEmbeddingType.alibi:
assert not self._use_flash_attention_2, "alibi is not implemented with FlashAttention"
self.alibi = Alibi(self.num_heads)
elif self.position_embedding_type == PositionEmbeddingType.rope:
self.rope = RoPE(self.head_dim, max_position_embeddings=config.n_positions, base=config.rope_theta)
else:
raise NotImplementedError()
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Embedding:
return self.wte
def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
self.wte = new_embeddings
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[DynamicCache] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
(
output_hidden_states,
use_cache,
return_dict,
input_shape,
hidden_states,
attention_mask,
position_ids,
rope_cos_sin,
past_key_values,
) = self._prepare_a_bunch_of_stuff(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# ==========================================================================================
# flash:
# attention_mask -> (batch_size, key_length)
# else:
# attention_mask -> (batch_size, 1, query_length, key_length)
# ==========================================================================================
output_shape = input_shape + (hidden_states.size(-1),)
past_key_values = DynamicCache() if use_cache and past_key_values is None else past_key_values
all_hidden_states = () if output_hidden_states else None
for block in self.h:
if output_hidden_states:
all_hidden_states += (hidden_states,)
hidden_states = block(
hidden_states,
past_key_values=past_key_values,
attention_mask=attention_mask,
rope_cos_sin=rope_cos_sin,
)
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
# Add last hidden state
if output_hidden_states:
all_hidden_states += (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, past_key_values, all_hidden_states] if v is not None)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
hidden_states=all_hidden_states,
)
def _get_position_ids(
self, attention_mask: torch.Tensor, past_length: int, query_length: int, key_length: int, device: torch.device
) -> torch.Tensor:
if attention_mask is not None and len(attention_mask.shape) == 2:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 0)
if past_length > 0:
position_ids = position_ids[:, past_length:key_length:]
else:
position_ids = torch.arange(past_length, key_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, query_length)
return position_ids
def _get_alibi_bias(
self,
attention_mask: torch.Tensor,
batch_size: int,
query_length: int,
key_length: int,
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
if self.position_embedding_type != PositionEmbeddingType.alibi:
return None
alibi_bias = self.alibi(attention_mask, batch_size, key_length, device, dtype)
# ==========================================================================================
# alibi_bias -> (batch_size, num_heads, key_length)
# ==========================================================================================
alibi_bias = alibi_bias.unsqueeze(2)
if query_length != 1:
alibi_bias = alibi_bias.expand(-1, -1, query_length, -1)
# ==========================================================================================
# alibi_bias -> (batch_size, num_heads, query_length, key_length)
# ==========================================================================================
return alibi_bias
def _get_rope_cos_sin(
self, key_length: int, position_ids: torch.Tensor, dtype: torch.dtype, device: torch.device
) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
if self.position_embedding_type == PositionEmbeddingType.rope:
cos, sin = self.rope(key_length, dtype=dtype, device=device)
cos = cos[position_ids].unsqueeze(1)
sin = sin[position_ids].unsqueeze(1)
return cos, sin
def _prepare_causal_attention_mask(
self, attention_mask: torch.Tensor, batch_size: int, query_length: int, key_length: int, device: torch.device
) -> torch.Tensor:
past_length = key_length - query_length
# ==========================================================================================
# attention_mask -> (batch_size, key_length)
# ==========================================================================================
if query_length > 1:
# (query_length, key_length)
causal_mask = torch.empty((query_length, key_length), dtype=torch.bool, device=device)
causal_mask[:, past_length:] = torch.tril(
torch.ones(query_length, query_length, dtype=torch.bool, device=device)
)
if past_length > 0:
causal_mask[:, :past_length] = True
# (query_length, key_length) -> (1, query_length, key_length)
causal_mask = causal_mask.unsqueeze(0)
if attention_mask is None:
# (1, query_length, key_length) -> (batch_size, query_length, key_length)
causal_mask = causal_mask.expand(batch_size, -1, -1)
else:
# (1, query_length, key_length) & (batch_size, 1, key_length) -> (batch_size, query_length, key_length)
causal_mask = causal_mask & attention_mask.unsqueeze(1).to(torch.bool)
else:
if attention_mask is None:
# (batch_size, query_length, key_length)
causal_mask = torch.ones(batch_size, query_length, key_length, dtype=torch.bool, device=device)
else:
# (batch_size, query_length, key_length)
causal_mask = attention_mask.unsqueeze(1).to(dtype=torch.bool, device=device)
# ==========================================================================================
# attention_mask -> (batch_size, query_length, key_length)
# ==========================================================================================
causal_mask = causal_mask.unsqueeze(1)
# ==========================================================================================
# attention_mask -> (batch_size, 1, query_length, key_length)
# ==========================================================================================
return causal_mask
def _get_initial_hidden_state(
self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
token_type_ids: torch.Tensor,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
if self.position_embedding_type == PositionEmbeddingType.learned_absolute:
inputs_embeds = inputs_embeds + self.wpe(position_ids)
if token_type_ids is not None:
inputs_embeds = inputs_embeds + self.wte(token_type_ids)
inputs_embeds = self.drop(inputs_embeds)
return inputs_embeds
def _prepare_a_bunch_of_stuff(
self,
input_ids: torch.Tensor = None,
past_key_values: DynamicCache = None,
attention_mask: torch.Tensor = None,
token_type_ids: torch.Tensor = None,
position_ids: torch.Tensor = None,
inputs_embeds: torch.Tensor = None,
use_cache: bool = None,
output_hidden_states: bool = None,
return_dict: bool = None,
) -> Tuple[
bool,
bool,
bool,
torch.Size,
torch.Tensor,
torch.Tensor,
torch.Tensor,
Optional[Tuple[torch.Tensor, torch.Tensor]],
DynamicCache,
]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = self.config.use_cache if use_cache is None else use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
# TODO special handling for padding free transformer needed here if we support inputs_embeds argument
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size = input_shape[0]
device = input_ids.device if input_ids is not None else inputs_embeds.device
if self.position_embedding_type == PositionEmbeddingType.alibi:
if position_ids is not None:
warnings.warn("`position_ids` have no functionality with Alibi.", FutureWarning)
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
# ==========================================================================================
# input_ids -> (batch_size, query_length)
# attention_mask -> None or (batch_size, key_length)
# position_ids -> None or (batch_size, key_length)
# ==========================================================================================
past_length = 0 if past_key_values is None else past_key_values.get_seq_length()
query_length = input_shape[-1]
key_length = past_length + query_length
if position_ids is None:
position_ids = self._get_position_ids(attention_mask, past_length, query_length, key_length, device)
# ==========================================================================================
# input_ids -> (batch_size, query_length)
# attention_mask -> None or (batch_size, key_length)
# position_ids -> (batch_size, query_length)
# ==========================================================================================
hidden_states = self._get_initial_hidden_state(input_ids, inputs_embeds, position_ids, token_type_ids)
# ==========================================================================================
# hidden_states -> (batch_size, query_length, num_heads * head_dim)
# ==========================================================================================
alibi_bias = self._get_alibi_bias(
attention_mask, batch_size, query_length, key_length, device, hidden_states.dtype
)
# ==========================================================================================
# alibi_bias -> (batch_size, num_heads, query_length, key_length)
# ==========================================================================================
rope_cos_sin = self._get_rope_cos_sin(
key_length, position_ids, dtype=hidden_states.dtype, device=hidden_states.device
)
# ==========================================================================================
# rope_cos_sin -> 2 * (key_length, head_dim)
# ==========================================================================================
# prepare causal mask only if not using flash attention
if self._use_flash_attention_2:
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
elif self._use_sdpa:
# we use the causal/non-causal argument of SDPA for attention in this case
if attention_mask is not None:
attention_mask = self._prepare_causal_attention_mask(
attention_mask, batch_size, query_length, key_length, device
)
attention_mask = torch.where(
attention_mask,
~attention_mask if alibi_bias is None else alibi_bias,
self._get_mask_value(attention_mask.device, hidden_states.dtype),
)
else:
attention_mask = self._prepare_causal_attention_mask(
attention_mask, batch_size, query_length, key_length, device
)
attention_mask = torch.where(
attention_mask,
~attention_mask if alibi_bias is None else alibi_bias,
self._get_mask_value(attention_mask.device, hidden_states.dtype),
)
return (
output_hidden_states,
use_cache,
return_dict,
input_shape,
hidden_states,
attention_mask,
position_ids,
rope_cos_sin,
past_key_values,
)
def _get_mask_value(self, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
# torch.where expects a tensor. We use a cache to avoid recreating it every time.
if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device:
self.mask_value = torch.full([], torch.finfo(torch.float16).min, dtype=dtype, device=device)
return self.mask_value
class GraniteForCausalLM(GranitePreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
def __init__(self, config: GraniteConfig, **kwargs) -> None:
super().__init__(config, **kwargs)
self.transformer = GraniteModel(config, **kwargs)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Embedding:
return self.transformer.wte
def set_input_embeddings(self, value: nn.Embedding) -> None:
self.transformer.wte = value
def get_output_embeddings(self) -> nn.Linear:
return self.lm_head
def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
self.lm_head = new_embeddings
# FIXME typing
def prepare_inputs_for_generation(
self,
input_ids: torch.Tensor,
past_key_values: Optional[DynamicCache] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> dict:
token_type_ids = kwargs.get("token_type_ids", None)
# Omit tokens covered by past_key_values
if past_key_values:
past_length = past_key_values.get_seq_length()
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 0)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
else:
position_ids = None
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
)
return model_inputs
def forward(
self,
input_ids: Optional[Union[torch.Tensor]] = None,
past_key_values: Optional[DynamicCache] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[Union[torch.Tensor]] = None,
position_ids: Optional[Union[torch.Tensor]] = None,
inputs_embeds: Optional[Union[torch.Tensor]] = None,
labels: Optional[Union[torch.Tensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# ==========================================================================================
# input_ids -> (batch_size, query_length)
# attention_mask -> None or (batch_size, key_length)
# position_ids -> None or (batch_size, key_length)
# ==========================================================================================
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
# Shift so that tokens < n predict n
if labels is not None:
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous().to(shift_logits.device)
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)