π Tiny Stories Subword LLM β A 3MB Efficient Language Model
This is a tiny, efficient, and fast autoregressive language model trained on 10,000 TinyStories using a custom Selective Recurrent Layer (SRL) β a linear-complexity alternative to Transformers β and a 5k-token SentencePiece Unigram tokenizer trained on the same dataset.
Unlike Transformers (O(nΒ²)), this model runs in O(n) time and memory, making it ideal for edge devices, mobile apps, or low-resource environments β while still generating coherent, story-like text.
β
Size: ~10 MB
β
Vocabulary: 5,000 subword tokens (SentencePiece)
β
Architecture: 2-layer SRL with 256-dim hidden states
β
Loss: ~2.42 after 20 epochs
β
Speed: ~0.5s per 100-token generation on CPU
β¨ Usage Example
from transformers import AutoTokenizer
import jax
import jax.numpy as jnp
from flax import serialization
import sentencepiece as spm
import numpy as np
import json
import os
# Load model config
with open("config.json") as f:
config = json.load(f)
# Load SentencePiece tokenizer
tokenizer = spm.SentencePieceProcessor(model_file="tokenizer.model")
# Define model architecture (same as training)
class SelectiveRecurrentLayer(nn.Module):
d_model: int
d_state: int = 16
dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, x):
x = x.astype(self.dtype)
B, L, D = x.shape
A_log = self.param("A_log", nn.initializers.zeros, (D,))
A = -jnp.exp(A_log.astype(self.dtype))
delta = nn.Dense(D, dtype=self.dtype)(x)
delta = jax.nn.softplus(delta)
B_ssm = nn.Dense(self.d_state, dtype=self.dtype)(x)
C_ssm = nn.Dense(self.d_state, dtype=self.dtype)(x)
A_bar = jnp.exp(A * delta)
inv_A = 1.0 / (-A)
B_exp = B_ssm[:, :, :, None]
A_exp = A_bar[:, :, None, :]
x_exp = x[:, :, None, :]
C_exp = C_ssm[:, :, :, None]
B_bar = B_exp * ((1 - A_exp) * inv_A)
inputs = (A_exp, B_bar, x_exp, C_exp)
inputs = jax.tree.map(lambda t: t.transpose(1, 0, 2, 3), inputs)
def ssm_op(carry, inp):
A_curr, B_curr, x_curr, C_curr = inp
state = carry
state = A_curr * state + B_curr * x_curr
y = jnp.sum(C_curr * state, axis=1)
return state, y
init_state = jnp.zeros((B, self.d_state, D), dtype=self.dtype)
_, y_seq = lax.scan(ssm_op, init_state, inputs)
return y_seq.transpose(1, 0, 2)
class SubwordLLM(nn.Module):
vocab_size: int
d_model: int = 256
n_layers: int = 2
dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, input_ids):
x = nn.Embed(self.vocab_size, self.d_model, dtype=jnp.float32)(input_ids)
x = x.astype(self.dtype)
for _ in range(self.n_layers):
x = SelectiveRecurrentLayer(d_model=self.d_model, dtype=self.dtype)(x)
x = nn.LayerNorm(dtype=self.dtype)(x)
return nn.Dense(self.vocab_size, dtype=self.dtype)(x)
# Load weights
model = SubwordLLM(
vocab_size=config["vocab_size"],
d_model=config["d_model"],
n_layers=config["n_layers"],
dtype=jnp.dtype(config["dtype"])
)
with open("flax_model.msgpack", "rb") as f:
params = serialization.from_bytes(
model.init(jax.random.key(0), jnp.ones((1, 128), dtype=jnp.int32)),
f.read()
)
# Generation function
def generate(prompt, max_new_tokens=150, temperature=0.7, repetition_penalty=1.2, top_k=25):
ids = tokenizer.encode(prompt)
ids = [i for i in ids if i not in (tokenizer.pad_id(), tokenizer.eos_id())]
generated = ids.copy()
input_ids = jnp.array([generated], dtype=jnp.int32)
for _ in range(max_new_tokens):
logits = model.apply(params, input_ids)
next_token_logits = logits[0, -1, :]
for tok in set(generated):
next_token_logits = next_token_logits.at[tok].divide(repetition_penalty)
if top_k > 0:
top_k_vals, top_k_idx = jax.lax.top_k(next_token_logits, min(top_k, len(next_token_logits)))
mask = jnp.full_like(next_token_logits, -1e10)
mask = mask.at[top_k_idx].set(top_k_vals)
next_token_logits = mask
next_token_logits /= temperature
key = jax.random.key(np.random.randint(0, 2**31 - 1))
next_token = int(jax.random.categorical(key, next_token_logits))
if next_token == tokenizer.eos_id():
break
generated.append(next_token)
input_ids = jnp.array([generated], dtype=jnp.int32)
return tokenizer.decode(generated)
# Generate!
print(generate("once upon a time"))
π Sample Output:
once upon a time, there was a little girl named Lily. She loved to play in the park. One day, she found a shiny rock. She showed it to her mom, who smiled and said, βThatβs magic!β Lily put it in her pocket and ran home. That night, the rock glowed under her pillow. She dreamed of dragons and stars β and woke up with a new friend beside her.
ποΈ Model Architecture
- No attention! Uses a Selective State Space Model (SSM) with linear complexity.
- Input: Subword tokens (SentencePiece Unigram, 5k vocab)
- Hidden layers: 2 Γ SelectiveRecurrentLayer (256-dim)
- Memory: O(n), not O(nΒ²) β ideal for long contexts
- Training: 10k TinyStories, 20 epochs, batch size 32
π Training Details
Item | Value |
---|---|
Dataset | roneneldan/TinyStories (50,000 samples) |
Tokenizer | SentencePiece Unigram (vocab=5000) |
Epochs | 50 |
Loss | ~2.42 |
Max Length | 128 |
Optimizer | AdamW + Cosine Decay |
Hardware | T4 GPU (x2) |
Training Time | ~15 minutes |
- Downloads last month
- 11