| |
| import torch |
| from torch.utils.data import DataLoader |
| from transformers import ( |
| AutoTokenizer, |
| TrainingArguments, |
| Trainer, |
| default_data_collator, |
| ) |
| from datasets import load_dataset |
| from myolmoe import MyOlmoeForCausalLM, OlmoeConfig |
| import os |
| from transformers import TrainerCallback |
| import subprocess |
|
|
| def main(): |
| print("Starting my COOL OLMoE training script for small experts") |
| |
| config_path = os.path.join("myolmoe", "config.json") |
| if os.path.exists(config_path): |
| config = OlmoeConfig.from_json_file(config_path) |
| else: |
| config = OlmoeConfig.from_pretrained("myolmoe") |
| |
| |
| model = MyOlmoeForCausalLM.from_pretrained( |
| "myolmoe", |
| config=config, |
| torch_dtype=torch.bfloat16, |
| device_map="auto", |
| ignore_mismatched_sizes=True |
| ) |
| |
| |
| tokenizer = AutoTokenizer.from_pretrained("myolmoe") |
| tokenizer.pad_token = tokenizer.eos_token |
| |
| |
| dataset = load_dataset("allenai/tulu-v2-sft-mixture", split="train") |
| |
| def tokenize_function(examples): |
| texts = [] |
| for message_list in examples["messages"]: |
| formatted = "" |
| for msg in message_list: |
| role = msg["role"] |
| content = msg["content"] |
| if role == "user": |
| formatted += f"User: {content}\n" |
| elif role == "assistant": |
| formatted += f"Assistant: {content}\n" |
| else: |
| formatted += f"{role.capitalize()}: {content}\n" |
| texts.append(formatted) |
|
|
| tokenized = tokenizer( |
| texts, |
| truncation=True, |
| max_length=4096, |
| padding="max_length" |
| ) |
| tokenized["labels"] = tokenized["input_ids"].copy() |
| return tokenized |
|
|
| tokenized_dataset = dataset.map( |
| tokenize_function, |
| batched=True, |
| remove_columns=dataset.column_names, |
| num_proc=4 |
| ) |
| |
| |
| training_args = TrainingArguments( |
| output_dir="./checkpoints", |
| per_device_train_batch_size=2, |
| gradient_accumulation_steps=8, |
| learning_rate=1e-4, |
| num_train_epochs=3, |
| logging_dir="./logs", |
| logging_steps=10, |
| save_steps=20, |
| save_total_limit=1, |
| bf16=True, |
| gradient_checkpointing=False, |
| report_to="tensorboard", |
| optim="adamw_torch", |
| lr_scheduler_type="cosine", |
| warmup_ratio=0.1, |
| max_grad_norm=1.0, |
| ) |
| |
| |
| for param in model.parameters(): |
| param.requires_grad = False |
| |
| |
| trainable_params = [] |
| for name, param in model.named_parameters(): |
| if ( |
| "small_experts" in name or |
| "small_gate" in name |
| ): |
| param.requires_grad = True |
| trainable_params.append(name) |
| |
| if trainable_params: |
| print(f"[INFO] Found {len(trainable_params)} small_expert/small_gate parameters.") |
| else: |
| print("[WARNING] No small_expert or small_gate parameters found in model!") |
|
|
| |
| unfrozen = [name for name, param in model.named_parameters() if param.requires_grad] |
| if unfrozen: |
| print(f"[INFO] {len(unfrozen)} parameters are unfrozen and trainable.") |
| for name in unfrozen: |
| print(f" - {name}") |
| else: |
| print("[ERROR] No parameters were unfrozen! Training will not update anything.") |
|
|
| print(f"Total trainable parameters: {len(trainable_params)}") |
| |
| |
| for name, param in model.named_parameters(): |
| if param.requires_grad: |
| print(f"Parameter {name} requires grad: {param.requires_grad}") |
|
|
| |
| def data_collator(features): |
| batch = default_data_collator(features) |
| batch["output_router_logits"] = True |
| return batch |
|
|
| |
| class CustomTrainer(Trainer): |
| def compute_loss(self, model, inputs, return_outputs=False, **kwargs): |
| |
| inputs = {k: v for k, v in inputs.items() if k not in ['num_items_in_batch']} |
| |
| |
| model.train() |
| |
| |
| with torch.set_grad_enabled(True): |
| outputs = model(**inputs) |
| loss = outputs.loss |
| |
| if not loss.requires_grad: |
| raise RuntimeError("Loss doesn't require gradients. Check model parameters.") |
| |
| return (loss, outputs) if return_outputs else loss |
| |
| class GitPushCallback(TrainerCallback): |
| def on_save(self, args, state, control, **kwargs): |
| try: |
| print("Saving checkpoint to Git repo...") |
| |
| |
| subprocess.run(["git", "add", "."], check=True) |
|
|
| |
| result = subprocess.run(["git", "diff", "--cached", "--quiet"]) |
| if result.returncode == 0: |
| print("No changes to commit.") |
| return |
|
|
| subprocess.run(["git", "commit", "-m", f'Checkpoint at step {state.global_step}'], check=True) |
| subprocess.run(["git", "push"], check=True) |
| print("Checkpoint pushed successfully.") |
| except subprocess.CalledProcessError as e: |
| print(f"Git push failed: {e}") |
| class SmallExpertSaveCallback(TrainerCallback): |
| def __init__(self, model, trainable_params): |
| self.model = model |
| self.trainable_params = trainable_params |
|
|
| def on_save(self, args, state, control, **kwargs): |
| |
| checkpoint_dir = os.path.join(args.output_dir, f"checkpoint-{state.global_step}") |
| small_expert_path = os.path.join(checkpoint_dir, "small_experts_and_gates.bin") |
|
|
| small_expert_state_dict = { |
| name: param for name, param in self.model.named_parameters() |
| if name in self.trainable_params |
| } |
|
|
| if small_expert_state_dict: |
| os.makedirs(checkpoint_dir, exist_ok=True) |
| torch.save(small_expert_state_dict, small_expert_path) |
| print(f"[INFO] Saved {len(small_expert_state_dict)} small_expert/small_gate parameters " |
| f"to {small_expert_path}") |
| else: |
| print("[ERROR] No small_expert or small_gate parameters found to save!") |
|
|
| |
| trainer = CustomTrainer( |
| model=model, |
| args=training_args, |
| train_dataset=tokenized_dataset, |
| data_collator=data_collator, |
| callbacks=[ |
| GitPushCallback(), |
| SmallExpertSaveCallback(model, trainable_params) |
| ] |
| ) |
| |
| |
| print("Testing gradient flow...") |
| test_loader = DataLoader(tokenized_dataset, batch_size=1, collate_fn=data_collator) |
| test_batch = next(iter(test_loader)) |
| |
| |
| device = next(model.parameters()).device |
| test_batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in test_batch.items()} |
| |
| model.train() |
| outputs = model(**test_batch) |
| loss = outputs.loss |
| print(f"Initial loss: {loss.item()}") |
| |
| loss.backward() |
| print("Gradients computed successfully") |
| |
| |
| for name, param in model.named_parameters(): |
| if param.grad is not None: |
| print(f"Parameter {name} received gradients") |
| |
| |
| model.zero_grad() |
|
|
| |
| import re |
|
|
| checkpoint_dir = None |
| if os.path.isdir(training_args.output_dir): |
| checkpoints = [ |
| os.path.join(training_args.output_dir, d) |
| for d in os.listdir(training_args.output_dir) |
| if re.match(r"checkpoint-\d+", d) |
| ] |
| if checkpoints: |
| |
| checkpoint_dir = max(checkpoints, key=lambda x: int(x.split('-')[-1])) |
| print(f"Resuming from checkpoint: {checkpoint_dir}") |
|
|
|
|
| |
| print("Starting training...") |
| trainer.train(resume_from_checkpoint=checkpoint_dir) |
|
|
| |
| print("Saving small experts and gates...") |
| small_expert_state_dict = { |
| name: param for name, param in model.named_parameters() |
| if name in trainable_params |
| } |
| |
| os.makedirs("./final_model", exist_ok=True) |
| torch.save(small_expert_state_dict, "./final_model/small_experts_and_gates.bin") |
| config.save_pretrained("./final_model") |
|
|
| if __name__ == "__main__": |
| main() |