|
|
""" |
|
|
Sequence Prediction Evaluation with QwenImageEditPlusPipeline / Flux2KleinPipeline. |
|
|
|
|
|
Evaluates the model's ability to predict the next number in a sequence |
|
|
by generating images and extracting answers via OCR. |
|
|
""" |
|
|
|
|
|
import json |
|
|
import re |
|
|
from pathlib import Path |
|
|
from dataclasses import dataclass, field |
|
|
from enum import Enum |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from PIL import Image |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
class ModelType(str, Enum): |
|
|
QWEN_IMAGE_EDIT = "qwen" |
|
|
FLUX2_KLEIN = "flux2-klein" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class EvalConfig: |
|
|
"""Evaluation configuration.""" |
|
|
dataset_dir: str = "sequence_dataset" |
|
|
output_dir: str = "eval_results" |
|
|
|
|
|
|
|
|
model_type: ModelType = ModelType.QWEN_IMAGE_EDIT |
|
|
model_id: str = "" |
|
|
|
|
|
|
|
|
prompt: str = ( |
|
|
"Based on the number patterns shown in the previous images, " |
|
|
"fill in the missing number in the empty cell of the last image." |
|
|
) |
|
|
negative_prompt: str = "" |
|
|
|
|
|
|
|
|
num_inference_steps: int = 5 |
|
|
guidance_scale: float = 1.0 |
|
|
true_cfg_scale: float = 4.0 |
|
|
height: int = 210 |
|
|
width: int = 750 |
|
|
|
|
|
seed: int = 42 |
|
|
device: str = "cuda" |
|
|
dtype: torch.dtype = field(default_factory=lambda: torch.bfloat16) |
|
|
|
|
|
def __post_init__(self): |
|
|
"""Set default model_id based on model_type.""" |
|
|
if not self.model_id: |
|
|
if self.model_type == ModelType.QWEN_IMAGE_EDIT: |
|
|
self.model_id = "Qwen/Qwen-Image-Edit-2509" |
|
|
elif self.model_type == ModelType.FLUX2_KLEIN: |
|
|
self.model_id = "black-forest-labs/FLUX.2-klein-9B" |
|
|
|
|
|
|
|
|
class OCRExtractor: |
|
|
"""Extract numbers from grid images using OCR.""" |
|
|
|
|
|
def __init__(self, backend: str = "easyocr"): |
|
|
""" |
|
|
Args: |
|
|
backend: OCR backend ("easyocr" or "pytesseract"). |
|
|
""" |
|
|
self.backend = backend |
|
|
if backend == "easyocr": |
|
|
import easyocr |
|
|
self.reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available()) |
|
|
elif backend == "pytesseract": |
|
|
import pytesseract |
|
|
self.pytesseract = pytesseract |
|
|
else: |
|
|
raise ValueError(f"Unknown backend: {backend}") |
|
|
|
|
|
def extract_last_number(self, image: Image.Image) -> int | None: |
|
|
""" |
|
|
Extract the last (rightmost) number from a grid image. |
|
|
|
|
|
Args: |
|
|
image: PIL Image of the number grid. |
|
|
|
|
|
Returns: |
|
|
Extracted number or None if extraction fails. |
|
|
""" |
|
|
w, h = image.size |
|
|
cell_crop = image.crop((w * 3 // 4, 0, w, h)) |
|
|
cell_array = np.array(cell_crop) |
|
|
|
|
|
if self.backend == "easyocr": |
|
|
results = self.reader.readtext(cell_array) |
|
|
for _, text, conf in results: |
|
|
digits = re.findall(r'-?\d+', text) |
|
|
if digits: |
|
|
return int(digits[0]) |
|
|
|
|
|
elif self.backend == "pytesseract": |
|
|
text = self.pytesseract.image_to_string( |
|
|
cell_crop, config='--psm 7 -c tessedit_char_whitelist=0123456789-' |
|
|
) |
|
|
digits = re.findall(r'-?\d+', text) |
|
|
if digits: |
|
|
return int(digits[0]) |
|
|
|
|
|
return None |
|
|
|
|
|
def extract_all_numbers(self, image: Image.Image, num_cells: int = 4) -> list[int | None]: |
|
|
"""Extract all numbers from a grid image.""" |
|
|
w, h = image.size |
|
|
cell_width = w // num_cells |
|
|
numbers = [] |
|
|
|
|
|
for i in range(num_cells): |
|
|
cell_crop = image.crop((i * cell_width, 0, (i + 1) * cell_width, h)) |
|
|
cell_array = np.array(cell_crop) |
|
|
|
|
|
if self.backend == "easyocr": |
|
|
results = self.reader.readtext(cell_array) |
|
|
num = None |
|
|
for _, text, conf in results: |
|
|
digits = re.findall(r'-?\d+', text) |
|
|
if digits: |
|
|
num = int(digits[0]) |
|
|
break |
|
|
numbers.append(num) |
|
|
|
|
|
elif self.backend == "pytesseract": |
|
|
text = self.pytesseract.image_to_string( |
|
|
cell_crop, config='--psm 7 -c tessedit_char_whitelist=0123456789-' |
|
|
) |
|
|
digits = re.findall(r'-?\d+', text) |
|
|
numbers.append(int(digits[0]) if digits else None) |
|
|
|
|
|
return numbers |
|
|
|
|
|
|
|
|
class SequenceEvaluator: |
|
|
"""Evaluator for sequence prediction task.""" |
|
|
|
|
|
def __init__(self, config: EvalConfig): |
|
|
self.config = config |
|
|
self.output_dir = Path(config.output_dir) |
|
|
self.output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
self.pipeline = self._load_pipeline() |
|
|
|
|
|
|
|
|
self.ocr = OCRExtractor(backend="easyocr") |
|
|
|
|
|
def _load_pipeline(self): |
|
|
"""Load pipeline based on model type.""" |
|
|
if self.config.model_type == ModelType.QWEN_IMAGE_EDIT: |
|
|
return self._load_qwen_pipeline() |
|
|
elif self.config.model_type == ModelType.FLUX2_KLEIN: |
|
|
return self._load_flux2_klein_pipeline() |
|
|
else: |
|
|
raise ValueError(f"Unknown model type: {self.config.model_type}") |
|
|
|
|
|
def _load_qwen_pipeline(self): |
|
|
"""Load QwenImageEditPlusPipeline.""" |
|
|
from diffusers import QwenImageEditPlusPipeline |
|
|
|
|
|
pipeline = QwenImageEditPlusPipeline.from_pretrained( |
|
|
self.config.model_id, |
|
|
torch_dtype=self.config.dtype, |
|
|
) |
|
|
pipeline.to(self.config.device) |
|
|
pipeline.set_progress_bar_config(disable=True) |
|
|
return pipeline |
|
|
|
|
|
def _load_flux2_klein_pipeline(self): |
|
|
"""Load Flux2KleinPipeline.""" |
|
|
from diffusers import Flux2KleinPipeline |
|
|
|
|
|
pipeline = Flux2KleinPipeline.from_pretrained( |
|
|
self.config.model_id, |
|
|
torch_dtype=self.config.dtype, |
|
|
) |
|
|
pipeline.enable_model_cpu_offload() |
|
|
pipeline.set_progress_bar_config(disable=True) |
|
|
return pipeline |
|
|
|
|
|
def _load_images(self, image_paths: list[str], image_dir: Path) -> list[Image.Image]: |
|
|
"""Load images from paths.""" |
|
|
return [Image.open(image_dir / p).convert("RGB") for p in image_paths] |
|
|
|
|
|
def predict(self, images: list[Image.Image]) -> Image.Image: |
|
|
""" |
|
|
Generate prediction image given input images. |
|
|
|
|
|
Args: |
|
|
images: List of input images (context + query). |
|
|
|
|
|
Returns: |
|
|
Generated image with predicted number. |
|
|
""" |
|
|
generator = torch.Generator(device=self.config.device).manual_seed(self.config.seed) |
|
|
|
|
|
if self.config.model_type == ModelType.QWEN_IMAGE_EDIT: |
|
|
inputs = { |
|
|
"image": images, |
|
|
"prompt": self.config.prompt, |
|
|
"generator": generator, |
|
|
"true_cfg_scale": self.config.true_cfg_scale, |
|
|
"negative_prompt": self.config.negative_prompt, |
|
|
"num_inference_steps": self.config.num_inference_steps, |
|
|
} |
|
|
|
|
|
elif self.config.model_type == ModelType.FLUX2_KLEIN: |
|
|
|
|
|
inputs = { |
|
|
"image": images, |
|
|
"prompt": self.config.prompt, |
|
|
"generator": generator, |
|
|
"guidance_scale": self.config.guidance_scale, |
|
|
"num_inference_steps": self.config.num_inference_steps, |
|
|
"height": self.config.height, |
|
|
"width": self.config.width, |
|
|
} |
|
|
|
|
|
with torch.inference_mode(): |
|
|
output = self.pipeline(**inputs) |
|
|
|
|
|
return output.images[0] |
|
|
|
|
|
def evaluate_sample(self, sample: dict, image_dir: Path) -> dict: |
|
|
""" |
|
|
Evaluate a single sample. |
|
|
|
|
|
Args: |
|
|
sample: Sample metadata dict. |
|
|
image_dir: Directory containing images. |
|
|
|
|
|
Returns: |
|
|
Evaluation result dict. |
|
|
""" |
|
|
|
|
|
images = self._load_images(sample["images"], image_dir) |
|
|
|
|
|
|
|
|
pred_image = self.predict(images) |
|
|
|
|
|
|
|
|
pred_path = self.output_dir / f"{sample['id']:05d}_pred.png" |
|
|
pred_image.save(pred_path) |
|
|
|
|
|
|
|
|
pred_number = self.ocr.extract_last_number(pred_image) |
|
|
|
|
|
|
|
|
gt_number = sample["answer"] |
|
|
|
|
|
|
|
|
correct = pred_number == gt_number |
|
|
|
|
|
return { |
|
|
"id": sample["id"], |
|
|
"seq_type": sample["seq_type"], |
|
|
"gt_answer": gt_number, |
|
|
"pred_answer": pred_number, |
|
|
"correct": correct, |
|
|
"pred_image": str(pred_path), |
|
|
} |
|
|
|
|
|
def evaluate(self, split: str = "test") -> dict: |
|
|
""" |
|
|
Evaluate on entire dataset split. |
|
|
|
|
|
Args: |
|
|
split: Dataset split ("train" or "test"). |
|
|
|
|
|
Returns: |
|
|
Evaluation results summary. |
|
|
""" |
|
|
dataset_dir = Path(self.config.dataset_dir) |
|
|
|
|
|
|
|
|
with open(dataset_dir / f"{split}.json") as f: |
|
|
samples = json.load(f) |
|
|
|
|
|
image_dir = dataset_dir / split / "images" |
|
|
|
|
|
results = [] |
|
|
for sample in tqdm(samples, desc=f"Evaluating {split}"): |
|
|
result = self.evaluate_sample(sample, image_dir) |
|
|
results.append(result) |
|
|
|
|
|
|
|
|
total = len(results) |
|
|
correct = sum(r["correct"] for r in results) |
|
|
accuracy = correct / total if total > 0 else 0.0 |
|
|
|
|
|
|
|
|
type_stats = {} |
|
|
for r in results: |
|
|
seq_type = r["seq_type"] |
|
|
if seq_type not in type_stats: |
|
|
type_stats[seq_type] = {"correct": 0, "total": 0} |
|
|
type_stats[seq_type]["total"] += 1 |
|
|
if r["correct"]: |
|
|
type_stats[seq_type]["correct"] += 1 |
|
|
|
|
|
type_accuracy = { |
|
|
k: v["correct"] / v["total"] for k, v in type_stats.items() |
|
|
} |
|
|
|
|
|
summary = { |
|
|
"split": split, |
|
|
"model_type": self.config.model_type.value, |
|
|
"model_id": self.config.model_id, |
|
|
"total": total, |
|
|
"correct": correct, |
|
|
"accuracy": accuracy, |
|
|
"type_accuracy": type_accuracy, |
|
|
"results": results, |
|
|
} |
|
|
|
|
|
|
|
|
with open(self.output_dir / f"{split}_results.json", "w") as f: |
|
|
json.dump(summary, f, indent=2) |
|
|
|
|
|
return summary |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Run evaluation.""" |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Sequence Prediction Evaluation") |
|
|
parser.add_argument("--model", type=str, default="flux2-klein", |
|
|
choices=["qwen", "flux2-klein"], |
|
|
help="Model type to use") |
|
|
parser.add_argument("--model-id", type=str, default="", |
|
|
help="Custom model ID (optional)") |
|
|
parser.add_argument("--dataset-dir", type=str, default="sequence_dataset", |
|
|
help="Dataset directory") |
|
|
parser.add_argument("--output-dir", type=str, default="eval_results", |
|
|
help="Output directory") |
|
|
parser.add_argument("--steps", type=int, default=50, |
|
|
help="Number of inference steps") |
|
|
parser.add_argument("--seed", type=int, default=42, |
|
|
help="Random seed") |
|
|
args = parser.parse_args() |
|
|
|
|
|
config = EvalConfig( |
|
|
dataset_dir=args.dataset_dir, |
|
|
output_dir=args.output_dir, |
|
|
model_type=ModelType(args.model), |
|
|
model_id=args.model_id, |
|
|
num_inference_steps=args.steps, |
|
|
seed=args.seed, |
|
|
) |
|
|
|
|
|
print(f"Model: {config.model_type.value} ({config.model_id})") |
|
|
|
|
|
evaluator = SequenceEvaluator(config) |
|
|
results = evaluator.evaluate("test") |
|
|
|
|
|
print(f"\n{'='*50}") |
|
|
print(f"Evaluation Results ({config.model_type.value})") |
|
|
print(f"{'='*50}") |
|
|
print(f"Total samples: {results['total']}") |
|
|
print(f"Correct: {results['correct']}") |
|
|
print(f"Accuracy: {results['accuracy']:.2%}") |
|
|
print(f"\nPer-type accuracy:") |
|
|
for seq_type, acc in sorted(results["type_accuracy"].items()): |
|
|
print(f" {seq_type}: {acc:.2%}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |