| """Shared construction and loading helpers for the project's tokenizer.""" |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass, field |
| import json |
| from pathlib import Path |
| import re |
| from typing import Any, Iterable |
|
|
|
|
| SPECIAL_TOKENS = [ |
| "<|pad|>", |
| "<|bos|>", |
| "<|eos|>", |
| "<|unk|>", |
| "<|endoftext|>", |
| ] |
| EOT_ID = SPECIAL_TOKENS.index("<|endoftext|>") |
| ARITHMETIC_TOKENS = ("+", "-", "*", "/", "=", "(", ")") |
| MAX_PLACE_ID = 64 |
| PLACE_OVERFLOW_ID = MAX_PLACE_ID + 1 |
| PLACE_VOCAB_SIZE = PLACE_OVERFLOW_ID + 1 |
| RESULT_ROLE_ID = 10 |
| SPACE_ROLE_ID = 11 |
| ROLE_VOCAB_SIZE = SPACE_ROLE_ID + 1 |
| MAX_OPERAND_ROLES = 9 |
|
|
|
|
| @dataclass(frozen=True) |
| class FusionEncoding: |
| ids: list[int] |
| place_ids: list[int] |
| role_ids: list[int] |
| tokens: list[str] = field(default_factory=list) |
|
|
| @property |
| def input_ids(self) -> list[int]: |
| return self.ids |
|
|
| def __len__(self) -> int: |
| return len(self.ids) |
|
|
| def __iter__(self): |
| return iter(self.ids) |
|
|
| def __post_init__(self) -> None: |
| if not (len(self.ids) == len(self.place_ids) == len(self.role_ids)): |
| raise ValueError("Fusion tokenizer streams must have equal length") |
|
|
|
|
| def build_tokenizer() -> Any: |
| """Build a byte-level BPE tokenizer with explicit lossless boundaries.""" |
| from tokenizers import Regex, Tokenizer, decoders, models, pre_tokenizers |
|
|
| tokenizer = Tokenizer(models.BPE(unk_token="<|unk|>")) |
| tokenizer.pre_tokenizer = pre_tokenizers.Sequence( |
| [ |
| pre_tokenizers.Split( |
| Regex(r"\s+|\d|[+\-*/=()]|[^\s\d+\-*/=()]+"), |
| behavior="isolated", |
| ), |
| pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False), |
| ] |
| ) |
| tokenizer.decoder = decoders.ByteLevel() |
| return tokenizer |
|
|
|
|
| class FusionTokenizer: |
| """Runtime wrapper adding LSD-first digit streams to a trained BPE tokenizer.""" |
|
|
| _digit_span_re = re.compile(r"\d+") |
|
|
| def __init__(self, tokenizer: Any): |
| self.tokenizer = tokenizer |
| self._digit_token_ids = frozenset( |
| token_id |
| for digit in "0123456789" |
| if (token_id := self.tokenizer.token_to_id(digit)) is not None |
| ) |
| self._digit_id_to_text = { |
| int(self.tokenizer.token_to_id(digit)): digit |
| for digit in "0123456789" |
| if self.tokenizer.token_to_id(digit) is not None |
| } |
| self._equals_id = self.tokenizer.token_to_id("=") |
| self._special_token_ids = frozenset( |
| token_id |
| for token in SPECIAL_TOKENS |
| if (token_id := self.tokenizer.token_to_id(token)) is not None |
| ) |
| if len(self._digit_token_ids) != 10: |
| raise ValueError("Tokenizer vocabulary must contain atomic digit tokens 0-9") |
| if self._equals_id is None: |
| raise ValueError("Tokenizer vocabulary must contain an atomic '=' token") |
|
|
| def __getattr__(self, name: str) -> Any: |
| return getattr(self.tokenizer, name) |
|
|
| @property |
| def digit_token_ids(self) -> frozenset[int]: |
| return self._digit_token_ids |
|
|
| @property |
| def special_token_ids(self) -> frozenset[int]: |
| return self._special_token_ids |
|
|
| def get_vocab_size(self, with_added_tokens: bool = True) -> int: |
| return int(self.tokenizer.get_vocab_size(with_added_tokens=with_added_tokens)) |
|
|
| def get_vocab(self, with_added_tokens: bool = True) -> dict[str, int]: |
| return self.tokenizer.get_vocab(with_added_tokens=with_added_tokens) |
|
|
| def token_to_id(self, token: str) -> int | None: |
| return self.tokenizer.token_to_id(token) |
|
|
| def id_to_token(self, token_id: int) -> str | None: |
| return self.tokenizer.id_to_token(int(token_id)) |
|
|
| @classmethod |
| def _reverse_digit_spans(cls, text: str) -> str: |
| return cls._digit_span_re.sub(lambda match: match.group(0)[::-1], text) |
|
|
| def _decode_token_piece(self, token_id: int) -> str: |
| return self.tokenizer.decode([int(token_id)], skip_special_tokens=False) |
|
|
| @staticmethod |
| def _is_equation_whitespace(piece: str) -> bool: |
| return bool(piece) and piece.isspace() and "\n" not in piece and "\r" not in piece |
|
|
| def _is_equation_piece(self, token_id: int, piece: str) -> bool: |
| if token_id in self._special_token_ids: |
| return False |
| if token_id in self._digit_token_ids: |
| return True |
| if self._is_equation_whitespace(piece): |
| return True |
| return len(piece) == 1 and piece in set(ARITHMETIC_TOKENS) |
|
|
| def _annotate_equation_span( |
| self, |
| ids: list[int], |
| pieces: list[str], |
| start: int, |
| end: int, |
| role_ids: list[int], |
| ) -> None: |
| equals_positions = [ |
| index |
| for index in range(start, end) |
| if ids[index] == self._equals_id |
| ] |
| if len(equals_positions) != 1: |
| return |
| equals_position = equals_positions[0] |
|
|
| digit_runs: list[tuple[int, int]] = [] |
| index = start |
| while index < end: |
| if ids[index] not in self._digit_token_ids: |
| index += 1 |
| continue |
| run_start = index |
| while index < end and ids[index] in self._digit_token_ids: |
| index += 1 |
| digit_runs.append((run_start, index)) |
|
|
| operand_runs = [(a, b) for a, b in digit_runs if b <= equals_position] |
| result_runs = [(a, b) for a, b in digit_runs if a > equals_position] |
| if not operand_runs or not result_runs or len(operand_runs) > MAX_OPERAND_ROLES: |
| return |
|
|
| for index in range(start, end): |
| if self._is_equation_whitespace(pieces[index]): |
| role_ids[index] = SPACE_ROLE_ID |
|
|
| for role, (run_start, run_end) in enumerate(operand_runs, start=1): |
| for index in range(run_start, run_end): |
| role_ids[index] = role |
| for run_start, run_end in result_runs: |
| for index in range(run_start, run_end): |
| role_ids[index] = RESULT_ROLE_ID |
|
|
| def annotate_ids(self, ids: Iterable[int]) -> tuple[list[int], list[int]]: |
| input_ids = [int(token_id) for token_id in ids] |
| place_ids = [0] * len(input_ids) |
| role_ids = [0] * len(input_ids) |
| pieces = [self._decode_token_piece(token_id) for token_id in input_ids] |
|
|
| index = 0 |
| while index < len(input_ids): |
| if input_ids[index] not in self._digit_token_ids: |
| index += 1 |
| continue |
| run_start = index |
| while index < len(input_ids) and input_ids[index] in self._digit_token_ids: |
| offset = index - run_start + 1 |
| place_ids[index] = min(offset, PLACE_OVERFLOW_ID) |
| index += 1 |
|
|
| span_start: int | None = None |
| for index, (token_id, piece) in enumerate(zip(input_ids, pieces, strict=True)): |
| if self._is_equation_piece(token_id, piece): |
| if span_start is None: |
| span_start = index |
| continue |
| if span_start is not None: |
| self._annotate_equation_span(input_ids, pieces, span_start, index, role_ids) |
| span_start = None |
| if span_start is not None: |
| self._annotate_equation_span(input_ids, pieces, span_start, len(input_ids), role_ids) |
|
|
| return place_ids, role_ids |
|
|
| def encode(self, text: str, *args, **kwargs) -> FusionEncoding: |
| transformed = self._reverse_digit_spans(text) |
| encoding = self.tokenizer.encode(transformed, *args, **kwargs) |
| ids = [int(token_id) for token_id in encoding.ids] |
| place_ids, role_ids = self.annotate_ids(ids) |
| return FusionEncoding( |
| ids=ids, |
| place_ids=place_ids, |
| role_ids=role_ids, |
| tokens=list(getattr(encoding, "tokens", [])), |
| ) |
|
|
| def encode_batch(self, texts: list[str], *args, **kwargs) -> list[FusionEncoding]: |
| return [self.encode(text, *args, **kwargs) for text in texts] |
|
|
| def decode( |
| self, |
| token_ids: Iterable[int], |
| skip_special_tokens: bool = True, |
| ) -> str: |
| pieces: list[str] = [] |
| text_ids: list[int] = [] |
| digit_buffer: list[str] = [] |
|
|
| def flush_text() -> None: |
| if text_ids: |
| pieces.append( |
| self.tokenizer.decode( |
| text_ids, |
| skip_special_tokens=skip_special_tokens, |
| ) |
| ) |
| text_ids.clear() |
|
|
| def flush_digits() -> None: |
| if digit_buffer: |
| pieces.extend(reversed(digit_buffer)) |
| digit_buffer.clear() |
|
|
| for raw_id in token_ids: |
| token_id = int(raw_id) |
| if token_id in self._digit_token_ids: |
| flush_text() |
| digit_buffer.append(self._digit_id_to_text[token_id]) |
| continue |
|
|
| flush_digits() |
| text_ids.append(token_id) |
|
|
| flush_text() |
| flush_digits() |
| return "".join(pieces) |
|
|
|
|
| def build_trainer(vocab_size: int, min_frequency: int) -> Any: |
| from tokenizers import pre_tokenizers, trainers |
|
|
| return trainers.BpeTrainer( |
| vocab_size=vocab_size, |
| min_frequency=min_frequency, |
| special_tokens=SPECIAL_TOKENS, |
| initial_alphabet=pre_tokenizers.ByteLevel.alphabet(), |
| ) |
|
|
|
|
| def tokenizer_files(tokenizer_dir: Path) -> tuple[Path, Path, Path]: |
| return ( |
| tokenizer_dir / "tokenizer.json", |
| tokenizer_dir / "vocab.json", |
| tokenizer_dir / "merges.txt", |
| ) |
|
|
|
|
| def validate_tokenizer(tokenizer_dir: Path) -> None: |
| tokenizer_json, vocab_path, merges_path = tokenizer_files(tokenizer_dir) |
| if not tokenizer_json.exists(): |
| raise FileNotFoundError( |
| f"Missing {tokenizer_json}. Retrain with train_tokenizer.py so the " |
| "whitespace and digit boundary rules are preserved." |
| ) |
| if vocab_path.exists(): |
| with vocab_path.open("r", encoding="utf-8") as f: |
| vocab = json.load(f) |
| else: |
| with tokenizer_json.open("r", encoding="utf-8") as f: |
| tokenizer_data = json.load(f) |
| vocab = tokenizer_data.get("model", {}).get("vocab") |
| if not isinstance(vocab, dict): |
| raise FileNotFoundError(f"Missing vocab.json and no embedded vocab in {tokenizer_json}") |
|
|
| max_id = max(vocab.values()) |
| if max_id > 65_535: |
| raise ValueError(f"Tokenizer max id {max_id} does not fit in uint16") |
| if vocab.get("<|endoftext|>") != EOT_ID: |
| raise ValueError( |
| f"Expected <|endoftext|> id {EOT_ID}, " |
| f"got {vocab.get('<|endoftext|>')}" |
| ) |
| missing = [ |
| token |
| for token in (*[str(value) for value in range(10)], *ARITHMETIC_TOKENS) |
| if token not in vocab |
| ] |
| if missing: |
| raise ValueError(f"Tokenizer missing required atomic tokens: {missing}") |
|
|
|
|
| def load_tokenizer(tokenizer_dir: Path) -> Any: |
| from tokenizers import Tokenizer |
|
|
| validate_tokenizer(tokenizer_dir) |
| tokenizer_json, _, _ = tokenizer_files(tokenizer_dir) |
| return FusionTokenizer(Tokenizer.from_file(str(tokenizer_json))) |
|
|