CALM-TinyStories-v1
Continuous Autoregressive Language Model (CALM) β trained by Mohamed Yasser (@yasserrmd)
Reference Paper
Continuous Autoregressive Language Models (CALM)
Weizhe Huang, Zhengxuan Wu, Qian Liu, Amanpreet Singh, et al.
arXiv preprint arXiv:2510.27688 (2025)
This model is a faithful implementation inspired by the CALM framework,
where generation happens in continuous latent space β predicting next latent vectors instead of discrete tokens,
allowing smoother interpolation and faster convergence compared to traditional autoregressive LLMs.
Model Overview
| Component | Description |
|---|---|
| ChunkAutoencoder | Encodes fixed-size token chunks (K=4) into latent vectors and reconstructs them. |
| CALM Transformer | Learns to predict the next latent vector in sequence. |
| Energy Head (optional) | Computes latent energy scores for contrastive fine-tuning or uncertainty modeling. |
The CALM-TinyStories-v1 model learns continuous sequence dynamics trained on the roneneldan/TinyStories dataset using the GPT-2 tokenizer.
Training Details (Full 1 β 41K Steps)
| Stage | Task | Samples | Steps | Final Loss | Hardware |
|---|---|---|---|---|---|
| Autoencoder (AE) | Reconstruction of token sequences | 100K | 6K | β 0.0085 | NVIDIA L4 (24GB VRAM) |
| CALM (Latent Model) | Continuous next-latent prediction | 100K | 1 β 41K | 0.0003 β 0.01 | NVIDIA L4 (24GB VRAM) |
Training Summary
The model was trained end-to-end from step 1 to 41,000, combining latent autoencoding and autoregressive latent forecasting under 18GB VRAM optimization.
Peak memory remained β 11.28 GB, with steady convergence and no instability across runs.
Excerpt from training logs:
π CALM Training (Full 1 β 41K Steps)
CALM step 100/40000 | loss 0.0015 | elapsed 0.1 m
CALM step 500/40000 | loss 0.0003 | elapsed 0.6 m
CALM step 2000/40000 | loss 0.0009 | elapsed 2.3 m
CALM step 5000/40000 | loss 0.0044 | elapsed 5.6 m
CALM step 9000/40000 | loss 0.0003 | elapsed 10.1 m
CALM step 11000/40000| loss 0.0014 | elapsed 12.4 m
β Saved final checkpoint at calm_41000.pt
Configuration (config.json)
{
"model_type": "CALM",
"ae_hidden_dim": 512,
"ae_latent_dim": 256,
"ae_chunk_size_k": 4,
"lm_hidden_dim": 768,
"lm_ffn_dim": 3072,
"lm_num_layers": 12,
"lm_num_heads": 12,
"vocab_size": 50257,
"energy_head_dim": 128,
"energy_num_blocks": 2
}
Model Architecture (ae_model.py)
import torch
import torch.nn as nn
import torch.nn.functional as F
import json, os
class TextAE(nn.Module):
"""
Lightweight text AutoEncoder used at inference:
- encode(tokens) -> latent sequence z [B, T_latent, D]
- decode(z) -> logits over vocab for each recovered token position
This matches the interfaces used during training: chunk K tokens into one latent.
"""
def __init__(self, config_path="config.json"):
super().__init__()
if not os.path.exists(config_path):
raise FileNotFoundError(f"Config not found: {config_path}")
with open(config_path) as f:
cfg = json.load(f)
self.vocab_size = cfg.get("vocab_size", 50257)
self.embed_dim = cfg.get("ae_embed_dim", 512)
self.latent_dim = cfg.get("ae_latent_dim", 256)
self.chunk_k = cfg.get("ae_chunk_size_k", 4)
self.enc_hidden = cfg.get("ae_enc_hidden", 1024)
self.dec_hidden = cfg.get("ae_dec_hidden", 1024)
# Token embedding for encoder side
self.tok_embed = nn.Embedding(self.vocab_size, self.embed_dim)
# Encoder: pooled token embeddings -> latent z
self.encoder = nn.Sequential(
nn.Linear(self.embed_dim, self.enc_hidden),
nn.GELU(),
nn.Linear(self.enc_hidden, self.latent_dim),
)
# Decoder: latent z -> per-token logits for K tokens
# We predict K token logits per latent by sharing a projector then a per-slot head
self.dec_proj = nn.Sequential(
nn.Linear(self.latent_dim, self.dec_hidden),
nn.GELU(),
)
# one classifier per position in the chunk
self.slot_heads = nn.ModuleList([
nn.Linear(self.dec_hidden, self.vocab_size) for _ in range(self.chunk_k)
])
def _chunk_tokens(self, tok_emb: torch.Tensor):
"""
tok_emb: [B, T, E]
returns [B, T_latent, K, E] where T_latent = ceil(T / K), right-padded if needed
"""
B, T, E = tok_emb.shape
K = self.chunk_k
pad = (K - (T % K)) % K
if pad > 0:
pad_emb = torch.zeros(B, pad, E, device=tok_emb.device, dtype=tok_emb.dtype)
tok_emb = torch.cat([tok_emb, pad_emb], dim=1)
T = T + pad
tok_emb = tok_emb.view(B, T // K, K, E)
return tok_emb, pad
def encode(self, input_ids: torch.Tensor):
"""
input_ids: [B, T]
returns:
z: [B, T_latent, D]
meta: dict with padding info
"""
emb = self.tok_embed(input_ids) # [B, T, E]
chunked, pad = self._chunk_tokens(emb) # [B, T_lat, K, E]
pooled = chunked.mean(dim=2) # [B, T_lat, E]
z = self.encoder(pooled) # [B, T_lat, D]
return z, {"pad_tokens": pad}
def decode(self, z: torch.Tensor, meta=None):
"""
z: [B, T_latent, D]
returns:
logits: [B, T_rec, V] where T_rec = T_latent * K
"""
B, T_lat, D = z.shape
h = self.dec_proj(z) # [B, T_lat, H]
# produce K distributions per latent step
slot_logits = []
for head in self.slot_heads:
slot_logits.append(head(h)) # each [B, T_lat, V]
# interleave along the time axis: [B, T_lat, K, V] -> [B, T_lat*K, V]
stacked = torch.stack(slot_logits, dim=2) # [B, T_lat, K, V]
logits = stacked.reshape(B, T_lat * len(self.slot_heads), -1)
return logits
Model Architecture (calm_model.py)
import torch
import torch.nn as nn
import torch.nn.functional as F
import json, os
class CALMBlock(nn.Module):
"""Transformer block for CALM latent prediction."""
def __init__(self, hidden_dim, num_heads, ffn_dim, dropout=0.1):
super().__init__()
self.ln1 = nn.LayerNorm(hidden_dim)
self.attn = nn.MultiheadAttention(hidden_dim, num_heads, dropout=dropout, batch_first=True)
self.ln2 = nn.LayerNorm(hidden_dim)
self.ff = nn.Sequential(
nn.Linear(hidden_dim, ffn_dim),
nn.GELU(),
nn.Linear(ffn_dim, hidden_dim),
)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
attn_out, _ = self.attn(self.ln1(x), self.ln1(x), self.ln1(x))
x = x + self.dropout(attn_out)
ff_out = self.ff(self.ln2(x))
x = x + self.dropout(ff_out)
return x
class CALM(nn.Module):
"""Continuous Autoregressive Language Model (CALM)."""
def __init__(self, config_path="config.json"):
super().__init__()
if not os.path.exists(config_path):
raise FileNotFoundError(f"Config file not found: {config_path}")
with open(config_path) as f:
cfg = json.load(f)
self.latent_dim = cfg.get("ae_latent_dim", 256)
self.hidden_dim = cfg.get("lm_hidden_dim", 768)
self.chunk_k = cfg.get("ae_chunk_size_k", 4)
self.vocab_size = cfg.get("vocab_size", 50257)
self.ffn_dim = cfg.get("lm_ffn_dim", 3072)
self.layers = cfg.get("lm_num_layers", 12)
self.heads = cfg.get("lm_num_heads", 12)
self.energy_dim = cfg.get("energy_head_dim", 128)
self.energy_blocks = cfg.get("energy_num_blocks", 2)
self.latent_proj = nn.Linear(self.latent_dim, self.hidden_dim)
self.blocks = nn.ModuleList([
CALMBlock(self.hidden_dim, self.heads, self.ffn_dim)
for _ in range(self.layers)
])
self.ln_final = nn.LayerNorm(self.hidden_dim)
self.out_proj = nn.Linear(self.hidden_dim, self.latent_dim)
self.energy_head = nn.Sequential(
nn.Linear(self.latent_dim, self.energy_dim),
nn.ReLU(),
*[
nn.Sequential(nn.Linear(self.energy_dim, self.energy_dim), nn.ReLU())
for _ in range(max(self.energy_blocks - 1, 0))
],
nn.Linear(self.energy_dim, 1)
)
def forward(self, z_seq):
x = self.latent_proj(z_seq)
for blk in self.blocks:
x = blk(x)
x = self.ln_final(x)
z_pred = self.out_proj(x[:, -1])
return z_pred
def energy(self, z):
return self.energy_head(z).mean()
Inference
import os, json, torch
from huggingface_hub import snapshot_download
from transformers import GPT2TokenizerFast
from calm_model import CALM
from ae_model import TextAE
@torch.no_grad()
def calm_generate_latents(calm, z_past, steps=32, temperature=0.0):
"""
Autoregress in latent space.
z_past: [B, T_lat, D]
returns: [B, T_lat+steps, D]
"""
device = next(calm.parameters()).device
z_seq = z_past.clone()
for _ in range(steps):
z_next = calm(z_seq) # [B, D]
if temperature and temperature > 0:
# optional Gaussian noise in latent space
z_next = z_next + torch.randn_like(z_next) * temperature
z_seq = torch.cat([z_seq, z_next.unsqueeze(1)], dim=1)
return z_seq
@torch.no_grad()
def ae_decode_tokens(ae, logits, pad_tokens=0, trim_to_multiple_of_k=True):
"""
Convert AE decoder logits to token ids and drop right padding added during encode.
logits: [B, T_rec, V]
"""
token_ids = logits.argmax(dim=-1) # [B, T_rec]
if trim_to_multiple_of_k and pad_tokens > 0:
token_ids = token_ids[:, :-pad_tokens]
return token_ids
def main():
repo_id = "yasserrmd/CALM-TinyStories-v1"
local_dir = snapshot_download(repo_id=repo_id)
# Load config
cfg_path = os.path.join(local_dir, "config.json")
with open(cfg_path) as f:
cfg = json.load(f)
chunk_k = cfg.get("ae_chunk_size_k", 4)
vocab = cfg.get("vocab_size", 50257)
# Tokenizer (GPT2-compatible by default)
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
# add pad token if missing
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
assert tokenizer.vocab_size == vocab or True, "Tokenizer vocab may differ; ensure config matches training."
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load AE
ae = TextAE(config_path=cfg_path).to(device)
ae_ckpt = torch.load(os.path.join(local_dir, "ae_final.pt"), map_location=device)
ae.load_state_dict(ae_ckpt, strict=True)
ae.eval()
# Load CALM
calm = CALM(config_path=cfg_path).to(device)
calm_ckpt = torch.load(os.path.join(local_dir, "calm_final.pt"), map_location=device)
calm.load_state_dict(calm_ckpt, strict=True)
calm.eval()
print("β Loaded AE and CALM")
# Example prompt -> latents
prompt = "Once upon a time"
enc = tokenizer(prompt, return_tensors="pt")
input_ids = enc["input_ids"].to(device) # [1, T]
# AE encode to latents
z_init, meta = ae.encode(input_ids) # [1, T_lat, D], meta['pad_tokens']
# Autoregress in latent space
z_full = calm_generate_latents(calm, z_init, steps=32, temperature=0.0) # extend by 32 latent steps
# Decode latents back to tokens
dec_logits = ae.decode(z_full, meta) # [1, T_rec, V] where T_rec = T_lat_total * K
out_ids = ae_decode_tokens(ae, dec_logits, pad_tokens=meta.get("pad_tokens", 0))
text = tokenizer.decode(out_ids[0], skip_special_tokens=True)
print("\n=== Generated Text ===\n", text)
if __name__ == "__main__":
main()
Performance Summary
- Training range: 1 β 41K steps (AE + CALM)
- Final AE reconstruction loss: 0.0085
- Final CALM latent prediction loss: 0.0003 β 0.01
- Dataset:
roneneldan/TinyStories - Tokenizer: GPT-2 (
AutoTokenizer.from_pretrained("gpt2")) - GPU usage: 11.28 GB peak
- Runtime: ~12 minutes
License
MIT License Β© 2025 Mohamed Yasser
When citing this repository, please credit both:
- Mohamed Yasser β CALM-TinyStories-v1 Implementation
- Original paper β Continuous Autoregressive Language Models (CALM), arXiv:2510.27688 (2025)
---
- Downloads last month
- 34