|
|
""" |
|
|
Maze Video Dataset Generator — generates maze puzzle images and solution videos |
|
|
with checkpoint/resume support, train/test splitting, and JSONL metadata. |
|
|
|
|
|
Includes an ``eval`` subcommand that takes a directory of result videos, |
|
|
extracts the last frame from each, parses the red path, and verifies it |
|
|
against the ground-truth maze text files. |
|
|
|
|
|
Usage: |
|
|
# Generate |
|
|
python maze_video_gen.py generate --output-dir maze --sizes 8 16 32 \ |
|
|
--num-per-size 100 500 1000 --min-path-ratio 0.3 \ |
|
|
--n-start 5 --m-end 5 --frames 50 --fps 10 --seed 42 |
|
|
|
|
|
# Evaluate result videos |
|
|
python maze_video_gen.py eval result_videos/ --text-dir maze/texts |
|
|
|
|
|
# Verify a pre-extracted JSON |
|
|
python maze_video_gen.py verify results.json --text-dir maze/texts |
|
|
""" |
|
|
import json |
|
|
import csv |
|
|
import hashlib |
|
|
import random |
|
|
import re |
|
|
import argparse |
|
|
from dataclasses import dataclass, asdict |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Optional |
|
|
|
|
|
import cv2 |
|
|
import numpy as np |
|
|
from tqdm import tqdm |
|
|
|
|
|
from maze_processor import MazeProcessor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class GenerationState: |
|
|
"""Tracks generation progress for checkpoint/resume.""" |
|
|
params_hash: str |
|
|
size_progress: Dict[int, int] |
|
|
seen_fingerprints: List[str] |
|
|
all_samples: List[Dict] |
|
|
completed: bool = False |
|
|
|
|
|
def to_dict(self) -> Dict: |
|
|
return asdict(self) |
|
|
|
|
|
@classmethod |
|
|
def from_dict(cls, d: Dict) -> "GenerationState": |
|
|
return cls(**d) |
|
|
|
|
|
|
|
|
def _params_hash(params: Dict) -> str: |
|
|
"""Deterministic hash of generation parameters (excluding output_dir).""" |
|
|
key = {k: v for k, v in params.items() if k != "output_dir"} |
|
|
return hashlib.md5(json.dumps(key, sort_keys=True).encode()).hexdigest()[:12] |
|
|
|
|
|
|
|
|
def load_checkpoint(output_dir: Path, params: Dict) -> Optional[GenerationState]: |
|
|
"""Load checkpoint if it exists and parameters match.""" |
|
|
meta = output_dir / "metadata.json" |
|
|
if not meta.exists(): |
|
|
return None |
|
|
with open(meta) as f: |
|
|
data = json.load(f) |
|
|
state = GenerationState.from_dict(data["state"]) |
|
|
expected = _params_hash(params) |
|
|
if state.params_hash != expected: |
|
|
print(f"⚠️ Parameters changed ({state.params_hash} → {expected}), starting fresh") |
|
|
return None |
|
|
if state.completed: |
|
|
print("✓ Generation already completed") |
|
|
return state |
|
|
done = sum(state.size_progress.values()) |
|
|
print(f"✓ Resuming from checkpoint: {done} mazes generated") |
|
|
return state |
|
|
|
|
|
|
|
|
def save_checkpoint(output_dir: Path, state: GenerationState, params: Dict): |
|
|
"""Atomically write checkpoint to metadata.json.""" |
|
|
meta = output_dir / "metadata.json" |
|
|
tmp = meta.with_suffix(".tmp") |
|
|
with open(tmp, "w") as f: |
|
|
json.dump({"params": params, "state": state.to_dict()}, f, indent=2) |
|
|
tmp.rename(meta) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_video_cv2(frames: list, path: str, fps: int = 10): |
|
|
"""Save list of PIL Images as an mp4 video.""" |
|
|
first = np.array(frames[0]) |
|
|
h, w = first.shape[:2] |
|
|
writer = cv2.VideoWriter( |
|
|
str(path), cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h) |
|
|
) |
|
|
for frame in frames: |
|
|
writer.write(cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR)) |
|
|
writer.release() |
|
|
|
|
|
|
|
|
def extract_last_frame(video_path: str) -> Optional[np.ndarray]: |
|
|
""" |
|
|
Extract the last frame from a video file as an RGB numpy array. |
|
|
|
|
|
Returns: |
|
|
(H, W, 3) uint8 RGB array, or None on failure. |
|
|
""" |
|
|
cap = cv2.VideoCapture(str(video_path)) |
|
|
if not cap.isOpened(): |
|
|
return None |
|
|
|
|
|
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
if total > 0: |
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, total - 1) |
|
|
|
|
|
ret, frame = cap.read() |
|
|
cap.release() |
|
|
|
|
|
if not ret or frame is None: |
|
|
return None |
|
|
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _normalise_list(val, sizes, name="parameter"): |
|
|
"""Broadcast a single int to a list, or validate list length.""" |
|
|
if isinstance(val, int): |
|
|
return [val] * len(sizes) |
|
|
if len(val) != len(sizes): |
|
|
raise ValueError(f"{name} length ({len(val)}) != sizes length ({len(sizes)})") |
|
|
return list(val) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_dataset( |
|
|
output_dir: str = "maze", |
|
|
sizes: List[int] = [8, 16, 32], |
|
|
num_per_size: list = [100, 500, 1000], |
|
|
min_path_ratio: float = 0.3, |
|
|
img_size: int = 1024, |
|
|
prompt: str = "Draw a continuous red line from the yellow dot to the blue dot, avoiding all walls.", |
|
|
train_ratio: float = 0.9, |
|
|
n_start: int = 5, |
|
|
m_end: int = 5, |
|
|
frames: Optional[int] = None, |
|
|
fps: int = 10, |
|
|
seed: int = 42, |
|
|
checkpoint_interval: int = 50, |
|
|
): |
|
|
""" |
|
|
Generate maze video dataset with checkpoint/resume support. |
|
|
|
|
|
The *frames* parameter controls content frames per video: |
|
|
- None → one content frame per path step (variable length) |
|
|
- N > 0 → exactly N content frames (slow-mo / fast-fwd as needed) |
|
|
|
|
|
Directory layout:: |
|
|
|
|
|
output_dir/ |
|
|
images/ — puzzle PNG (no solution line) |
|
|
videos/ — solution MP4 (progressive red line) |
|
|
texts/ — maze text files (bitmask format) |
|
|
train.jsonl / test.jsonl |
|
|
train.csv / test.csv |
|
|
path.json — UDRL answer key |
|
|
metadata.json — checkpoint state |
|
|
""" |
|
|
params = { |
|
|
"sizes": sizes, "num_per_size": num_per_size, |
|
|
"min_path_ratio": min_path_ratio, "img_size": img_size, |
|
|
"prompt": prompt, "train_ratio": train_ratio, |
|
|
"n_start": n_start, "m_end": m_end, "frames": frames, |
|
|
"fps": fps, "seed": seed, |
|
|
} |
|
|
|
|
|
out = Path(output_dir) |
|
|
img_dir = out / "images" |
|
|
vid_dir = out / "videos" |
|
|
txt_dir = out / "texts" |
|
|
for d in (img_dir, vid_dir, txt_dir): |
|
|
d.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
state = load_checkpoint(out, params) |
|
|
if state and state.completed: |
|
|
return |
|
|
|
|
|
num_list = _normalise_list( |
|
|
num_per_size[0] if len(num_per_size) == 1 else num_per_size, |
|
|
sizes, "num_per_size", |
|
|
) |
|
|
max_puzzles = max(num_list) |
|
|
num_w = len(str(max_puzzles)) |
|
|
proc = MazeProcessor(img_size=img_size) |
|
|
|
|
|
if state is None: |
|
|
random.seed(seed) |
|
|
state = GenerationState( |
|
|
params_hash=_params_hash(params), |
|
|
size_progress={sz: 0 for sz in sizes}, |
|
|
seen_fingerprints=[], |
|
|
all_samples=[], |
|
|
) |
|
|
print(f"Starting fresh generation: sizes={sizes}, counts={num_list}") |
|
|
print(f" frames={'auto (1 per step)' if frames is None else frames}, " |
|
|
f"n_start={n_start}, m_end={m_end}, fps={fps}") |
|
|
else: |
|
|
random.seed(seed) |
|
|
for _ in range(sum(state.size_progress.values()) * 10): |
|
|
random.random() |
|
|
|
|
|
seen = set(state.seen_fingerprints) |
|
|
all_samples = list(state.all_samples) |
|
|
progress = {int(k): v for k, v in state.size_progress.items()} |
|
|
since_ckpt = 0 |
|
|
|
|
|
total_target = sum(num_list) |
|
|
total_done = sum(progress.values()) |
|
|
|
|
|
with tqdm(total=total_target, initial=total_done, desc="Total", unit="maze") as pbar: |
|
|
for maze_size, target in zip(sizes, num_list): |
|
|
generated = progress.get(maze_size, 0) |
|
|
if generated >= target: |
|
|
continue |
|
|
|
|
|
min_len = max(1, int(maze_size * maze_size * min_path_ratio)) |
|
|
max_attempts = (target - generated) * 20 |
|
|
|
|
|
with tqdm( |
|
|
total=target, initial=generated, desc=f"Size {maze_size:3d}", |
|
|
unit="maze", leave=False, |
|
|
) as pbar_sz: |
|
|
for _ in range(max_attempts): |
|
|
if generated >= target: |
|
|
break |
|
|
|
|
|
try: |
|
|
grid, start, end, path = proc.generate( |
|
|
maze_size, min_path_len=min_len |
|
|
) |
|
|
except RuntimeError: |
|
|
continue |
|
|
|
|
|
fp = proc.fingerprint(grid, start, end) |
|
|
if fp in seen: |
|
|
continue |
|
|
seen.add(fp) |
|
|
|
|
|
idx = generated |
|
|
base = f"size{maze_size}_{idx:0{num_w}d}" |
|
|
img_name = f"{base}.png" |
|
|
vid_name = f"{base}.mp4" |
|
|
txt_name = f"{base}.txt" |
|
|
|
|
|
puzzle_img = proc.render(grid, start, end) |
|
|
puzzle_img.save(str(img_dir / img_name)) |
|
|
|
|
|
vid_frames = proc.generate_video_frames( |
|
|
grid, start, end, path, |
|
|
n_start=n_start, m_end=m_end, frames=frames, |
|
|
) |
|
|
save_video_cv2(vid_frames, str(vid_dir / vid_name), fps=fps) |
|
|
|
|
|
proc.save_text(str(txt_dir / txt_name), grid, start, end) |
|
|
|
|
|
udrl = proc.path_to_udrl(path) |
|
|
|
|
|
all_samples.append({ |
|
|
"prompt": prompt, |
|
|
"image": img_name, |
|
|
"video": vid_name, |
|
|
"text": txt_name, |
|
|
"maze_size": maze_size, |
|
|
"start": list(start), |
|
|
"end": list(end), |
|
|
"path_udrl": udrl, |
|
|
"path_length": len(path), |
|
|
"frame_count": len(vid_frames), |
|
|
}) |
|
|
|
|
|
generated += 1 |
|
|
progress[maze_size] = generated |
|
|
since_ckpt += 1 |
|
|
pbar_sz.update(1) |
|
|
pbar.update(1) |
|
|
|
|
|
if since_ckpt >= checkpoint_interval: |
|
|
state.size_progress = progress |
|
|
state.seen_fingerprints = list(seen) |
|
|
state.all_samples = all_samples |
|
|
save_checkpoint(out, state, params) |
|
|
since_ckpt = 0 |
|
|
|
|
|
tqdm.write( |
|
|
f"Size {maze_size}: {generated} mazes, " |
|
|
f"{sum(1 for s in all_samples if s['maze_size'] == maze_size)} samples" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
path_answers = {s["image"]: s["path_udrl"] for s in all_samples} |
|
|
with open(out / "path.json", "w") as f: |
|
|
json.dump(dict(sorted(path_answers.items())), f, indent=4) |
|
|
|
|
|
random.seed(seed + 1) |
|
|
random.shuffle(all_samples) |
|
|
split = int(len(all_samples) * train_ratio) |
|
|
|
|
|
def _write_jsonl(samples, path): |
|
|
with open(path, "w") as f: |
|
|
for s in samples: |
|
|
f.write(json.dumps(s) + "\n") |
|
|
|
|
|
_write_jsonl(all_samples[:split], out / "train.jsonl") |
|
|
_write_jsonl(all_samples[split:], out / "test.jsonl") |
|
|
|
|
|
for name, samples in [("train", all_samples[:split]), ("test", all_samples[split:])]: |
|
|
with open(out / f"{name}.csv", "w", newline="", encoding="utf-8") as f: |
|
|
writer = csv.writer(f) |
|
|
writer.writerow(["input_image", "video", "prompt"]) |
|
|
for s in samples: |
|
|
writer.writerow([ |
|
|
f"images/{s['image']}", f"videos/{s['video']}", s["prompt"] |
|
|
]) |
|
|
|
|
|
state.size_progress = progress |
|
|
state.seen_fingerprints = list(seen) |
|
|
state.all_samples = all_samples |
|
|
state.completed = True |
|
|
save_checkpoint(out, state, params) |
|
|
|
|
|
print(f"\n✓ Dataset complete: {out}/") |
|
|
print(f" Sizes: {sizes}") |
|
|
print(f" Mazes: {len(all_samples)}") |
|
|
print(f" Train: {split}, Test: {len(all_samples) - split}") |
|
|
lengths = [s["path_length"] for s in all_samples] |
|
|
fcounts = [s["frame_count"] for s in all_samples] |
|
|
print(f" Path lengths: avg={np.mean(lengths):.1f}, " |
|
|
f"min={min(lengths)}, max={max(lengths)}") |
|
|
print(f" Frame counts: avg={np.mean(fcounts):.1f}, " |
|
|
f"min={min(fcounts)}, max={max(fcounts)}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def eval_videos( |
|
|
video_dir: str, |
|
|
text_dir: str, |
|
|
output_json: Optional[str] = None, |
|
|
gt_json: Optional[str] = None, |
|
|
): |
|
|
""" |
|
|
Evaluate a directory of result videos against ground-truth mazes. |
|
|
|
|
|
Pipeline per video: |
|
|
1. Extract last frame from .mp4 |
|
|
2. Detect red path via pixel analysis |
|
|
3. Convert to UDRL action string |
|
|
4. Verify against maze .txt (wall-respecting walk from start to end) |
|
|
|
|
|
Matching convention: |
|
|
Video ``<stem>.mp4`` → Text ``<stem>.txt`` in *text_dir*. |
|
|
Common stems: ``size8_000``, ``size16_042``, etc. |
|
|
|
|
|
Args: |
|
|
video_dir: Directory containing result .mp4 files. |
|
|
text_dir: Directory containing ground-truth maze .txt files. |
|
|
output_json: Path to save extracted paths as JSON (default: video_dir/0_result.json). |
|
|
gt_json: Optional ground-truth answer JSON for accuracy by path length. |
|
|
""" |
|
|
proc = MazeProcessor() |
|
|
vid_root = Path(video_dir) |
|
|
txt_root = Path(text_dir) |
|
|
|
|
|
if output_json is None: |
|
|
output_json = str(vid_root / "0_result.json") |
|
|
|
|
|
|
|
|
videos = sorted( |
|
|
vid_root.glob("*.mp4"), |
|
|
key=lambda p: [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", p.stem)], |
|
|
) |
|
|
|
|
|
if not videos: |
|
|
print(f"No .mp4 files found in {vid_root}") |
|
|
return |
|
|
|
|
|
print(f"Found {len(videos)} result videos in {vid_root}") |
|
|
print(f"Text dir: {txt_root}") |
|
|
|
|
|
|
|
|
extracted: Dict[str, str] = {} |
|
|
missing_txt = 0 |
|
|
missing_frame = 0 |
|
|
|
|
|
for vpath in tqdm(videos, desc="Extracting paths"): |
|
|
stem = vpath.stem |
|
|
txt_path = txt_root / f"{stem}.txt" |
|
|
|
|
|
if not txt_path.exists(): |
|
|
missing_txt += 1 |
|
|
continue |
|
|
|
|
|
maze = proc.load_text(str(txt_path)) |
|
|
if maze is None: |
|
|
missing_txt += 1 |
|
|
continue |
|
|
|
|
|
last_frame = extract_last_frame(str(vpath)) |
|
|
if last_frame is None: |
|
|
missing_frame += 1 |
|
|
continue |
|
|
|
|
|
udrl = proc.extract_path_from_pixels( |
|
|
last_frame, |
|
|
grid_raw=maze["grid_raw"], |
|
|
size=maze["size"], |
|
|
start=maze["start"], |
|
|
) |
|
|
extracted[f"{stem}.png"] = udrl |
|
|
|
|
|
|
|
|
with open(output_json, "w", encoding="utf-8") as f: |
|
|
json.dump(extracted, f, indent=4) |
|
|
print(f"\nExtracted paths saved to: {output_json}") |
|
|
|
|
|
|
|
|
correct = 0 |
|
|
total_valid = 0 |
|
|
correctly_solved: List[Dict] = [] |
|
|
|
|
|
for name, udrl in extracted.items(): |
|
|
stem = name.replace(".png", "") |
|
|
txt_path = txt_root / f"{stem}.txt" |
|
|
maze = proc.load_text(str(txt_path)) |
|
|
if maze is None: |
|
|
continue |
|
|
total_valid += 1 |
|
|
if proc.verify_path(maze["grid"], maze["start"], maze["end"], udrl): |
|
|
correct += 1 |
|
|
correctly_solved.append({"name": name, "length": len(udrl)}) |
|
|
|
|
|
acc = (correct / total_valid * 100) if total_valid else 0 |
|
|
|
|
|
print(f"\n{'=' * 50}") |
|
|
print("Evaluation Summary") |
|
|
print(f"{'=' * 50}") |
|
|
print(f"Total Videos : {len(videos)}") |
|
|
print(f"Missing .txt : {missing_txt}") |
|
|
print(f"Failed Frame Read : {missing_frame}") |
|
|
print(f"Evaluated : {total_valid}") |
|
|
print(f"Correctly Solved : {correct}") |
|
|
print(f"Accuracy : {acc:.2f}%") |
|
|
print(f"{'-' * 50}") |
|
|
|
|
|
|
|
|
size_stats: Dict[int, Dict[str, int]] = {} |
|
|
for name, udrl in extracted.items(): |
|
|
stem = name.replace(".png", "") |
|
|
txt_path = txt_root / f"{stem}.txt" |
|
|
maze = proc.load_text(str(txt_path)) |
|
|
if maze is None: |
|
|
continue |
|
|
sz = maze["size"] |
|
|
if sz not in size_stats: |
|
|
size_stats[sz] = {"total": 0, "correct": 0} |
|
|
size_stats[sz]["total"] += 1 |
|
|
if proc.verify_path(maze["grid"], maze["start"], maze["end"], udrl): |
|
|
size_stats[sz]["correct"] += 1 |
|
|
|
|
|
if size_stats: |
|
|
print("\nAccuracy by maze size:") |
|
|
for sz in sorted(size_stats): |
|
|
s = size_stats[sz] |
|
|
sz_acc = s["correct"] / s["total"] * 100 if s["total"] else 0 |
|
|
print(f" Size {sz:3d}: {s['correct']:4d}/{s['total']:4d} ({sz_acc:.2f}%)") |
|
|
|
|
|
|
|
|
correctly_solved.sort(key=lambda x: x["length"], reverse=True) |
|
|
if correctly_solved: |
|
|
print(f"\nTop 3 Longest Correct Paths:") |
|
|
for i, item in enumerate(correctly_solved[:3]): |
|
|
print(f" {i+1}. {item['name']} (length: {item['length']})") |
|
|
|
|
|
|
|
|
if gt_json: |
|
|
_compare_with_gt(extracted, gt_json, txt_root, proc) |
|
|
|
|
|
print(f"{'=' * 50}") |
|
|
|
|
|
|
|
|
def _compare_with_gt( |
|
|
extracted: Dict[str, str], |
|
|
gt_json_path: str, |
|
|
txt_root: Path, |
|
|
proc: MazeProcessor, |
|
|
): |
|
|
"""Print accuracy binned by ground-truth path length.""" |
|
|
try: |
|
|
with open(gt_json_path) as f: |
|
|
gt = json.load(f) |
|
|
except Exception: |
|
|
print(f" Warning: could not load ground-truth JSON: {gt_json_path}") |
|
|
return |
|
|
|
|
|
bins: Dict[str, Dict[str, int]] = {} |
|
|
for name, pred_udrl in extracted.items(): |
|
|
if name not in gt: |
|
|
continue |
|
|
gt_udrl = gt[name] |
|
|
gt_len = len(gt_udrl) |
|
|
|
|
|
|
|
|
lo = (gt_len // 10) * 10 |
|
|
hi = lo + 9 |
|
|
label = f"{lo:3d}-{hi:3d}" |
|
|
if label not in bins: |
|
|
bins[label] = {"total": 0, "correct": 0} |
|
|
bins[label]["total"] += 1 |
|
|
|
|
|
stem = name.replace(".png", "") |
|
|
maze = proc.load_text(str(txt_root / f"{stem}.txt")) |
|
|
if maze and proc.verify_path(maze["grid"], maze["start"], maze["end"], pred_udrl): |
|
|
bins[label]["correct"] += 1 |
|
|
|
|
|
if bins: |
|
|
print("\nAccuracy by GT path length:") |
|
|
for label in sorted(bins): |
|
|
b = bins[label] |
|
|
b_acc = b["correct"] / b["total"] * 100 if b["total"] else 0 |
|
|
print(f" Length {label}: {b['correct']:4d}/{b['total']:4d} ({b_acc:.2f}%)") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def verify_results(json_file: str, text_dir: str): |
|
|
""" |
|
|
Verify pre-extracted UDRL paths (from a JSON file) against maze .txt files. |
|
|
|
|
|
Args: |
|
|
json_file: Path to JSON with {name: udrl_string} predictions. |
|
|
text_dir: Directory containing maze .txt files. |
|
|
""" |
|
|
proc = MazeProcessor() |
|
|
json_path = Path(json_file) |
|
|
txt_root = Path(text_dir) |
|
|
|
|
|
with open(json_path) as f: |
|
|
solutions = json.load(f) |
|
|
|
|
|
correct = skipped = valid = 0 |
|
|
|
|
|
for name, udrl in solutions.items(): |
|
|
clean = name.replace(".png", "") |
|
|
txt_path = txt_root / f"{clean}.txt" |
|
|
maze = proc.load_text(str(txt_path)) |
|
|
if maze is None: |
|
|
skipped += 1 |
|
|
continue |
|
|
valid += 1 |
|
|
if proc.verify_path(maze["grid"], maze["start"], maze["end"], udrl): |
|
|
correct += 1 |
|
|
|
|
|
acc = (correct / valid * 100) if valid else 0 |
|
|
print(f"\n{'='*40}") |
|
|
print(f"Verification: {correct}/{valid} correct ({acc:.2f}%)") |
|
|
if skipped: |
|
|
print(f"Skipped: {skipped}") |
|
|
print(f"{'='*40}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
p = argparse.ArgumentParser( |
|
|
description="Maze video dataset: generate, eval, verify" |
|
|
) |
|
|
sub = p.add_subparsers(dest="command", help="Sub-command") |
|
|
|
|
|
|
|
|
gen = sub.add_parser("generate", help="Generate dataset") |
|
|
gen.add_argument("--output-dir", type=str, default="maze") |
|
|
gen.add_argument("--sizes", type=int, nargs="+", default=[8, 16, 24, 32]) |
|
|
gen.add_argument("--num-per-size", type=int, nargs="+", default=[100, 500, 1000, 2000]) |
|
|
gen.add_argument("--min-path-ratio", type=float, default=0.3, |
|
|
help="Min path length as fraction of size²") |
|
|
gen.add_argument("--img-size", type=int, default=1024) |
|
|
gen.add_argument("--prompt", type=str, |
|
|
default="Draw a continuous red line from the yellow dot " |
|
|
"to the blue dot, avoiding all walls.") |
|
|
gen.add_argument("--train-ratio", type=float, default=0.9) |
|
|
gen.add_argument("--n-start", type=int, default=2, |
|
|
help="Hold frames at video start (blank puzzle)") |
|
|
gen.add_argument("--m-end", type=int, default=3, |
|
|
help="Hold frames at video end (completed solution)") |
|
|
gen.add_argument("--frames", type=int, default=None, |
|
|
help="Content frames per video (None=auto 1 per step)") |
|
|
gen.add_argument("--fps", type=int, default=10) |
|
|
gen.add_argument("--seed", type=int, default=42) |
|
|
gen.add_argument("--checkpoint-interval", type=int, default=50) |
|
|
|
|
|
|
|
|
ev = sub.add_parser("eval", |
|
|
help="Evaluate result videos (last frame → extract → verify)") |
|
|
ev.add_argument("video_dir", type=str, |
|
|
help="Directory containing result .mp4 files") |
|
|
ev.add_argument("--text-dir", type=str, required=True, |
|
|
help="Directory with ground-truth maze .txt files") |
|
|
ev.add_argument("--output-json", type=str, default=None, |
|
|
help="Output JSON for extracted paths (default: video_dir/0_result.json)") |
|
|
ev.add_argument("--gt-json", type=str, default=None, |
|
|
help="Optional ground-truth path.json for length-binned accuracy") |
|
|
|
|
|
|
|
|
ver = sub.add_parser("verify", help="Verify a pre-extracted JSON of UDRL paths") |
|
|
ver.add_argument("json_file", type=str) |
|
|
ver.add_argument("--text-dir", type=str, required=True, |
|
|
help="Directory with maze .txt files") |
|
|
|
|
|
return p.parse_args() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
args = parse_args() |
|
|
|
|
|
if args.command == "generate": |
|
|
kwargs = {k: v for k, v in vars(args).items() if k != "command"} |
|
|
generate_dataset(**kwargs) |
|
|
|
|
|
elif args.command == "eval": |
|
|
eval_videos( |
|
|
video_dir=args.video_dir, |
|
|
text_dir=args.text_dir, |
|
|
output_json=args.output_json, |
|
|
gt_json=args.gt_json, |
|
|
) |
|
|
|
|
|
elif args.command == "verify": |
|
|
verify_results(args.json_file, args.text_dir) |
|
|
|
|
|
else: |
|
|
print("Usage: python maze_video_gen.py {generate|eval|verify} [options]") |
|
|
print(" python maze_video_gen.py generate --help") |
|
|
print(" python maze_video_gen.py eval --help") |
|
|
print(" python maze_video_gen.py verify --help") |