|
|
""" |
|
|
Sequence Prediction Dataset Generator. |
|
|
|
|
|
Generates image pairs for sequence prediction tasks with various |
|
|
mathematical sequences (arithmetic, geometric, fibonacci, etc.) |
|
|
""" |
|
|
|
|
|
import json |
|
|
import random |
|
|
from pathlib import Path |
|
|
from typing import Callable |
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
import matplotlib.patches as patches |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def arithmetic_seq(start: int, diff: int, length: int = 4) -> list[int]: |
|
|
"""Arithmetic sequence: a, a+d, a+2d, ...""" |
|
|
return [start + i * diff for i in range(length)] |
|
|
|
|
|
|
|
|
def geometric_seq(start: int, ratio: int, length: int = 4) -> list[int]: |
|
|
"""Geometric sequence: a, a*r, a*r^2, ...""" |
|
|
return [start * (ratio ** i) for i in range(length)] |
|
|
|
|
|
|
|
|
def square_seq(start: int, length: int = 4) -> list[int]: |
|
|
"""Square numbers: n^2, (n+1)^2, ...""" |
|
|
return [(start + i) ** 2 for i in range(length)] |
|
|
|
|
|
|
|
|
def cube_seq(start: int, length: int = 4) -> list[int]: |
|
|
"""Cube numbers: n^3, (n+1)^3, ...""" |
|
|
return [(start + i) ** 3 for i in range(length)] |
|
|
|
|
|
|
|
|
def triangular_seq(start: int, length: int = 4) -> list[int]: |
|
|
"""Triangular numbers: n(n+1)/2""" |
|
|
return [(start + i) * (start + i + 1) // 2 for i in range(length)] |
|
|
|
|
|
|
|
|
def fibonacci_like_seq(a: int, b: int, length: int = 4) -> list[int]: |
|
|
"""Fibonacci-like: a, b, a+b, a+2b, ...""" |
|
|
seq = [a, b] |
|
|
for _ in range(length - 2): |
|
|
seq.append(seq[-1] + seq[-2]) |
|
|
return seq[:length] |
|
|
|
|
|
|
|
|
def prime_seq(start_idx: int, length: int = 4) -> list[int]: |
|
|
"""Prime numbers starting from index.""" |
|
|
primes = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47] |
|
|
return primes[start_idx:start_idx + length] |
|
|
|
|
|
|
|
|
def power_of_two_seq(start: int, length: int = 4) -> list[int]: |
|
|
"""Powers of 2: 2^n, 2^(n+1), ...""" |
|
|
return [2 ** (start + i) for i in range(length)] |
|
|
|
|
|
|
|
|
def factorial_seq(start: int, length: int = 4) -> list[int]: |
|
|
"""Factorial sequence: n!, (n+1)!, ...""" |
|
|
from math import factorial |
|
|
return [factorial(start + i) for i in range(length)] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SEQUENCE_TYPES = { |
|
|
"arithmetic": lambda rng: arithmetic_seq( |
|
|
rng.randint(1, 20), rng.randint(1, 10) |
|
|
), |
|
|
"arithmetic_neg": lambda rng: arithmetic_seq( |
|
|
rng.randint(20, 50), -rng.randint(1, 5) |
|
|
), |
|
|
"geometric_2": lambda rng: geometric_seq( |
|
|
rng.randint(1, 5), 2 |
|
|
), |
|
|
"geometric_3": lambda rng: geometric_seq( |
|
|
rng.randint(1, 3), 3 |
|
|
), |
|
|
"square": lambda rng: square_seq(rng.randint(1, 10)), |
|
|
"cube": lambda rng: cube_seq(rng.randint(1, 5)), |
|
|
"triangular": lambda rng: triangular_seq(rng.randint(1, 10)), |
|
|
"fibonacci": lambda rng: fibonacci_like_seq( |
|
|
rng.randint(1, 5), rng.randint(1, 5) |
|
|
), |
|
|
"prime": lambda rng: prime_seq(rng.randint(0, 10)), |
|
|
"power_of_2": lambda rng: power_of_two_seq(rng.randint(0, 6)), |
|
|
} |
|
|
|
|
|
|
|
|
def generate_sequence_pair(seq: list[int]) -> tuple[list, list]: |
|
|
""" |
|
|
Generate a pair of sequences for the task. |
|
|
|
|
|
Returns: |
|
|
(partial, complete): partial has last element as "", complete is full. |
|
|
""" |
|
|
partial = seq[:-1] + [""] |
|
|
return partial, seq |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def round_to_multiple(x: int, multiple: int = 16) -> int: |
|
|
"""Round x up to nearest multiple.""" |
|
|
return ((x + multiple - 1) // multiple) * multiple |
|
|
|
|
|
|
|
|
def create_number_grid( |
|
|
numbers: list, |
|
|
save_path: str, |
|
|
height: int = 224, |
|
|
width: int = 896, |
|
|
fontsize: int = 48, |
|
|
size_multiple: int = 16, |
|
|
) -> None: |
|
|
""" |
|
|
Create a 1xN grid image with numbers in each cell. |
|
|
|
|
|
Args: |
|
|
numbers: List of numbers/strings to display. |
|
|
save_path: Output file path. |
|
|
height: Target height in pixels (will be rounded to size_multiple). |
|
|
width: Target width in pixels (will be rounded to size_multiple). |
|
|
fontsize: Font size for the numbers. |
|
|
size_multiple: Ensure dimensions are multiples of this (default 16). |
|
|
""" |
|
|
from PIL import Image |
|
|
|
|
|
n = len(numbers) |
|
|
|
|
|
|
|
|
width = round_to_multiple(width, size_multiple) |
|
|
height = round_to_multiple(height, size_multiple) |
|
|
|
|
|
|
|
|
dpi = 100 |
|
|
fig_width = width / dpi |
|
|
fig_height = height / dpi |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(fig_width, fig_height), dpi=dpi) |
|
|
fig.subplots_adjust(left=0, right=1, top=1, bottom=0) |
|
|
|
|
|
for i, num in enumerate(numbers): |
|
|
rect = patches.Rectangle( |
|
|
(i, 0), 1, 1, linewidth=2, |
|
|
edgecolor='black', facecolor='white' |
|
|
) |
|
|
ax.add_patch(rect) |
|
|
ax.text( |
|
|
i + 0.5, 0.5, str(num), fontsize=fontsize, |
|
|
ha='center', va='center', fontweight='bold' |
|
|
) |
|
|
|
|
|
ax.set_xlim(0, n) |
|
|
ax.set_ylim(0, 1) |
|
|
ax.set_aspect('equal') |
|
|
ax.axis('off') |
|
|
|
|
|
|
|
|
fig.savefig(save_path, dpi=dpi, facecolor='white', edgecolor='none') |
|
|
plt.close(fig) |
|
|
|
|
|
|
|
|
img = Image.open(save_path) |
|
|
if img.size != (width, height): |
|
|
img = img.resize((width, height), Image.Resampling.LANCZOS) |
|
|
img.save(save_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SequenceDatasetGenerator: |
|
|
"""Generate sequence prediction dataset with train/test splits.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
output_dir: str, |
|
|
seed: int = 42, |
|
|
num_pairs: tuple[int, int] = (2, 3), |
|
|
seq_types: list[str] | None = None, |
|
|
image_height: int = 224, |
|
|
image_width: int = 896, |
|
|
fontsize: int = 48, |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
output_dir: Directory to save the dataset. |
|
|
seed: Random seed for reproducibility. |
|
|
num_pairs: Range of pairs per sample (min, max inclusive). |
|
|
seq_types: List of sequence types to use (None = all). |
|
|
image_height: Image height in pixels (rounded to 16). |
|
|
image_width: Image width in pixels (rounded to 16). |
|
|
fontsize: Font size for numbers. |
|
|
""" |
|
|
self.output_dir = Path(output_dir) |
|
|
self.rng = random.Random(seed) |
|
|
self.num_pairs = num_pairs |
|
|
self.seq_types = seq_types or list(SEQUENCE_TYPES.keys()) |
|
|
self.image_height = round_to_multiple(image_height, 16) |
|
|
self.image_width = round_to_multiple(image_width, 16) |
|
|
self.fontsize = fontsize |
|
|
|
|
|
|
|
|
for split in ["train", "test"]: |
|
|
(self.output_dir / split / "images").mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
def _generate_sample(self, sample_id: int) -> dict: |
|
|
"""Generate a single sample with multiple sequence pairs.""" |
|
|
num_pairs = self.rng.randint(*self.num_pairs) |
|
|
seq_type = self.rng.choice(self.seq_types) |
|
|
|
|
|
|
|
|
base_seq = SEQUENCE_TYPES[seq_type](self.rng) |
|
|
|
|
|
pairs = [] |
|
|
for i in range(num_pairs): |
|
|
|
|
|
if seq_type.startswith("arithmetic"): |
|
|
diff = base_seq[1] - base_seq[0] |
|
|
seq = [x + i * diff for x in base_seq] |
|
|
elif seq_type.startswith("geometric"): |
|
|
ratio = base_seq[1] // base_seq[0] if base_seq[0] != 0 else 2 |
|
|
seq = [x * (ratio ** i) for x in base_seq] |
|
|
else: |
|
|
|
|
|
seq = [x + i for x in base_seq] |
|
|
|
|
|
partial, complete = generate_sequence_pair(seq) |
|
|
pairs.append({ |
|
|
"partial": partial, |
|
|
"complete": complete, |
|
|
"answer": complete[-1], |
|
|
}) |
|
|
|
|
|
return { |
|
|
"id": sample_id, |
|
|
"seq_type": seq_type, |
|
|
"num_pairs": num_pairs, |
|
|
"pairs": pairs, |
|
|
} |
|
|
|
|
|
def _save_sample_images( |
|
|
self, sample: dict, split: str, include_last_answer: bool = True |
|
|
) -> dict: |
|
|
"""Save images for a sample and return metadata.""" |
|
|
sample_id = sample["id"] |
|
|
image_dir = self.output_dir / split / "images" |
|
|
|
|
|
images = [] |
|
|
img_idx = 0 |
|
|
|
|
|
for i, pair in enumerate(sample["pairs"]): |
|
|
|
|
|
partial_path = f"{sample_id:05d}_{img_idx}.png" |
|
|
create_number_grid( |
|
|
pair["partial"], image_dir / partial_path, |
|
|
height=self.image_height, width=self.image_width, |
|
|
fontsize=self.fontsize, |
|
|
) |
|
|
images.append(partial_path) |
|
|
img_idx += 1 |
|
|
|
|
|
|
|
|
is_last = (i == sample["num_pairs"] - 1) |
|
|
if include_last_answer or not is_last: |
|
|
complete_path = f"{sample_id:05d}_{img_idx}.png" |
|
|
create_number_grid( |
|
|
pair["complete"], image_dir / complete_path, |
|
|
height=self.image_height, width=self.image_width, |
|
|
fontsize=self.fontsize, |
|
|
) |
|
|
images.append(complete_path) |
|
|
img_idx += 1 |
|
|
|
|
|
return { |
|
|
"id": sample_id, |
|
|
"seq_type": sample["seq_type"], |
|
|
"num_pairs": sample["num_pairs"], |
|
|
"images": images, |
|
|
"answer": sample["pairs"][-1]["answer"], |
|
|
"sequences": [p["complete"] for p in sample["pairs"]], |
|
|
} |
|
|
|
|
|
def generate(self, num_train: int, num_test: int) -> None: |
|
|
""" |
|
|
Generate the full dataset. |
|
|
|
|
|
Args: |
|
|
num_train: Number of training samples. |
|
|
num_test: Number of test samples. |
|
|
""" |
|
|
train_meta, test_meta = [], [] |
|
|
|
|
|
|
|
|
print(f"Generating {num_train} training samples...") |
|
|
for i in range(num_train): |
|
|
sample = self._generate_sample(i) |
|
|
meta = self._save_sample_images(sample, "train", include_last_answer=True) |
|
|
train_meta.append(meta) |
|
|
if (i + 1) % 50 == 0: |
|
|
print(f" Train: {i + 1}/{num_train}") |
|
|
|
|
|
|
|
|
print(f"Generating {num_test} test samples...") |
|
|
for i in range(num_test): |
|
|
sample = self._generate_sample(num_train + i) |
|
|
meta = self._save_sample_images(sample, "test", include_last_answer=False) |
|
|
test_meta.append(meta) |
|
|
if (i + 1) % 50 == 0: |
|
|
print(f" Test: {i + 1}/{num_test}") |
|
|
|
|
|
|
|
|
with open(self.output_dir / "train.json", "w") as f: |
|
|
json.dump(train_meta, f, indent=2) |
|
|
with open(self.output_dir / "test.json", "w") as f: |
|
|
json.dump(test_meta, f, indent=2) |
|
|
|
|
|
print(f"\nDataset saved to {self.output_dir}") |
|
|
print(f" Train: {num_train} samples") |
|
|
print(f" Test: {num_test} samples") |
|
|
print(f" Image size: {self.image_width}x{self.image_height}") |
|
|
print(f" Sequence types: {self.seq_types}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
generator = SequenceDatasetGenerator( |
|
|
output_dir="/home/claude/sequence_dataset", |
|
|
seed=42, |
|
|
num_pairs=(2, 3), |
|
|
) |
|
|
generator.generate(num_train=100, num_test=20) |