🌟 Tiny Stories Subword LLM β€” A 3MB Efficient Language Model

This is a tiny, efficient, and fast autoregressive language model trained on 10,000 TinyStories using a custom Selective Recurrent Layer (SRL) β€” a linear-complexity alternative to Transformers β€” and a 5k-token SentencePiece Unigram tokenizer trained on the same dataset.

Unlike Transformers (O(nΒ²)), this model runs in O(n) time and memory, making it ideal for edge devices, mobile apps, or low-resource environments β€” while still generating coherent, story-like text.

βœ… Size: ~10 MB
βœ… Vocabulary: 5,000 subword tokens (SentencePiece)
βœ… Architecture: 2-layer SRL with 256-dim hidden states
βœ… Loss: ~2.42 after 20 epochs
βœ… Speed: ~0.5s per 100-token generation on CPU


✨ Usage Example

from transformers import AutoTokenizer
import jax
import jax.numpy as jnp
from flax import serialization
import sentencepiece as spm
import numpy as np
import json
import os

# Load model config
with open("config.json") as f:
    config = json.load(f)

# Load SentencePiece tokenizer
tokenizer = spm.SentencePieceProcessor(model_file="tokenizer.model")

# Define model architecture (same as training)
class SelectiveRecurrentLayer(nn.Module):
    d_model: int
    d_state: int = 16
    dtype: jnp.dtype = jnp.float32
    @nn.compact
    def __call__(self, x):
        x = x.astype(self.dtype)
        B, L, D = x.shape
        A_log = self.param("A_log", nn.initializers.zeros, (D,))
        A = -jnp.exp(A_log.astype(self.dtype))
        delta = nn.Dense(D, dtype=self.dtype)(x)
        delta = jax.nn.softplus(delta)
        B_ssm = nn.Dense(self.d_state, dtype=self.dtype)(x)
        C_ssm = nn.Dense(self.d_state, dtype=self.dtype)(x)
        A_bar = jnp.exp(A * delta)
        inv_A = 1.0 / (-A)
        B_exp = B_ssm[:, :, :, None]
        A_exp = A_bar[:, :, None, :]
        x_exp = x[:, :, None, :]
        C_exp = C_ssm[:, :, :, None]
        B_bar = B_exp * ((1 - A_exp) * inv_A)
        inputs = (A_exp, B_bar, x_exp, C_exp)
        inputs = jax.tree.map(lambda t: t.transpose(1, 0, 2, 3), inputs)
        def ssm_op(carry, inp):
            A_curr, B_curr, x_curr, C_curr = inp
            state = carry
            state = A_curr * state + B_curr * x_curr
            y = jnp.sum(C_curr * state, axis=1)
            return state, y
        init_state = jnp.zeros((B, self.d_state, D), dtype=self.dtype)
        _, y_seq = lax.scan(ssm_op, init_state, inputs)
        return y_seq.transpose(1, 0, 2)

class SubwordLLM(nn.Module):
    vocab_size: int
    d_model: int = 256
    n_layers: int = 2
    dtype: jnp.dtype = jnp.float32
    @nn.compact
    def __call__(self, input_ids):
        x = nn.Embed(self.vocab_size, self.d_model, dtype=jnp.float32)(input_ids)
        x = x.astype(self.dtype)
        for _ in range(self.n_layers):
            x = SelectiveRecurrentLayer(d_model=self.d_model, dtype=self.dtype)(x)
            x = nn.LayerNorm(dtype=self.dtype)(x)
        return nn.Dense(self.vocab_size, dtype=self.dtype)(x)

# Load weights
model = SubwordLLM(
    vocab_size=config["vocab_size"],
    d_model=config["d_model"],
    n_layers=config["n_layers"],
    dtype=jnp.dtype(config["dtype"])
)

with open("flax_model.msgpack", "rb") as f:
    params = serialization.from_bytes(
        model.init(jax.random.key(0), jnp.ones((1, 128), dtype=jnp.int32)),
        f.read()
    )

# Generation function
def generate(prompt, max_new_tokens=150, temperature=0.7, repetition_penalty=1.2, top_k=25):
    ids = tokenizer.encode(prompt)
    ids = [i for i in ids if i not in (tokenizer.pad_id(), tokenizer.eos_id())]
    generated = ids.copy()
    input_ids = jnp.array([generated], dtype=jnp.int32)

    for _ in range(max_new_tokens):
        logits = model.apply(params, input_ids)
        next_token_logits = logits[0, -1, :]
        
        for tok in set(generated):
            next_token_logits = next_token_logits.at[tok].divide(repetition_penalty)

        if top_k > 0:
            top_k_vals, top_k_idx = jax.lax.top_k(next_token_logits, min(top_k, len(next_token_logits)))
            mask = jnp.full_like(next_token_logits, -1e10)
            mask = mask.at[top_k_idx].set(top_k_vals)
            next_token_logits = mask

        next_token_logits /= temperature
        key = jax.random.key(np.random.randint(0, 2**31 - 1))
        next_token = int(jax.random.categorical(key, next_token_logits))

        if next_token == tokenizer.eos_id():
            break
        generated.append(next_token)
        input_ids = jnp.array([generated], dtype=jnp.int32)

    return tokenizer.decode(generated)

# Generate!
print(generate("once upon a time"))

πŸ“ Sample Output:

once upon a time, there was a little girl named Lily. She loved to play in the park. One day, she found a shiny rock. She showed it to her mom, who smiled and said, β€œThat’s magic!” Lily put it in her pocket and ran home. That night, the rock glowed under her pillow. She dreamed of dragons and stars β€” and woke up with a new friend beside her.

πŸ—οΈ Model Architecture

  • No attention! Uses a Selective State Space Model (SSM) with linear complexity.
  • Input: Subword tokens (SentencePiece Unigram, 5k vocab)
  • Hidden layers: 2 Γ— SelectiveRecurrentLayer (256-dim)
  • Memory: O(n), not O(nΒ²) β€” ideal for long contexts
  • Training: 10k TinyStories, 20 epochs, batch size 32

πŸ“š Training Details

Item Value
Dataset roneneldan/TinyStories (50,000 samples)
Tokenizer SentencePiece Unigram (vocab=5000)
Epochs 50
Loss ~2.42
Max Length 128
Optimizer AdamW + Cosine Decay
Hardware T4 GPU (x2)
Training Time ~15 minutes

Downloads last month
11
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Dataset used to train phmd/TinyStories-SRL-5M