MoAMetricLM‑100M — Mixture of Attentions (MoA)

A geometry‑aware Transformer that mixes several attention mechanisms and routes them with a metric‑based router.

  • Parameters: ~185 M (≈ 100 M effective due to the mixture)
  • Task: Causal language modeling (decoder‑only)
  • Library: 🤗 Transformers
  • KV cache: Not yet implemented (generation recomputes the full context at every step)

Model card

Model ID reaperdoesntknow/MoA-100M
Architecture moa_metric (custom)
Tokenizer GPT‑2 (gpt2) – pad_token set to eos_token
Context length 2048 tokens
Training data 2 × ≈ 256 k tokens from the datasets listed above
Training compute CPU‑only (Intel), FP32
Training hyper‑parameters LR = 5e‑4 (AdamW), batch = 4, seq ≤ 512, 500 k total tokens
Final loss ≈ 0.30 (train)
License Apache‑2.0
Safety No alignment or safety fine‑tuning – outputs may be biased or inaccurate.
Intended use Research on geometry‑aware attention, structured sparsity, and mixture‑of‑attention models.
Limitations • No KV‑cache → slower generation.
• Small token budget → not a general‑purpose LM.
• No safety/alignment training.
Out‑of‑scope High‑stakes applications (medical, legal, etc.) without further evaluation.

Overview

MoA replaces the classic dot‑product attention with metric‑based attention and blends four distinct heads per Transformer block:

Head type Description
LocalConvHead Depthwise‑separable 1‑D convolution → captures short‑range context.
Metric Multi‑Head Attention (MetricMHAttention) Soft‑min over L2 / cosine / diagonal‑Mahalanobis distances:
(\displaystyle \text{attn}_{h}(i,j) \propto \exp!\big(-\alpha_h|q_i-k_j|^2\big))
Metric MQA Multi‑Query attention (shared K/V) in the same metric space – cheaper than full MHA.
ChannelMixHead Per‑token MLP that mixes channel dimensions (no positional mixing).

A token‑wise router decides, for each token, which head(s) to use and applies feature‑gates (FiLM‑style) and router‑bias gates for up/down‑scaling.

The FFN is a HyperFFN – three parallel branches (SwiGLU MLP, separable‑conv, low‑rank) combined by a branch router. LayerScale and optional DropPath keep training stable.

Regularisation (optional)

  • Triangle‑inequality (TI) penalty on sampled triples to encourage true‑metric behaviour.
  • Ball pruning – each head learns an origin (o_h) and radius (r_h); keys outside the ball are masked, giving structured sparsity.

Architecture diagram (high‑level)

Input → Embedding → (PreNorm) → Block₁ → … → Blockₙ → LM‑Head → Output
                     │
                     ├─ LocalConvHead
                     ├─ MetricMHAttention
                     ├─ MetricMQA
                     └─ ChannelMixHead
                     (router decides per‑token)

Each Block also contains:
  → HyperFFN (SwiGLU | Conv | Low‑rank)  ← branch router
  → LayerScale + DropPath

Configuration (example)

{
  "model_type": "moa_metric",
  "vocab_size": 50257,
  "dim": 768,
  "num_layers": 12,
  "attn_heads": 8,
  "mqa_q_heads": 8,
  "mixer_hidden": 3072,
  "ffn_hidden": 3072,
  "metric": "l2",                     // "l2" | "cosine" | "maha_diag"
  "alpha_init": 1.0,
  "learn_alpha": true,
  "use_balls": true,
  "radius_init": 3.0,
  "learn_radius": true,
  "origin_init_scale": 0.0,
  "maha_init": 1.0,
  "ti_reg_weight": 0.0,
  "ti_reg_samples": 0,
  "router_hidden": 128,
  "router_dropout": 0.1,
  "router_temperature": 1.0,
  "attn_drop": 0.1,
  "proj_drop": 0.1,
  "drop_path": 0.0,
  "max_position_embeddings": 2048,
  "pad_token_id": 50256,
  "bos_token_id": 50256,
  "eos_token_id": 50256
}

Tip: If you use the GPT‑2 tokenizer, set pad_token = eos_token and make sure vocab_size matches the tokenizer (50257).


Quick‑start (inference)

>>> from transformers import AutoTokenizer, AutoModelForCausalLM

>>> model_id = "reaperdoesntknow/MoA-100M"
>>> tokenizer = AutoTokenizer.from_pretrained(model_id)
>>> tokenizer.pad_token = tokenizer.eos_token   # needed for the GPT‑2 tokenizer

>>> model = AutoModelForCausalLM.from_pretrained(model_id)

>>> prompt = "Explain metric‑based attention in simple terms:"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> output_ids = model.generate(
...     **inputs,
...     max_new_tokens=128,
...     do_sample=False,          # deterministic; set temperature>0 for sampling
...     pad_token_id=tokenizer.pad_token_id,
... )
>>> print(tokenizer.decode(output_ids[0], skip_special_tokens=True))

Note: Because KV‑cache is not implemented, generation time grows linearly with the total context length.


Training (custom loop sketch)

from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling
from torch.utils.data import DataLoader
import torch, torch.nn.functional as F

tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

def collate_fn(examples):
    batch = tokenizer(
        [ex["text"] for ex in examples],
        padding="max_length",
        truncation=True,
        max_length=512,
        return_tensors="pt",
    )
    labels = batch["input_ids"].clone()
    labels[batch["attention_mask"] == 0] = -100
    batch["labels"] = labels
    return batch

# dataset = load_dataset(..., split="train")  # must contain a 'text' field
# loader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)

model = AutoModelForCausalLM.from_pretrained("reaperdoesntknow/MoA-100M")
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=5e-4,
    betas=(0.9, 0.95),
    weight_decay=0.01,
)

for batch in loader:
    out = model(**batch)
    out.loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.2)
    optimizer.step()
    optimizer.zero_grad()

Evaluation checklist

  • Perplexity on a held‑out split of the two training datasets.
  • Ablation studies (keep total token budget constant):
    • L2 vs. cosine vs. diagonal‑Mahalanobis distance.
    • With / without ball pruning.
    • With / without HyperFFN branch router.
    • With / without TI regulariser.
  • Speed / memory comparison against a vanilla GPT‑2‑size model (same dim/layers).

Efficiency notes

Feature What it does
Ball pruning Masks keys that lie outside a learned radius → reduces the quadratic attention cost.
Metric MQA Shares K/V across heads → fewer projection matrices, lower FLOPs.
HyperFFN branch router Token‑wise top‑k routing means only the most useful branch is evaluated per token.
CPU tips Set OMP_NUM_THREADS / MKL_NUM_THREADS to the number of physical cores; use torch.set_num_threads() if needed.

Future roadmap: metric‑aware KV‑cache, kernelised distance approximations (e.g., Random Fourier Features), quantisation & mixed‑precision inference.


Safety, Bias & Risks

  • The model has not been fine‑tuned for safety or alignment.
  • Outputs may contain biases, profanity, or factual errors.
  • Do not deploy in high‑stakes contexts without additional evaluation, moderation, and possibly further fine‑tuning.

License

Apache‑2.0 – see the LICENSE file in the repository.


Citation

@misc{moametriclm185m,
  title   = {reaperdoesntknow/MoA-100M: A Geometry-Aware Mixture-of-Attentions Language Model},
  author  = {Colca, Roy Shawn and collaborators},
  year    = {2025},
  url     = {https://huggingface.co/reaperdoesntknow/MoA-100M}
}

Changelog

Version Date Notes
v0.2 2025‑09‑20 500 k‑token CPU run, GPT‑2 tokenizer, LR = 5e‑4, final loss ≈ 0.30.
v0.1 2025‑09‑20 Initial public release: metric heads, MQA, ball pruning, HyperFFN, router & gates; HF‑compatible; no KV cache.

Maintainers


Special Remarks

  • This models still in an extremely experimental state. As are most of them, but im working on stabilizing this one for general inference.
  • I design create and train all of my models using my mathematical research and pure disgust for the dot product!
  • For those of you who actually read this and use my models, you make my day everytime I see another download, so thank you for being awesome!
Downloads last month
12
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Datasets used to train reaperdoesntknow/MoA-100M

Collection including reaperdoesntknow/MoA-100M