| | print("Starting...") |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | TXT_PATH = "data.txt" |
| | TOKENIZER_NAME = "gpt2" |
| | REDUCE_VOCAB = True |
| | VOCAB_SAVE_PATH = "vocab_map.pt" |
| |
|
| | |
| | EPOCHS = 25 |
| | MICRO_BATCH_SIZE = 1 |
| | GRAD_ACCUM_STEPS = 8 |
| | LEARNING_RATE = 3e-4 |
| |
|
| | |
| | D_MODEL = 256 |
| | N_LAYERS = 4 |
| | MAX_SEQ_LEN = 8192 |
| |
|
| | LOCAL_KERNEL_SIZE = 5 |
| | GLOBAL_KERNEL_SIZE = 256 |
| | USE_GLOBAL_EVERY_N_LAYERS = 2 |
| |
|
| | |
| | FFT_SIZE = 1024 |
| |
|
| | |
| | SAVE_PATH = "model.pt" |
| | SAVE_N_EPOCHS = 1 |
| |
|
| | |
| | USE_DEVICE = "cuda" |
| | USE_AMP = True |
| | USE_ACTIVATION_CHECKPOINTING = False |
| |
|
| | |
| | COMPILE = False |
| | COMPILE_MODE = "reduce-overhead" |
| | COMPILE_BACKEND = "eager" |
| |
|
| | |
| | |
| | |
| |
|
| | import os |
| |
|
| | |
| | if os.name != "nt": |
| | os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.utils.data import Dataset, DataLoader |
| | from tqdm import tqdm |
| | import tiktoken |
| |
|
| | |
| | torch.set_float32_matmul_precision("high") |
| | torch.backends.cuda.matmul.allow_tf32 = True |
| | torch.backends.cudnn.allow_tf32 = True |
| |
|
| | |
| | |
| | |
| |
|
| | PAD_ID = 0 |
| | SEP_ID = 1 |
| | EOS_ID = 2 |
| | OFFSET = 3 |
| |
|
| | |
| | |
| | |
| |
|
| | def build_dataset_vocab(txt_path, tokenizer, save_path): |
| | text = open(txt_path, "r", encoding="utf-8").read() |
| | token_ids = tokenizer.encode(text) |
| | used = sorted(set(token_ids)) |
| |
|
| | id2new = {tok: i + OFFSET for i, tok in enumerate(used)} |
| |
|
| | torch.save({ |
| | "used_tokens": used, |
| | "id2new": id2new, |
| | "PAD_ID": PAD_ID, |
| | "SEP_ID": SEP_ID, |
| | "EOS_ID": EOS_ID, |
| | }, save_path) |
| |
|
| | print(f"[OK] Vocab size: {len(used) + OFFSET}") |
| | return used, id2new |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class RemappedTextDataset(Dataset): |
| | def __init__(self, path, tokenizer, id2new, max_len): |
| | text = open(path, "r", encoding="utf-8").read() |
| | raw = tokenizer.encode(text) |
| | self.ids = [id2new.get(i, PAD_ID) for i in raw] + [EOS_ID] |
| | self.max_len = max_len |
| |
|
| | def __len__(self): |
| | return len(self.ids) - self.max_len - 1 |
| |
|
| | def __getitem__(self, i): |
| | x = self.ids[i:i+self.max_len] |
| | y = self.ids[i+1:i+self.max_len+1] |
| | return torch.tensor(x), torch.tensor(y) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class GlobalConv1D(nn.Module): |
| | def __init__(self, d_model, kernel_size, fft_size): |
| | super().__init__() |
| | self.kernel = nn.Parameter(torch.randn(d_model, kernel_size) * 0.01) |
| | self.kernel_size = kernel_size |
| | self.fft_size = fft_size |
| |
|
| | def forward(self, x): |
| | B, C, T = x.shape |
| | K = min(self.kernel_size, T) |
| |
|
| | overlap = K - 1 |
| | block = self.fft_size - overlap |
| |
|
| | x = F.pad(x, (overlap, 0)) |
| | k = self.kernel[:, :K] |
| | k = F.pad(k, (0, self.fft_size - K)) |
| | k_f = torch.fft.rfft(k, n=self.fft_size) |
| |
|
| | outs = [] |
| | pos = 0 |
| | while pos < T: |
| | seg = x[..., pos:pos+self.fft_size] |
| | if seg.shape[-1] < self.fft_size: |
| | seg = F.pad(seg, (0, self.fft_size - seg.shape[-1])) |
| |
|
| | y = torch.fft.irfft( |
| | torch.fft.rfft(seg, n=self.fft_size) * k_f.unsqueeze(0), |
| | n=self.fft_size |
| | ) |
| | outs.append(y[..., overlap:overlap+block]) |
| | pos += block |
| |
|
| | return torch.cat(outs, dim=-1)[..., :T] |
| |
|
| |
|
| | class LocalConv1D(nn.Module): |
| | def __init__(self, d_model, k): |
| | super().__init__() |
| | self.k = k |
| | self.dw = nn.Conv1d(d_model, d_model, k, groups=d_model) |
| | self.pw = nn.Conv1d(d_model, d_model, 1) |
| |
|
| | def forward(self, x): |
| | x = F.pad(x, (self.k - 1, 0)) |
| | return self.pw(F.relu(self.dw(x))) |
| |
|
| |
|
| | class Block(nn.Module): |
| | def __init__(self, d_model, use_global): |
| | super().__init__() |
| | self.use_global = use_global |
| |
|
| | self.ln1 = nn.LayerNorm(d_model) |
| | self.local = LocalConv1D(d_model, LOCAL_KERNEL_SIZE) |
| |
|
| | if use_global: |
| | self.ln2 = nn.LayerNorm(d_model) |
| | self.global_conv = GlobalConv1D(d_model, GLOBAL_KERNEL_SIZE, FFT_SIZE) |
| |
|
| | self.ln3 = nn.LayerNorm(d_model) |
| | self.ff = nn.Sequential( |
| | nn.Linear(d_model, d_model*4), |
| | nn.GELU(), |
| | nn.Linear(d_model*4, d_model) |
| | ) |
| |
|
| | def forward(self, x): |
| | x = x + self.local(self.ln1(x).transpose(1,2)).transpose(1,2) |
| | if self.use_global: |
| | x = x + self.global_conv(self.ln2(x).transpose(1,2)).transpose(1,2) |
| | return x + self.ff(self.ln3(x)) |
| |
|
| |
|
| | class GCLM(nn.Module): |
| | def __init__(self, vocab): |
| | super().__init__() |
| | self.emb = nn.Embedding(vocab, D_MODEL) |
| | self.pos = nn.Embedding(MAX_SEQ_LEN, D_MODEL) |
| |
|
| | self.layers = nn.ModuleList([ |
| | Block(D_MODEL, i % USE_GLOBAL_EVERY_N_LAYERS == 0) |
| | for i in range(N_LAYERS) |
| | ]) |
| |
|
| | self.ln = nn.LayerNorm(D_MODEL) |
| | self.head = nn.Linear(D_MODEL, vocab) |
| |
|
| | def forward(self, x): |
| | T = x.size(1) |
| | h = self.emb(x) + self.pos(torch.arange(T, device=x.device)) |
| | for layer in self.layers: |
| | h = layer(h) |
| | return self.head(self.ln(h)) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def train(): |
| | device = USE_DEVICE if torch.cuda.is_available() else "cpu" |
| | print("[INFO] Device:", device) |
| |
|
| | tok = tiktoken.get_encoding(TOKENIZER_NAME) |
| | used, id2new = build_dataset_vocab(TXT_PATH, tok, VOCAB_SAVE_PATH) |
| | vocab = len(used) + OFFSET |
| |
|
| | ds = RemappedTextDataset(TXT_PATH, tok, id2new, MAX_SEQ_LEN) |
| | dl = DataLoader(ds, batch_size=MICRO_BATCH_SIZE, shuffle=True) |
| |
|
| | model = GCLM(vocab).to(device) |
| |
|
| | |
| | if os.path.exists(SAVE_PATH): |
| | model.load_state_dict(torch.load(SAVE_PATH, map_location=device)) |
| | print(f"[RESUME] Loaded existing checkpoint from {SAVE_PATH}") |
| |
|
| | if device == "cuda" and COMPILE: |
| | print("[INFO] Compiling model with torch.compile...") |
| | model = torch.compile( |
| | model, |
| | mode=COMPILE_MODE, |
| | fullgraph=False, |
| | backend=COMPILE_BACKEND |
| | ) |
| |
|
| | opt = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) |
| | loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID) |
| |
|
| | scaler = torch.amp.GradScaler("cuda", enabled=(device=="cuda" and USE_AMP)) |
| |
|
| | for ep in range(EPOCHS): |
| | print(f"\nEpoch {ep+1}/{EPOCHS}") |
| | opt.zero_grad(set_to_none=True) |
| |
|
| | for i, (x, y) in enumerate(tqdm(dl)): |
| | x, y = x.to(device), y.to(device) |
| |
|
| | with torch.amp.autocast("cuda", enabled=(device=="cuda" and USE_AMP)): |
| | logits = model(x) |
| | loss = loss_fn(logits.reshape(-1, vocab), y.reshape(-1)) |
| | loss = loss / GRAD_ACCUM_STEPS |
| |
|
| | scaler.scale(loss).backward() |
| |
|
| | if (i+1) % GRAD_ACCUM_STEPS == 0: |
| | scaler.step(opt) |
| | scaler.update() |
| | opt.zero_grad(set_to_none=True) |
| |
|
| | if SAVE_N_EPOCHS and (ep+1) % SAVE_N_EPOCHS == 0: |
| | torch.save(model.state_dict(), SAVE_PATH) |
| | print("[OK] Saved checkpoint.") |
| |
|
| | torch.save(model.state_dict(), SAVE_PATH) |
| | print("[DONE] Training complete.") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | if __name__ == "__main__": |
| | train() |
| |
|