| |
| |
| """Eval script for DFlash LoRA: compute accepted length on a given dataset. |
| |
| Usage: |
| python scripts/eval_dflash_lora.py \ |
| --model-path /workspace/Qwen3-8B \ |
| --ckpt-dir outputs/qwen3-8b-dflash-lora/epoch_2_step_218500 \ |
| --data-path /workspace/hanrui/datasets/Nemotron-CodeAlpaca-qwen3-8b-800K \ |
| --lora-config configs/qwen3-8b-dflash-lora.json \ |
| --block-size 16 \ |
| --max-length 2048 \ |
| --batch-size 1 \ |
| --attention-backend flex_attention \ |
| --chat-template qwen |
| """ |
|
|
| import argparse |
| import json |
| import logging |
| import os |
| import warnings |
| from typing import Optional, Tuple |
|
|
| import torch |
| import torch.distributed as dist |
| from transformers import AutoTokenizer |
|
|
| from datasets import load_dataset |
| from specforge.core.dflash_lora import OnlineDFlashLoRAModel |
| from specforge.data import build_eagle3_dataset, prepare_dp_dataloaders |
| from specforge.distributed import destroy_distributed, get_dp_group, init_distributed |
| from specforge.modeling.draft.dflash_lora import DFlashLoRADraftModel |
| from specforge.utils import print_on_rank0, print_with_rank |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="Eval DFlash LoRA: compute accepted length") |
|
|
| model_group = parser.add_argument_group("model") |
| model_group.add_argument("--model-path", type=str, required=True, |
| help="Path to base model (e.g. /workspace/Qwen3-8B)") |
| model_group.add_argument("--ckpt-dir", type=str, required=True, |
| help="Path to LoRA checkpoint directory (adapter_model.safetensors)") |
| model_group.add_argument("--block-size", type=int, default=16) |
| model_group.add_argument("--mask-token-id", type=int, default=None) |
| model_group.add_argument("--context-len", type=int, default=0) |
| model_group.add_argument("--trust-remote-code", action="store_true") |
| model_group.add_argument("--attn-implementation", type=str, default="sdpa", |
| choices=["sdpa", "eager"]) |
| model_group.add_argument("--attention-backend", type=str, default="flex_attention", |
| choices=["flex_attention", "additive"]) |
| model_group.add_argument("--lm-head-chunk-size", type=int, default=256) |
|
|
| lora_group = parser.add_argument_group("lora") |
| lora_group.add_argument("--lora-rank", type=int, default=16) |
| lora_group.add_argument("--lora-alpha", type=int, default=32) |
| lora_group.add_argument("--lora-dropout", type=float, default=0.0) |
| lora_group.add_argument("--lora-target-modules", type=str, nargs="+", |
| default=["q_proj", "k_proj", "v_proj", "o_proj"]) |
| lora_group.add_argument("--lora-config", type=str, default=None, |
| help="Path to JSON file with LoRA config") |
|
|
| dataset_group = parser.add_argument_group("dataset") |
| dataset_group.add_argument("--data-path", type=str, required=True) |
| dataset_group.add_argument("--chat-template", type=str, default="qwen") |
| dataset_group.add_argument("--is-preformatted", action="store_true") |
| dataset_group.add_argument("--max-length", type=int, default=2048) |
| dataset_group.add_argument("--batch-size", type=int, default=1) |
| dataset_group.add_argument("--num-workers", type=int, default=8) |
| dataset_group.add_argument("--num-samples", type=int, default=None, |
| help="Limit number of samples to evaluate (default: all)") |
| dataset_group.add_argument("--build-dataset-num-proc", type=int, |
| default=int(os.environ.get("SPECFORGE_DATA_NUM_PROC", 8))) |
|
|
| misc_group = parser.add_argument_group("misc") |
| misc_group.add_argument("--cache-dir", type=str, default="./cache") |
| misc_group.add_argument("--log-interval", type=int, default=10) |
| misc_group.add_argument("--dist-timeout", type=int, default=30) |
|
|
| return parser.parse_args() |
|
|
|
|
| def build_model(args) -> Tuple[DFlashLoRADraftModel, OnlineDFlashLoRAModel]: |
| print_on_rank0(f"Loading base model from {args.model_path}") |
|
|
| lora_rank = args.lora_rank |
| lora_alpha = args.lora_alpha |
| lora_dropout = args.lora_dropout |
| lora_target_modules = args.lora_target_modules |
|
|
| if args.lora_config is not None: |
| with open(args.lora_config) as f: |
| lora_cfg = json.load(f) |
| lora_rank = lora_cfg.get("lora_rank", lora_rank) |
| lora_alpha = lora_cfg.get("lora_alpha", lora_alpha) |
| lora_dropout = lora_cfg.get("lora_dropout", lora_dropout) |
| lora_target_modules = lora_cfg.get("lora_target_modules", lora_target_modules) |
| print_on_rank0(f"Loaded LoRA config from {args.lora_config}") |
|
|
| attn_impl = "flex_attention" if args.attention_backend == "flex_attention" else args.attn_implementation |
|
|
| draft_model = DFlashLoRADraftModel.from_pretrained( |
| pretrained_model_name_or_path=args.model_path, |
| lora_rank=lora_rank, |
| lora_alpha=lora_alpha, |
| lora_dropout=lora_dropout, |
| lora_target_modules=lora_target_modules, |
| block_size=args.block_size, |
| mask_token_id=args.mask_token_id or 151669, |
| torch_dtype=torch.bfloat16, |
| device_map="cuda", |
| trust_remote_code=args.trust_remote_code, |
| attn_implementation=attn_impl, |
| ) |
|
|
| |
| print_on_rank0(f"Loading LoRA weights from {args.ckpt_dir}") |
| from peft import PeftModel |
| draft_model.model = PeftModel.from_pretrained( |
| draft_model.model.base_model.model, args.ckpt_dir |
| ) |
|
|
| online_model = OnlineDFlashLoRAModel( |
| draft_model=draft_model, |
| block_size=args.block_size, |
| mask_token_id=args.mask_token_id or 151669, |
| loss_decay_gamma=None, |
| attention_backend=args.attention_backend, |
| lm_head_chunk_size=args.lm_head_chunk_size, |
| ) |
|
|
| return draft_model, online_model |
|
|
|
|
| def build_dataloader(args, tokenizer): |
| import hashlib |
|
|
| cache_params_string = ( |
| f"{args.data_path}-{args.max_length}-{args.chat_template}-{args.model_path}" |
| ) |
| cache_key = hashlib.md5(cache_params_string.encode()).hexdigest() |
|
|
| rank = dist.get_rank() |
|
|
| if os.path.isdir(args.data_path): |
| dataset = load_dataset(args.data_path, split="train") |
| else: |
| dataset = load_dataset("json", data_files=args.data_path)["train"] |
|
|
| if args.num_samples is not None: |
| dataset = dataset.select(range(min(args.num_samples, len(dataset)))) |
| print_on_rank0(f"Using {len(dataset)} samples for eval") |
|
|
| dataset_kwargs = dict( |
| dataset=dataset, |
| tokenizer=tokenizer, |
| chat_template=args.chat_template, |
| max_length=args.max_length, |
| is_preformatted=args.is_preformatted, |
| cache_dir=os.path.join(args.cache_dir, "processed_dataset"), |
| cache_key=cache_key, |
| num_proc=args.build_dataset_num_proc, |
| ) |
|
|
| if rank == 0: |
| eval_dataset = build_eagle3_dataset(**dataset_kwargs) |
| dist.barrier() |
| if rank != 0: |
| eval_dataset = build_eagle3_dataset(**dataset_kwargs) |
|
|
| min_loss_tokens = 2 * args.block_size |
| original_size = len(eval_dataset) |
| eval_dataset = eval_dataset.filter( |
| lambda x: x["loss_mask"].sum() >= min_loss_tokens |
| ) |
| print_on_rank0(f"Filtered dataset: {original_size} -> {len(eval_dataset)} samples") |
|
|
| dataloader = prepare_dp_dataloaders( |
| eval_dataset, |
| args.batch_size, |
| num_workers=args.num_workers, |
| shuffle=False, |
| process_group=get_dp_group(), |
| ) |
| return dataloader |
|
|
|
|
| def main(): |
| logging.basicConfig( |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| datefmt="%m/%d/%Y %H:%M:%S", |
| level=logging.INFO, |
| ) |
| warnings.filterwarnings( |
| "ignore", |
| "The .grad attribute of a Tensor that is not a leaf Tensor is being accessed", |
| ) |
|
|
| args = parse_args() |
|
|
| init_distributed(timeout=args.dist_timeout, tp_size=1) |
| print_with_rank("Initialized distributed") |
|
|
| tokenizer = AutoTokenizer.from_pretrained(args.model_path) |
| if args.mask_token_id is not None: |
| mask_token_id = args.mask_token_id |
| elif tokenizer.mask_token_id is not None: |
| mask_token_id = tokenizer.mask_token_id |
| else: |
| tokenizer.add_special_tokens({"mask_token": "<|MASK|>"}) |
| mask_token_id = tokenizer.mask_token_id |
| print_on_rank0(f"Using mask_token_id: {mask_token_id}") |
| args.mask_token_id = mask_token_id |
|
|
| draft_model, online_model = build_model(args) |
| draft_model.mask_token_id = mask_token_id |
| online_model.mask_token_id = mask_token_id |
|
|
| dataloader = build_dataloader(args, tokenizer) |
|
|
| draft_model.eval() |
| online_model.eval() |
|
|
| total_acc = 0.0 |
| total_loss = 0.0 |
| total_steps = 0 |
|
|
| print_on_rank0(f"Starting eval on {len(dataloader)} batches...") |
|
|
| with torch.no_grad(): |
| for step, data in enumerate(dataloader): |
| input_ids = data["input_ids"].cuda() |
| attention_mask = data["attention_mask"].cuda() |
| loss_mask = data["loss_mask"].cuda() |
|
|
| loss, accuracy = online_model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| loss_mask=loss_mask, |
| context_len=args.context_len, |
| ) |
|
|
| total_acc += accuracy.item() |
| total_loss += loss.item() |
| total_steps += 1 |
|
|
| if (step + 1) % args.log_interval == 0: |
| avg_acc = total_acc / total_steps |
| avg_accepted_length = avg_acc * (args.block_size - 1) |
| print_on_rank0( |
| f"Step {step + 1}/{len(dataloader)} | " |
| f"loss: {total_loss / total_steps:.4f} | " |
| f"acc: {avg_acc:.4f} | " |
| f"accepted_length: {avg_accepted_length:.4f}" |
| ) |
|
|
| |
| acc_t = torch.tensor(total_acc / total_steps, device="cuda") |
| loss_t = torch.tensor(total_loss / total_steps, device="cuda") |
| dist.all_reduce(acc_t) |
| dist.all_reduce(loss_t) |
| world_size = dist.get_world_size() |
|
|
| final_acc = acc_t.item() / world_size |
| final_loss = loss_t.item() / world_size |
| final_accepted_length = final_acc * (args.block_size - 1) |
|
|
| print_on_rank0( |
| f"\n=== Eval Results ===\n" |
| f" Loss: {final_loss:.4f}\n" |
| f" Accuracy: {final_acc:.4f}\n" |
| f" Accepted Length: {final_accepted_length:.4f} / {args.block_size - 1}\n" |
| f" Num batches: {total_steps}\n" |
| ) |
|
|
| destroy_distributed() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|