MoAMetricLM‑100M — Mixture of Attentions (MoA)
A geometry‑aware Transformer that mixes several attention mechanisms and routes them with a metric‑based router.
- Parameters: ~185 M (≈ 100 M effective due to the mixture)
- Task: Causal language modeling (decoder‑only)
- Library: 🤗 Transformers
- KV cache: Not yet implemented (generation recomputes the full context at every step)
Model card
| Model ID | reaperdoesntknow/MoA-100M |
|---|---|
| Architecture | moa_metric (custom) |
| Tokenizer | GPT‑2 (gpt2) – pad_token set to eos_token |
| Context length | 2048 tokens |
| Training data | 2 × ≈ 256 k tokens from the datasets listed above |
| Training compute | CPU‑only (Intel), FP32 |
| Training hyper‑parameters | LR = 5e‑4 (AdamW), batch = 4, seq ≤ 512, 500 k total tokens |
| Final loss | ≈ 0.30 (train) |
| License | Apache‑2.0 |
| Safety | No alignment or safety fine‑tuning – outputs may be biased or inaccurate. |
| Intended use | Research on geometry‑aware attention, structured sparsity, and mixture‑of‑attention models. |
| Limitations | • No KV‑cache → slower generation. • Small token budget → not a general‑purpose LM. • No safety/alignment training. |
| Out‑of‑scope | High‑stakes applications (medical, legal, etc.) without further evaluation. |
Overview
MoA replaces the classic dot‑product attention with metric‑based attention and blends four distinct heads per Transformer block:
| Head type | Description |
|---|---|
| LocalConvHead | Depthwise‑separable 1‑D convolution → captures short‑range context. |
| Metric Multi‑Head Attention (MetricMHAttention) | Soft‑min over L2 / cosine / diagonal‑Mahalanobis distances: (\displaystyle \text{attn}_{h}(i,j) \propto \exp!\big(-\alpha_h|q_i-k_j|^2\big)) |
| Metric MQA | Multi‑Query attention (shared K/V) in the same metric space – cheaper than full MHA. |
| ChannelMixHead | Per‑token MLP that mixes channel dimensions (no positional mixing). |
A token‑wise router decides, for each token, which head(s) to use and applies feature‑gates (FiLM‑style) and router‑bias gates for up/down‑scaling.
The FFN is a HyperFFN – three parallel branches (SwiGLU MLP, separable‑conv, low‑rank) combined by a branch router. LayerScale and optional DropPath keep training stable.
Regularisation (optional)
- Triangle‑inequality (TI) penalty on sampled triples to encourage true‑metric behaviour.
- Ball pruning – each head learns an origin (o_h) and radius (r_h); keys outside the ball are masked, giving structured sparsity.
Architecture diagram (high‑level)
Input → Embedding → (PreNorm) → Block₁ → … → Blockₙ → LM‑Head → Output
│
├─ LocalConvHead
├─ MetricMHAttention
├─ MetricMQA
└─ ChannelMixHead
(router decides per‑token)
Each Block also contains:
→ HyperFFN (SwiGLU | Conv | Low‑rank) ← branch router
→ LayerScale + DropPath
Configuration (example)
{
"model_type": "moa_metric",
"vocab_size": 50257,
"dim": 768,
"num_layers": 12,
"attn_heads": 8,
"mqa_q_heads": 8,
"mixer_hidden": 3072,
"ffn_hidden": 3072,
"metric": "l2", // "l2" | "cosine" | "maha_diag"
"alpha_init": 1.0,
"learn_alpha": true,
"use_balls": true,
"radius_init": 3.0,
"learn_radius": true,
"origin_init_scale": 0.0,
"maha_init": 1.0,
"ti_reg_weight": 0.0,
"ti_reg_samples": 0,
"router_hidden": 128,
"router_dropout": 0.1,
"router_temperature": 1.0,
"attn_drop": 0.1,
"proj_drop": 0.1,
"drop_path": 0.0,
"max_position_embeddings": 2048,
"pad_token_id": 50256,
"bos_token_id": 50256,
"eos_token_id": 50256
}
Tip: If you use the GPT‑2 tokenizer, set
pad_token = eos_tokenand make surevocab_sizematches the tokenizer (50257).
Quick‑start (inference)
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> model_id = "reaperdoesntknow/MoA-100M"
>>> tokenizer = AutoTokenizer.from_pretrained(model_id)
>>> tokenizer.pad_token = tokenizer.eos_token # needed for the GPT‑2 tokenizer
>>> model = AutoModelForCausalLM.from_pretrained(model_id)
>>> prompt = "Explain metric‑based attention in simple terms:"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> output_ids = model.generate(
... **inputs,
... max_new_tokens=128,
... do_sample=False, # deterministic; set temperature>0 for sampling
... pad_token_id=tokenizer.pad_token_id,
... )
>>> print(tokenizer.decode(output_ids[0], skip_special_tokens=True))
Note: Because KV‑cache is not implemented, generation time grows linearly with the total context length.
Training (custom loop sketch)
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling
from torch.utils.data import DataLoader
import torch, torch.nn.functional as F
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
def collate_fn(examples):
batch = tokenizer(
[ex["text"] for ex in examples],
padding="max_length",
truncation=True,
max_length=512,
return_tensors="pt",
)
labels = batch["input_ids"].clone()
labels[batch["attention_mask"] == 0] = -100
batch["labels"] = labels
return batch
# dataset = load_dataset(..., split="train") # must contain a 'text' field
# loader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
model = AutoModelForCausalLM.from_pretrained("reaperdoesntknow/MoA-100M")
optimizer = torch.optim.AdamW(
model.parameters(),
lr=5e-4,
betas=(0.9, 0.95),
weight_decay=0.01,
)
for batch in loader:
out = model(**batch)
out.loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.2)
optimizer.step()
optimizer.zero_grad()
Evaluation checklist
- Perplexity on a held‑out split of the two training datasets.
- Ablation studies (keep total token budget constant):
- L2 vs. cosine vs. diagonal‑Mahalanobis distance.
- With / without ball pruning.
- With / without HyperFFN branch router.
- With / without TI regulariser.
- Speed / memory comparison against a vanilla GPT‑2‑size model (same
dim/layers).
Efficiency notes
| Feature | What it does |
|---|---|
| Ball pruning | Masks keys that lie outside a learned radius → reduces the quadratic attention cost. |
| Metric MQA | Shares K/V across heads → fewer projection matrices, lower FLOPs. |
| HyperFFN branch router | Token‑wise top‑k routing means only the most useful branch is evaluated per token. |
| CPU tips | Set OMP_NUM_THREADS / MKL_NUM_THREADS to the number of physical cores; use torch.set_num_threads() if needed. |
Future roadmap: metric‑aware KV‑cache, kernelised distance approximations (e.g., Random Fourier Features), quantisation & mixed‑precision inference.
Safety, Bias & Risks
- The model has not been fine‑tuned for safety or alignment.
- Outputs may contain biases, profanity, or factual errors.
- Do not deploy in high‑stakes contexts without additional evaluation, moderation, and possibly further fine‑tuning.
License
Apache‑2.0 – see the LICENSE file in the repository.
Citation
@misc{moametriclm185m,
title = {reaperdoesntknow/MoA-100M: A Geometry-Aware Mixture-of-Attentions Language Model},
author = {Colca, Roy Shawn and collaborators},
year = {2025},
url = {https://huggingface.co/reaperdoesntknow/MoA-100M}
}
Changelog
| Version | Date | Notes |
|---|---|---|
| v0.2 | 2025‑09‑20 | 500 k‑token CPU run, GPT‑2 tokenizer, LR = 5e‑4, final loss ≈ 0.30. |
| v0.1 | 2025‑09‑20 | Initial public release: metric heads, MQA, ball pruning, HyperFFN, router & gates; HF‑compatible; no KV cache. |
Maintainers
- Author: reaper (Convergent Intelligence LLC)
- Contact: Email ([email protected])*
Special Remarks
- This models still in an extremely experimental state. As are most of them, but im working on stabilizing this one for general inference.
- I design create and train all of my models using my mathematical research and pure disgust for the dot product!
- For those of you who actually read this and use my models, you make my day everytime I see another download, so thank you for being awesome!
- Downloads last month
- 12