How to use the model
Model definition for MLX is in Files and versions. Make sure you have downloaded it and placed it in the same folder as your notebook. Support for general PyTorch is WIP.
To use the model:
# Setup the configuration for this model
config = {
'layers': 8,
'num_heads': 8,
'vocab_size': 256,
'input_dims': 512,
'hidden_dims': None,
'mtp_heads': 0,
'dtype': mx.bfloat16,
}
# Import the model definition:
from model import * # Replace 'model' with your file name if renamed (e.g. if the model def is called XYZ.py import from XYZ)
# Instantiate the model
model = RWKV_v7(**config)
dtype = config['dtype']
mx.eval(model.parameters())
# Load the model
ckpt_path = 'Downloads/TinyCorpus_10414_0.9765625_Aug_16_2025_weights.pt' # Insert your path to RWKV 7 model
state_dict = torch.load(ckpt_path, map_location='cpu', weights_only=True) # IMPORTANT: map_location='cpu' since weights are mapped to 'cuda'!
new_state_dict = OrderedDict()
for old_key, weight in state_dict.items():
new_key = old_key.replace(".d_lora.", ".d_lora.layers.").replace(".a_lora.", ".a_lora.layers.")
new_state_dict[new_key] = weight
th_p = dict(new_state_dict)
th_p = utils.tree_map(lambda x: mx.array(x), th_p)
th_p_ = list(zip(th_p.keys(), th_p.values()))
model.load_weights(th_p_, strict=True)
Example: Evaluate the model on some text
def eval_loss(text_: str, model, config, per_token=False):
text_ = text_.encode('utf-8')
x_prev_0s, state_prevs, x_prev_1s = (mx.zeros([config['layers'], 1, 1, config['input_dims']], dtype=dtype),
mx.zeros([config['layers'], 1, config['num_heads'],
config['input_dims'] // config['num_heads'],
config['input_dims'] // config['num_heads']], dtype=mx.float32),
mx.zeros([config['layers'], 1, 1, config['input_dims']], dtype=dtype))
model.eval()
txt_btch = mx.array(text_).reshape(1, -1).astype(mx.uint8)
logits, (x_prev_0s, state_prevs, x_prev_1s) = model(txt_btch, x_prev_0s=x_prev_0s, state_prevs=state_prevs, x_prev_1s=x_prev_1s)
model.train()
if per_token:
return nn.losses.cross_entropy(logits, mx.roll(txt_btch, -1, axis=1))[:, :-1].mean(), (mx.argmax(logits, axis=-1) == mx.roll(txt_btch, -1, axis=1)).mean(), nn.losses.cross_entropy(logits, mx.roll(txt_btch, -1, axis=1))[:, :-1]
else:
return nn.losses.cross_entropy(logits, mx.roll(txt_btch, -1, axis=1))[:, :-1].mean(), (mx.argmax(logits, axis=-1) == mx.roll(txt_btch, -1, axis=1)).mean()
The text should show something like '[STX]def to_char(x): ...' since '[STX]' is my start token. Else, add the \x02 character in, NOT the picture version.
The STX character should appear bright red, the version on the right is the correct one.
text_ = '''def to_char(x):
try:
return bytes([x]).decode('utf-8')
except:
return f'{x}'
'''
print(eval_loss(text_, model, config)) # returns (CE Loss, Accuracy of next character)
(array(0.738281, dtype=bfloat16), array(0.77451, dtype=float32))
Example: Visualize the attention maps (beta)
You must have plotly installed for this, matplotlib is not interactive and is very slow for high resolution images.
def features(text_: str, model, config):
text_ = text_.encode('utf-8')
x_prev_0s, state_prevs, x_prev_1s = (mx.zeros([config['layers'], 1, 1, config['input_dims']], dtype=dtype),
mx.zeros([config['layers'], 1, config['num_heads'],
config['input_dims'] // config['num_heads'],
config['input_dims'] // config['num_heads']], dtype=mx.float32),
mx.zeros([config['layers'], 1, 1, config['input_dims']], dtype=dtype))
model.eval()
txt_btch = mx.array(text_).reshape(1, -1).astype(mx.uint8)
logits, (x_prev_0s, state_prevs, x_prev_1s), attn_maps, ffn_acts = model(txt_btch, x_prev_0s=x_prev_0s, state_prevs=state_prevs, x_prev_1s=x_prev_1s, attn_map=True, ffn_act=True)
model.train()
return attn_maps, ffn_acts
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
def to_char(x):
try:
return bytes([x]).decode('utf-8')
except:
return f'{x}'
text_ = '''def to_char(x):
try:
return bytes([x]).decode('utf-8')
except:
return f'{x}'
'''
attn_maps = np.array(features(text_, model, config)[0].abs())
text_bytes = list(bytes(text_.encode('utf-8')))
text_labels = [to_char(i) for i in text_bytes]
n_layers = attn_maps.shape[0]
n_heads = attn_maps.shape[4]
sequence_length = attn_maps.shape[2]
print(f"Detected {n_layers} layers and {n_heads} heads.")
print(f"Attention map shape for each head: ({sequence_length}, {sequence_length})")
subplot_base_size = 512
fig = make_subplots(
rows=n_layers,
cols=n_heads,
subplot_titles=[f"L{layer_idx} H{head_idx}" for layer_idx in range(n_layers) for head_idx in range(n_heads)],
horizontal_spacing=0.01,
vertical_spacing=0.01
)
x_indices = list(range(sequence_length))
y_indices = list(range(sequence_length))
for layer_idx in range(n_layers):
for head_idx in range(n_heads):
current_attention_map = attn_maps[layer_idx, 0, :, :, head_idx]
trace = go.Heatmap(
z=current_attention_map,
x=x_indices,
y=y_indices,
colorscale='Viridis',
showscale=False,
hovertemplate="Source: %{x}<br>Target: %{y}<br>Attention: %{z:.4f}<extra></extra>"
)
fig.add_trace(trace, row=layer_idx + 1, col=head_idx + 1)
fig.update_xaxes(
row=layer_idx + 1,
col=head_idx + 1,
showticklabels=False,
tickangle=45,
constrain='domain',
tickmode='array',
tickvals=x_indices,
ticktext=text_labels
)
fig.update_yaxes(
row=layer_idx + 1,
col=head_idx + 1,
showticklabels=False,
scaleanchor="x",
scaleratio=1,
constrain='domain',
tickmode='array',
tickvals=y_indices,
ticktext=text_labels,
autorange="reversed"
)
fig_height = n_layers * subplot_base_size + 128
fig_width = n_heads * subplot_base_size + 128
fig.update_layout(
title_text="Attention Maps per Layer and Head",
height=fig_height,
width=fig_width,
showlegend=False,
hovermode='closest',
margin=dict(l=50, r=50, t=80, b=50)
)
fig.show()
Example of RWKV7 attention map (Layers are L0, L2, L4, L6)
Example of Softmax attention map (Layers are L1, L3, L5, L7)
We can already observe interesting patterns just from these heads!
L0 H1 has triangular attention patterns that coincide with the characters in words; perhaps the model has learned to segment words to map character sequences to semantic meanings?
L7 H1 shows a weak induction head, which uses past patterns to guess future tokens. For example, if the sequence 5873 appears in the previous context, and the model sees 587 again, the induction head tells the model to output 3.
Stronger induction heads can be found in L3 H0, L3 H3, L3 H4, L3 H7, L5 H0, L5 H4, L5 H6, L5 H7. (Note that all the induction heads use softmax attention, perhaps RWKV attention is not good enough for sharp, precise long-context recall?)