How to use the model

Model definition for MLX is in Files and versions. Make sure you have downloaded it and placed it in the same folder as your notebook. Support for general PyTorch is WIP.

To use the model:

# Setup the configuration for this model

config = {
    'layers': 8,
    'num_heads': 8,
    'vocab_size': 256,
    'input_dims': 512,
    'hidden_dims': None, 
    'mtp_heads': 0,
    'dtype': mx.bfloat16,
}
# Import the model definition:

from model import * # Replace 'model' with your file name if renamed (e.g. if the model def is called XYZ.py import from XYZ)

# Instantiate the model

model = RWKV_v7(**config)
dtype = config['dtype']

mx.eval(model.parameters())
# Load the model
ckpt_path = 'Downloads/TinyCorpus_10414_0.9765625_Aug_16_2025_weights.pt' # Insert your path to RWKV 7 model

state_dict = torch.load(ckpt_path, map_location='cpu', weights_only=True) # IMPORTANT: map_location='cpu' since weights are mapped to 'cuda'!

new_state_dict = OrderedDict()

for old_key, weight in state_dict.items():
    new_key = old_key.replace(".d_lora.", ".d_lora.layers.").replace(".a_lora.", ".a_lora.layers.")
    new_state_dict[new_key] = weight

th_p = dict(new_state_dict)
    
th_p = utils.tree_map(lambda x: mx.array(x), th_p)
th_p_ = list(zip(th_p.keys(), th_p.values()))
model.load_weights(th_p_, strict=True)

Example: Evaluate the model on some text

def eval_loss(text_: str, model, config, per_token=False):
    text_ = text_.encode('utf-8')
    
    x_prev_0s, state_prevs, x_prev_1s = (mx.zeros([config['layers'], 1, 1, config['input_dims']], dtype=dtype), 
                                     
                                         mx.zeros([config['layers'], 1, config['num_heads'], 
                                                   config['input_dims'] // config['num_heads'], 
                                                   config['input_dims'] // config['num_heads']], dtype=mx.float32), 
                                         
                                         mx.zeros([config['layers'], 1, 1, config['input_dims']], dtype=dtype))
        
    model.eval()
    
    txt_btch = mx.array(text_).reshape(1, -1).astype(mx.uint8)
    
    logits, (x_prev_0s, state_prevs, x_prev_1s) = model(txt_btch, x_prev_0s=x_prev_0s, state_prevs=state_prevs, x_prev_1s=x_prev_1s)
    
    model.train()

    if per_token:
        return nn.losses.cross_entropy(logits, mx.roll(txt_btch, -1, axis=1))[:, :-1].mean(), (mx.argmax(logits, axis=-1) == mx.roll(txt_btch, -1, axis=1)).mean(), nn.losses.cross_entropy(logits, mx.roll(txt_btch, -1, axis=1))[:, :-1]
    else:
        return nn.losses.cross_entropy(logits, mx.roll(txt_btch, -1, axis=1))[:, :-1].mean(), (mx.argmax(logits, axis=-1) == mx.roll(txt_btch, -1, axis=1)).mean()

The text should show something like '[STX]def to_char(x): ...' since '[STX]' is my start token. Else, add the \x02 character in, NOT the picture version.

The STX character should appear bright red, the version on the right is the correct one. image/png

text_ = '''def to_char(x):
    try:
        return bytes([x]).decode('utf-8')
    except:
        return f'{x}'
'''

print(eval_loss(text_, model, config)) # returns (CE Loss, Accuracy of next character)
(array(0.738281, dtype=bfloat16), array(0.77451, dtype=float32))

Example: Visualize the attention maps (beta)

You must have plotly installed for this, matplotlib is not interactive and is very slow for high resolution images.

def features(text_: str, model, config):
    text_ = text_.encode('utf-8')
    
    x_prev_0s, state_prevs, x_prev_1s = (mx.zeros([config['layers'], 1, 1, config['input_dims']], dtype=dtype), 
                                     
                                         mx.zeros([config['layers'], 1, config['num_heads'], 
                                                   config['input_dims'] // config['num_heads'], 
                                                   config['input_dims'] // config['num_heads']], dtype=mx.float32), 
                                         
                                         mx.zeros([config['layers'], 1, 1, config['input_dims']], dtype=dtype))
        
    model.eval()
    
    txt_btch = mx.array(text_).reshape(1, -1).astype(mx.uint8)
    
    logits, (x_prev_0s, state_prevs, x_prev_1s), attn_maps, ffn_acts = model(txt_btch, x_prev_0s=x_prev_0s, state_prevs=state_prevs, x_prev_1s=x_prev_1s, attn_map=True, ffn_act=True)
    
    model.train()

    return attn_maps, ffn_acts
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np

def to_char(x):
    try:
        return bytes([x]).decode('utf-8')
    except:
        return f'{x}'

text_ = '''def to_char(x):
    try:
        return bytes([x]).decode('utf-8')
    except:
        return f'{x}'
'''

attn_maps = np.array(features(text_, model, config)[0].abs())

text_bytes = list(bytes(text_.encode('utf-8')))
text_labels = [to_char(i) for i in text_bytes]

n_layers = attn_maps.shape[0]
n_heads = attn_maps.shape[4]
sequence_length = attn_maps.shape[2]

print(f"Detected {n_layers} layers and {n_heads} heads.")
print(f"Attention map shape for each head: ({sequence_length}, {sequence_length})")

subplot_base_size = 512

fig = make_subplots(
    rows=n_layers,
    cols=n_heads,
    subplot_titles=[f"L{layer_idx} H{head_idx}" for layer_idx in range(n_layers) for head_idx in range(n_heads)],
    horizontal_spacing=0.01,
    vertical_spacing=0.01
)

x_indices = list(range(sequence_length))
y_indices = list(range(sequence_length))

for layer_idx in range(n_layers):
    for head_idx in range(n_heads):
        current_attention_map = attn_maps[layer_idx, 0, :, :, head_idx]

        trace = go.Heatmap(
            z=current_attention_map,
            x=x_indices,
            y=y_indices,
            colorscale='Viridis',
            showscale=False,
            hovertemplate="Source: %{x}<br>Target: %{y}<br>Attention: %{z:.4f}<extra></extra>"
        )

        fig.add_trace(trace, row=layer_idx + 1, col=head_idx + 1)

        fig.update_xaxes(
            row=layer_idx + 1,
            col=head_idx + 1,
            showticklabels=False,
            tickangle=45,
            constrain='domain',
            tickmode='array',
            tickvals=x_indices,
            ticktext=text_labels
        )
        fig.update_yaxes(
            row=layer_idx + 1,
            col=head_idx + 1,
            showticklabels=False,
            scaleanchor="x",
            scaleratio=1,
            constrain='domain',
            tickmode='array',
            tickvals=y_indices,
            ticktext=text_labels,
            autorange="reversed"
        )

fig_height = n_layers * subplot_base_size + 128
fig_width = n_heads * subplot_base_size + 128

fig.update_layout(
    title_text="Attention Maps per Layer and Head",
    height=fig_height,
    width=fig_width,
    showlegend=False,
    hovermode='closest',
    margin=dict(l=50, r=50, t=80, b=50)
)

fig.show()

Example of RWKV7 attention map (Layers are L0, L2, L4, L6) image/png

Example of Softmax attention map (Layers are L1, L3, L5, L7) image/png

We can already observe interesting patterns just from these heads!

L0 H1 has triangular attention patterns that coincide with the characters in words; perhaps the model has learned to segment words to map character sequences to semantic meanings?

L7 H1 shows a weak induction head, which uses past patterns to guess future tokens. For example, if the sequence 5873 appears in the previous context, and the model sees 587 again, the induction head tells the model to output 3.

Stronger induction heads can be found in L3 H0, L3 H3, L3 H4, L3 H7, L5 H0, L5 H4, L5 H6, L5 H7. (Note that all the induction heads use softmax attention, perhaps RWKV attention is not good enough for sharp, precise long-context recall?)

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support