Tilelli-llm / scripts /train.py
TilelliLab's picture
Mirror small files (code, paper, results)
f86dc09 verified
raw
history blame contribute delete
19.9 kB
#!/usr/bin/env python3
"""scripts/train.py — real Tilelli/Vanilla trainer on TinyStories.
Replaces the smoke ``train_demo.py``. Adds the things a serious run needs:
* train/val split (separate ``.bin`` files produced by ``prepare_tinystories.py``)
* AdamW + cosine LR with warmup
* gradient clipping
* periodic eval-loss against val
* periodic checkpointing + resume from last
* deterministic seed
* a per-run directory under ``runs/`` with config.json + log.jsonl
Models supported via ``--model``:
* ``tilelli-fp32`` — TilelliLM with quantize=False (architecture, FP32 weights)
* ``tilelli-ternary`` — TilelliLM with quantize=True (the default Tilelli model)
* ``vanilla-fp32`` — pre-norm Transformer baseline at the same param budget
The three are param-matched at ~10 M each via the configs in
``scripts/configs.py``.
"""
from __future__ import annotations
import argparse
import json
import math
import os
import random
import sys
import time
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Iterator
# Allow running directly without `pip install -e .`
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
import numpy as np
import torch
from torch import Tensor
from tilelli.baselines.vanilla import VanillaLM
from tilelli.core.tilelli_lm import TilelliLM
def _make_tilelli_lite(cfg, max_seq_len):
from tilelli.core.tilelli_lite import TilelliLiteLM
n_heads = getattr(cfg, "n_heads", 8) or 8
return TilelliLiteLM(
vocab_size=256,
d_model=cfg.d_model,
n_layers=cfg.n_layers,
n_heads=n_heads,
top_k=cfg.top_k or 16,
ffn_expand=cfg.dense_expand or 4,
max_seq_len=max_seq_len,
quantize=cfg.quantize,
)
# ---------------------------------------------------------------------- #
# Configs — three param-matched ~10M models
# ---------------------------------------------------------------------- #
@dataclass
class ModelCfg:
name: str
builder: str # "tilelli" | "vanilla"
quantize: bool
d_model: int
n_layers: int
d_head: int
top_k: int
n_heads: int # vanilla only
expand: int # vanilla only
n_banks: int = 1
per_row: bool = False
hadamard: bool = False
lsq: bool = False
dense_expand: int = 2
fp_attention: bool = False
top_k_routing: int = 0
MODEL_CFGS: dict[str, ModelCfg] = {
"tilelli-fp32": ModelCfg(
name="tilelli-fp32",
builder="tilelli",
quantize=False,
d_model=512,
n_layers=7,
d_head=64,
top_k=8,
n_heads=0,
expand=0,
),
"tilelli-ternary": ModelCfg(
name="tilelli-ternary",
builder="tilelli",
quantize=True,
d_model=512,
n_layers=7,
d_head=64,
top_k=8,
n_heads=0,
expand=0,
),
"vanilla-fp32": ModelCfg(
name="vanilla-fp32",
builder="vanilla",
quantize=False,
d_model=320,
n_layers=8,
d_head=40, # 320/8
top_k=0,
n_heads=8,
expand=4,
),
# === Tilelli Lite — clean 3-pathway sibling (same arch as the deployed v4 chat ckpt) ===
"tilelli-lite-fp32": ModelCfg(
name="tilelli-lite-fp32",
builder="tilelli_lite",
quantize=False,
d_model=256, n_layers=8, d_head=32, top_k=16,
n_heads=8, expand=0, dense_expand=4,
),
"tilelli-lite-ternary": ModelCfg(
name="tilelli-lite-ternary",
builder="tilelli_lite",
quantize=True,
d_model=256, n_layers=8, d_head=32, top_k=16,
n_heads=8, expand=0, dense_expand=4,
),
}
def build_model(cfg: ModelCfg, max_seq_len: int) -> torch.nn.Module:
if cfg.builder == "tilelli":
return TilelliLM(
vocab_size=256,
d_model=cfg.d_model,
n_layers=cfg.n_layers,
d_head=cfg.d_head,
top_k=cfg.top_k,
max_seq_len=max_seq_len,
quantize=cfg.quantize,
n_banks=cfg.n_banks,
per_row=cfg.per_row,
hadamard=cfg.hadamard,
lsq=cfg.lsq,
dense_expand=cfg.dense_expand,
fp_attention=cfg.fp_attention,
top_k_routing=cfg.top_k_routing,
)
if cfg.builder == "vanilla":
return VanillaLM(
vocab_size=256,
d_model=cfg.d_model,
n_layers=cfg.n_layers,
n_heads=cfg.n_heads,
expand=cfg.expand,
max_seq_len=max_seq_len,
)
if cfg.builder == "tilelli_lite":
return _make_tilelli_lite(cfg, max_seq_len)
raise ValueError(f"unknown builder {cfg.builder!r}")
# ---------------------------------------------------------------------- #
# Data — memmap byte arrays, sample random windows
# ---------------------------------------------------------------------- #
class ByteShard:
"""Read-only memmap of a packed uint8 token shard."""
def __init__(self, path: Path) -> None:
self.path = path
self.data = np.memmap(path, dtype=np.uint8, mode="r")
self.n = int(self.data.size)
def sample_batch(self, batch_size: int, seq_len: int, rng: np.random.Generator) -> Tensor:
# +1 for the next-token target slot
max_start = self.n - (seq_len + 1)
starts = rng.integers(0, max_start, size=batch_size)
out = np.empty((batch_size, seq_len + 1), dtype=np.uint8)
for i, s in enumerate(starts):
out[i] = self.data[s : s + seq_len + 1]
return torch.from_numpy(out.astype(np.int64))
def iter_eval_batches(
self, batch_size: int, seq_len: int, n_batches: int, rng: np.random.Generator
) -> Iterator[Tensor]:
for _ in range(n_batches):
yield self.sample_batch(batch_size, seq_len, rng)
class InductionStream:
"""In-memory generator that emits synthetic induction-heads sequences.
Wire-compatible with ByteShard (same .sample_batch / .iter_eval_batches
interface). Each batch is freshly generated from
`tilelli.sherlock.induction_heads.make_induction_batch` — so a "step" of
training sees a fresh patch of (random body) + (planted KEY-VALUE
pattern). The model is trained to do next-token prediction on the whole
sequence; the planted pattern provides a non-trivial signal that only
a model with working in-context recall can exploit.
`n` here is a notional "shard size" so the loss-per-token reporting
in the main train loop has a sane denominator; for the streaming
source it's just the per-sample token count.
"""
def __init__(self, vocab_size: int = 256, min_gap: int = 8) -> None:
self.vocab_size = vocab_size
self.min_gap = min_gap
self.n = 1_000_000 # notional
def sample_batch(self, batch_size: int, seq_len: int, rng: np.random.Generator) -> Tensor:
# Use the DENSE version for training (many patterns per seq), not the
# 1-pattern-per-seq EVAL version. With dense patterns the model gets
# learnable signal at ~50% of positions instead of ~0.4%, so the LM
# cross-entropy loss actually drives induction-head learning.
from tilelli.sherlock.induction_heads import make_dense_induction_batch
seed = int(rng.integers(0, 2**31 - 1))
tgen = torch.Generator()
tgen.manual_seed(seed)
ids = make_dense_induction_batch(
batch_size=batch_size, seq_len=seq_len + 1,
rng=tgen, vocab_size=self.vocab_size, n_keys=16,
min_gap=self.min_gap,
)
return ids
def iter_eval_batches(
self, batch_size: int, seq_len: int, n_batches: int, rng: np.random.Generator
) -> Iterator[Tensor]:
for _ in range(n_batches):
yield self.sample_batch(batch_size, seq_len, rng)
# ---------------------------------------------------------------------- #
# Multi-optimizer wrapper (Muon for 2D weights + AdamW for 1D)
# ---------------------------------------------------------------------- #
class _MultiOptim:
"""Forwards zero_grad / step / state_dict / load_state_dict to a list of
underlying optimizers. Exposes a concatenated param_groups, with each group
annotated with its own peak_lr so the cosine schedule can scale them
proportionally (Muon's effective LR is ~60× AdamW's).
"""
def __init__(self, optims, peak_lrs):
assert len(optims) == len(peak_lrs)
self._optims = list(optims)
for opt, peak in zip(self._optims, peak_lrs):
for g in opt.param_groups:
g["peak_lr"] = peak
@property
def param_groups(self):
groups = []
for opt in self._optims:
groups.extend(opt.param_groups)
return groups
def zero_grad(self, set_to_none=True):
for opt in self._optims:
opt.zero_grad(set_to_none=set_to_none)
def step(self, closure=None):
for opt in self._optims:
opt.step()
def state_dict(self):
return {"optims": [opt.state_dict() for opt in self._optims]}
def load_state_dict(self, sd):
for opt, s in zip(self._optims, sd["optims"]):
opt.load_state_dict(s)
# ---------------------------------------------------------------------- #
# LR schedule
# ---------------------------------------------------------------------- #
def lr_at(step: int, total_steps: int, peak_lr: float, warmup: int, min_ratio: float) -> float:
if step < warmup:
return peak_lr * (step + 1) / max(1, warmup)
progress = (step - warmup) / max(1, total_steps - warmup)
progress = min(1.0, max(0.0, progress))
cosine = 0.5 * (1.0 + math.cos(math.pi * progress))
return peak_lr * (min_ratio + (1.0 - min_ratio) * cosine)
# ---------------------------------------------------------------------- #
# Train loop
# ---------------------------------------------------------------------- #
def evaluate(
model: torch.nn.Module,
val: ByteShard,
batch_size: int,
seq_len: int,
n_batches: int,
rng: np.random.Generator,
device: torch.device,
autocast_dtype=None,
) -> float:
model.eval()
losses: list[float] = []
with torch.no_grad():
for chunk in val.iter_eval_batches(batch_size, seq_len, n_batches, rng):
chunk = chunk.to(device, non_blocking=True)
if autocast_dtype is not None:
with torch.amp.autocast(device.type, dtype=autocast_dtype):
loss = model.loss(chunk[:, :-1], chunk[:, 1:])
else:
loss = model.loss(chunk[:, :-1], chunk[:, 1:])
losses.append(float(loss.item()))
model.train()
return float(np.mean(losses))
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--model", required=True, choices=list(MODEL_CFGS.keys()))
ap.add_argument("--data-dir", type=Path, default=Path("data/tinystories"))
ap.add_argument("--steps", type=int, default=50_000)
ap.add_argument("--seq-len", type=int, default=256)
ap.add_argument("--batch-size", type=int, default=16)
ap.add_argument("--peak-lr", type=float, default=3e-4)
ap.add_argument("--min-lr-ratio", type=float, default=0.01)
ap.add_argument("--warmup", type=int, default=500)
ap.add_argument("--weight-decay", type=float, default=0.01)
ap.add_argument("--grad-clip", type=float, default=1.0)
ap.add_argument("--eval-every", type=int, default=1000)
ap.add_argument("--eval-batches", type=int, default=20)
ap.add_argument("--ckpt-every", type=int, default=2000)
ap.add_argument("--log-every", type=int, default=50)
ap.add_argument("--seed", type=int, default=1234)
ap.add_argument("--threads", type=int, default=8)
ap.add_argument("--device", default="auto",
help="auto | cuda | cpu | cuda:0 etc.")
ap.add_argument("--autocast", default="off",
choices=["off", "bf16", "fp16"],
help="Mixed-precision autocast for forward+backward (CUDA only)")
ap.add_argument("--run-dir", type=Path, default=None,
help="Directory for this run. Defaults to runs/<model>_<timestamp>.")
ap.add_argument("--resume", action="store_true",
help="Resume from runs/<run-dir>/last.pt if present.")
ap.add_argument("--optimizer", default="adamw", choices=["adamw", "muon"],
help="adamw (default) | muon (Muon for 2D+, AdamW for 1D)")
ap.add_argument("--muon-lr-mult", type=float, default=60.0,
help="Muon LR multiplier vs AdamW peak_lr; per Keller Jordan ~60×")
ap.add_argument("--data-source", default="bin",
choices=["bin", "induction"],
help="bin: memmap train.bin/valid.bin (default). "
"induction: generate synthetic induction-heads sequences "
"on the fly (no data-dir needed).")
args = ap.parse_args()
if args.device == "auto":
args.device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(args.device)
if device.type == "cpu":
torch.set_num_threads(args.threads)
if device.type == "cuda":
torch.set_float32_matmul_precision("high")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
autocast_dtype = {"off": None, "bf16": torch.bfloat16, "fp16": torch.float16}[args.autocast]
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
# Run dir
if args.run_dir is None:
ts = time.strftime("%Y-%m-%d_%H-%M-%S")
args.run_dir = Path("runs") / f"{args.model}_{ts}"
args.run_dir.mkdir(parents=True, exist_ok=True)
log_path = args.run_dir / "log.jsonl"
cfg_path = args.run_dir / "config.json"
last_ckpt = args.run_dir / "last.pt"
best_ckpt = args.run_dir / "best.pt"
# Data
if args.data_source == "induction":
# Synthetic induction-heads task — generate batches in-process.
# Train + val use independent RNGs (different seeds) so eval is on
# held-out random patterns the model hasn't seen.
train = InductionStream(vocab_size=256, min_gap=8)
val = InductionStream(vocab_size=256, min_gap=8)
print(f"data: induction-heads (synthetic, vocab=256, min_gap=8)")
else:
train = ByteShard(args.data_dir / "train.bin")
val = ByteShard(args.data_dir / "valid.bin")
print(f"train: {train.n:,} tokens val: {val.n:,} tokens")
# Model
cfg = MODEL_CFGS[args.model]
model = build_model(cfg, max_seq_len=args.seq_len).to(device)
n_params = sum(p.numel() for p in model.parameters())
print(f"model {cfg.name}: {n_params:,} params ({n_params/1e6:.2f}M) on {device}")
if args.optimizer == "muon":
from tilelli.optimisers import Muon, split_params_for_muon
muon_params, adamw_params = split_params_for_muon(model)
muon_peak_lr = args.peak_lr * args.muon_lr_mult
optim_muon = Muon(
muon_params, lr=muon_peak_lr, momentum=0.95,
weight_decay=args.weight_decay, nesterov=True, ns_steps=5,
)
optim_adamw = torch.optim.AdamW(
adamw_params, lr=args.peak_lr,
weight_decay=args.weight_decay, betas=(0.9, 0.95),
)
optim = _MultiOptim([optim_muon, optim_adamw], peak_lrs=[muon_peak_lr, args.peak_lr])
print(f"optimizer: muon ({len(muon_params)} 2D params, lr {muon_peak_lr:.1e}) + adamw ({len(adamw_params)} 1D params, lr {args.peak_lr:.1e})")
else:
optim = torch.optim.AdamW(
model.parameters(),
lr=args.peak_lr,
weight_decay=args.weight_decay,
betas=(0.9, 0.95),
)
# Resume
start_step = 0
best_val = float("inf")
if args.resume and last_ckpt.exists():
sd = torch.load(last_ckpt, map_location="cpu")
model.load_state_dict(sd["model"])
optim.load_state_dict(sd["optim"])
start_step = int(sd.get("step", 0))
best_val = float(sd.get("best_val", float("inf")))
print(f"resumed from {last_ckpt} at step {start_step}, best_val {best_val:.4f}")
# Persist config
cfg_path.write_text(json.dumps({
"model_cfg": asdict(cfg),
"args": {k: (str(v) if isinstance(v, Path) else v) for k, v in vars(args).items()},
"n_params": n_params,
}, indent=2))
log = log_path.open("a", buffering=1)
rng_train = np.random.default_rng(args.seed + 1)
rng_eval = np.random.default_rng(args.seed + 2)
model.train()
t0 = time.time()
last_log_t = t0
running_loss = 0.0
running_n = 0
for step in range(start_step, args.steps):
# LR schedule (per-group peak_lr if present, else args.peak_lr)
lr = lr_at(step, args.steps, args.peak_lr, args.warmup, args.min_lr_ratio)
for g in optim.param_groups:
peak = g.get("peak_lr", args.peak_lr)
g["lr"] = lr_at(step, args.steps, peak, args.warmup, args.min_lr_ratio)
chunk = train.sample_batch(args.batch_size, args.seq_len, rng_train).to(device, non_blocking=True)
optim.zero_grad()
if autocast_dtype is not None:
with torch.amp.autocast(device.type, dtype=autocast_dtype):
loss = model.loss(chunk[:, :-1], chunk[:, 1:])
else:
loss = model.loss(chunk[:, :-1], chunk[:, 1:])
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optim.step()
running_loss += float(loss.item())
running_n += 1
if (step + 1) % args.log_every == 0:
now = time.time()
ms = (now - last_log_t) / args.log_every * 1000
avg = running_loss / max(1, running_n)
print(f"step {step+1:>6d}/{args.steps} loss {avg:.4f} lr {lr:.2e} {ms:.0f} ms/step")
log.write(json.dumps({"event": "train", "step": step+1, "loss": avg, "lr": lr, "ms_per_step": ms}) + "\n")
running_loss = 0.0
running_n = 0
last_log_t = now
if (step + 1) % args.eval_every == 0:
v = evaluate(model, val, args.batch_size, args.seq_len, args.eval_batches, rng_eval, device, autocast_dtype)
print(f" val loss {v:.4f} best {min(best_val, v):.4f}")
log.write(json.dumps({"event": "val", "step": step+1, "val_loss": v, "best_val": min(best_val, v)}) + "\n")
if v < best_val:
best_val = v
torch.save({
"model": model.state_dict(),
"step": step + 1,
"best_val": best_val,
"model_cfg": asdict(cfg),
}, best_ckpt)
if (step + 1) % args.ckpt_every == 0:
torch.save({
"model": model.state_dict(),
"optim": optim.state_dict(),
"step": step + 1,
"best_val": best_val,
"model_cfg": asdict(cfg),
}, last_ckpt)
# Final ckpt + final eval
v_final = evaluate(model, val, args.batch_size, args.seq_len, args.eval_batches, rng_eval, device, autocast_dtype)
log.write(json.dumps({"event": "final", "step": args.steps, "val_loss": v_final, "best_val": min(best_val, v_final), "wall_seconds": time.time()-t0}) + "\n")
torch.save({
"model": model.state_dict(),
"optim": optim.state_dict(),
"step": args.steps,
"best_val": min(best_val, v_final),
"model_cfg": asdict(cfg),
}, last_ckpt)
log.close()
print(f"done. final val {v_final:.4f} best val {min(best_val, v_final):.4f} wall {(time.time()-t0)/3600:.2f}h")
if __name__ == "__main__":
main()