CALM-TinyStories-v1

Continuous Autoregressive Language Model (CALM) β€” trained by Mohamed Yasser (@yasserrmd)

Reference Paper

Continuous Autoregressive Language Models (CALM)
Weizhe Huang, Zhengxuan Wu, Qian Liu, Amanpreet Singh, et al.
arXiv preprint arXiv:2510.27688 (2025)

This model is a faithful implementation inspired by the CALM framework,
where generation happens in continuous latent space β€” predicting next latent vectors instead of discrete tokens,
allowing smoother interpolation and faster convergence compared to traditional autoregressive LLMs.


Model Overview

Component Description
ChunkAutoencoder Encodes fixed-size token chunks (K=4) into latent vectors and reconstructs them.
CALM Transformer Learns to predict the next latent vector in sequence.
Energy Head (optional) Computes latent energy scores for contrastive fine-tuning or uncertainty modeling.

The CALM-TinyStories-v1 model learns continuous sequence dynamics trained on the roneneldan/TinyStories dataset using the GPT-2 tokenizer.


Training Details (Full 1 β†’ 41K Steps)

Stage Task Samples Steps Final Loss Hardware
Autoencoder (AE) Reconstruction of token sequences 100K 6K ↓ 0.0085 NVIDIA L4 (24GB VRAM)
CALM (Latent Model) Continuous next-latent prediction 100K 1 β†’ 41K 0.0003 – 0.01 NVIDIA L4 (24GB VRAM)

Training Summary

The model was trained end-to-end from step 1 to 41,000, combining latent autoencoding and autoregressive latent forecasting under 18GB VRAM optimization.
Peak memory remained β‰ˆ 11.28 GB, with steady convergence and no instability across runs.

Excerpt from training logs:


πŸš€ CALM Training (Full 1 β†’ 41K Steps)
CALM step 100/40000  | loss 0.0015  | elapsed 0.1 m
CALM step 500/40000  | loss 0.0003  | elapsed 0.6 m
CALM step 2000/40000 | loss 0.0009  | elapsed 2.3 m
CALM step 5000/40000 | loss 0.0044  | elapsed 5.6 m
CALM step 9000/40000 | loss 0.0003  | elapsed 10.1 m
CALM step 11000/40000| loss 0.0014  | elapsed 12.4 m
βœ“ Saved final checkpoint at calm_41000.pt

Configuration (config.json)

{
  "model_type": "CALM",
  "ae_hidden_dim": 512,
  "ae_latent_dim": 256,
  "ae_chunk_size_k": 4,
  "lm_hidden_dim": 768,
  "lm_ffn_dim": 3072,
  "lm_num_layers": 12,
  "lm_num_heads": 12,
  "vocab_size": 50257,
  "energy_head_dim": 128,
  "energy_num_blocks": 2
}

Model Architecture (ae_model.py)

import torch
import torch.nn as nn
import torch.nn.functional as F
import json, os

class TextAE(nn.Module):
    """
    Lightweight text AutoEncoder used at inference:
    - encode(tokens) -> latent sequence z [B, T_latent, D]
    - decode(z) -> logits over vocab for each recovered token position
    This matches the interfaces used during training: chunk K tokens into one latent.
    """
    def __init__(self, config_path="config.json"):
        super().__init__()
        if not os.path.exists(config_path):
            raise FileNotFoundError(f"Config not found: {config_path}")
        with open(config_path) as f:
            cfg = json.load(f)

        self.vocab_size   = cfg.get("vocab_size", 50257)
        self.embed_dim    = cfg.get("ae_embed_dim", 512)
        self.latent_dim   = cfg.get("ae_latent_dim", 256)
        self.chunk_k      = cfg.get("ae_chunk_size_k", 4)
        self.enc_hidden   = cfg.get("ae_enc_hidden", 1024)
        self.dec_hidden   = cfg.get("ae_dec_hidden", 1024)

        # Token embedding for encoder side
        self.tok_embed = nn.Embedding(self.vocab_size, self.embed_dim)

        # Encoder: pooled token embeddings -> latent z
        self.encoder = nn.Sequential(
            nn.Linear(self.embed_dim, self.enc_hidden),
            nn.GELU(),
            nn.Linear(self.enc_hidden, self.latent_dim),
        )

        # Decoder: latent z -> per-token logits for K tokens
        # We predict K token logits per latent by sharing a projector then a per-slot head
        self.dec_proj = nn.Sequential(
            nn.Linear(self.latent_dim, self.dec_hidden),
            nn.GELU(),
        )
        # one classifier per position in the chunk
        self.slot_heads = nn.ModuleList([
            nn.Linear(self.dec_hidden, self.vocab_size) for _ in range(self.chunk_k)
        ])

    def _chunk_tokens(self, tok_emb: torch.Tensor):
        """
        tok_emb: [B, T, E]
        returns [B, T_latent, K, E] where T_latent = ceil(T / K), right-padded if needed
        """
        B, T, E = tok_emb.shape
        K = self.chunk_k
        pad = (K - (T % K)) % K
        if pad > 0:
            pad_emb = torch.zeros(B, pad, E, device=tok_emb.device, dtype=tok_emb.dtype)
            tok_emb = torch.cat([tok_emb, pad_emb], dim=1)
            T = T + pad
        tok_emb = tok_emb.view(B, T // K, K, E)
        return tok_emb, pad

    def encode(self, input_ids: torch.Tensor):
        """
        input_ids: [B, T]
        returns:
          z: [B, T_latent, D]
          meta: dict with padding info
        """
        emb = self.tok_embed(input_ids)              # [B, T, E]
        chunked, pad = self._chunk_tokens(emb)       # [B, T_lat, K, E]
        pooled = chunked.mean(dim=2)                 # [B, T_lat, E]
        z = self.encoder(pooled)                     # [B, T_lat, D]
        return z, {"pad_tokens": pad}

    def decode(self, z: torch.Tensor, meta=None):
        """
        z: [B, T_latent, D]
        returns:
          logits: [B, T_rec, V] where T_rec = T_latent * K
        """
        B, T_lat, D = z.shape
        h = self.dec_proj(z)                         # [B, T_lat, H]
        # produce K distributions per latent step
        slot_logits = []
        for head in self.slot_heads:
            slot_logits.append(head(h))              # each [B, T_lat, V]
        # interleave along the time axis: [B, T_lat, K, V] -> [B, T_lat*K, V]
        stacked = torch.stack(slot_logits, dim=2)    # [B, T_lat, K, V]
        logits = stacked.reshape(B, T_lat * len(self.slot_heads), -1)
        return logits

Model Architecture (calm_model.py)

import torch
import torch.nn as nn
import torch.nn.functional as F
import json, os


class CALMBlock(nn.Module):
    """Transformer block for CALM latent prediction."""
    def __init__(self, hidden_dim, num_heads, ffn_dim, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(hidden_dim)
        self.attn = nn.MultiheadAttention(hidden_dim, num_heads, dropout=dropout, batch_first=True)
        self.ln2 = nn.LayerNorm(hidden_dim)
        self.ff = nn.Sequential(
            nn.Linear(hidden_dim, ffn_dim),
            nn.GELU(),
            nn.Linear(ffn_dim, hidden_dim),
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        attn_out, _ = self.attn(self.ln1(x), self.ln1(x), self.ln1(x))
        x = x + self.dropout(attn_out)
        ff_out = self.ff(self.ln2(x))
        x = x + self.dropout(ff_out)
        return x


class CALM(nn.Module):
    """Continuous Autoregressive Language Model (CALM)."""
    def __init__(self, config_path="config.json"):
        super().__init__()
        if not os.path.exists(config_path):
            raise FileNotFoundError(f"Config file not found: {config_path}")
        with open(config_path) as f:
            cfg = json.load(f)

        self.latent_dim = cfg.get("ae_latent_dim", 256)
        self.hidden_dim = cfg.get("lm_hidden_dim", 768)
        self.chunk_k = cfg.get("ae_chunk_size_k", 4)
        self.vocab_size = cfg.get("vocab_size", 50257)
        self.ffn_dim = cfg.get("lm_ffn_dim", 3072)
        self.layers = cfg.get("lm_num_layers", 12)
        self.heads = cfg.get("lm_num_heads", 12)
        self.energy_dim = cfg.get("energy_head_dim", 128)
        self.energy_blocks = cfg.get("energy_num_blocks", 2)

        self.latent_proj = nn.Linear(self.latent_dim, self.hidden_dim)
        self.blocks = nn.ModuleList([
            CALMBlock(self.hidden_dim, self.heads, self.ffn_dim)
            for _ in range(self.layers)
        ])
        self.ln_final = nn.LayerNorm(self.hidden_dim)
        self.out_proj = nn.Linear(self.hidden_dim, self.latent_dim)
        self.energy_head = nn.Sequential(
            nn.Linear(self.latent_dim, self.energy_dim),
            nn.ReLU(),
            *[
                nn.Sequential(nn.Linear(self.energy_dim, self.energy_dim), nn.ReLU())
                for _ in range(max(self.energy_blocks - 1, 0))
            ],
            nn.Linear(self.energy_dim, 1)
        )

    def forward(self, z_seq):
        x = self.latent_proj(z_seq)
        for blk in self.blocks:
            x = blk(x)
        x = self.ln_final(x)
        z_pred = self.out_proj(x[:, -1])
        return z_pred

    def energy(self, z):
        return self.energy_head(z).mean()

Inference

import os, json, torch
from huggingface_hub import snapshot_download
from transformers import GPT2TokenizerFast
from calm_model import CALM
from ae_model import TextAE

@torch.no_grad()
def calm_generate_latents(calm, z_past, steps=32, temperature=0.0):
    """
    Autoregress in latent space.
    z_past: [B, T_lat, D]
    returns: [B, T_lat+steps, D]
    """
    device = next(calm.parameters()).device
    z_seq = z_past.clone()
    for _ in range(steps):
        z_next = calm(z_seq)                        # [B, D]
        if temperature and temperature > 0:
            # optional Gaussian noise in latent space
            z_next = z_next + torch.randn_like(z_next) * temperature
        z_seq = torch.cat([z_seq, z_next.unsqueeze(1)], dim=1)
    return z_seq

@torch.no_grad()
def ae_decode_tokens(ae, logits, pad_tokens=0, trim_to_multiple_of_k=True):
    """
    Convert AE decoder logits to token ids and drop right padding added during encode.
    logits: [B, T_rec, V]
    """
    token_ids = logits.argmax(dim=-1)               # [B, T_rec]
    if trim_to_multiple_of_k and pad_tokens > 0:
        token_ids = token_ids[:, :-pad_tokens]
    return token_ids

def main():
    repo_id = "yasserrmd/CALM-TinyStories-v1"
    local_dir = snapshot_download(repo_id=repo_id)

    # Load config
    cfg_path = os.path.join(local_dir, "config.json")
    with open(cfg_path) as f:
        cfg = json.load(f)
    chunk_k   = cfg.get("ae_chunk_size_k", 4)
    vocab     = cfg.get("vocab_size", 50257)

    # Tokenizer (GPT2-compatible by default)
    tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
    # add pad token if missing
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token = tokenizer.eos_token
    assert tokenizer.vocab_size == vocab or True, "Tokenizer vocab may differ; ensure config matches training."

    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Load AE
    ae = TextAE(config_path=cfg_path).to(device)
    ae_ckpt = torch.load(os.path.join(local_dir, "ae_final.pt"), map_location=device)
    ae.load_state_dict(ae_ckpt, strict=True)
    ae.eval()

    # Load CALM
    calm = CALM(config_path=cfg_path).to(device)
    calm_ckpt = torch.load(os.path.join(local_dir, "calm_final.pt"), map_location=device)
    calm.load_state_dict(calm_ckpt, strict=True)
    calm.eval()

    print("βœ“ Loaded AE and CALM")

    # Example prompt -> latents
    prompt = "Once upon a time"
    enc = tokenizer(prompt, return_tensors="pt")
    input_ids = enc["input_ids"].to(device)         # [1, T]

    # AE encode to latents
    z_init, meta = ae.encode(input_ids)             # [1, T_lat, D], meta['pad_tokens']
    # Autoregress in latent space
    z_full = calm_generate_latents(calm, z_init, steps=32, temperature=0.0)  # extend by 32 latent steps
    # Decode latents back to tokens
    dec_logits = ae.decode(z_full, meta)            # [1, T_rec, V] where T_rec = T_lat_total * K
    out_ids = ae_decode_tokens(ae, dec_logits, pad_tokens=meta.get("pad_tokens", 0))

    text = tokenizer.decode(out_ids[0], skip_special_tokens=True)
    print("\n=== Generated Text ===\n", text)

if __name__ == "__main__":
    main()

Performance Summary

  • Training range: 1 β†’ 41K steps (AE + CALM)
  • Final AE reconstruction loss: 0.0085
  • Final CALM latent prediction loss: 0.0003 – 0.01
  • Dataset: roneneldan/TinyStories
  • Tokenizer: GPT-2 (AutoTokenizer.from_pretrained("gpt2"))
  • GPU usage: 11.28 GB peak
  • Runtime: ~12 minutes

License

MIT License Β© 2025 Mohamed Yasser

When citing this repository, please credit both:

  1. Mohamed Yasser β€” CALM-TinyStories-v1 Implementation
  2. Original paper β€” Continuous Autoregressive Language Models (CALM), arXiv:2510.27688 (2025)

---
Downloads last month
34
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Dataset used to train yasserrmd/CALM-TinyStories-v1