cmpatino's picture
cmpatino HF Staff
Change tokenization visualizer
8ef5720
import html
import gradio as gr
from datasets import load_dataset
from transformers import AutoTokenizer
def build_alignment_groups_from_ids(student_tokenizer, teacher_tokenizer, student_token_ids, teacher_token_ids):
"""
Build alignment groups using a greedy substring-equality algorithm on decoded token pieces.
Adapted from TRL's GoldTrainer._build_alignment_groups_from_ids.
"""
def to_canonical_pieces(tok, ids):
pieces = []
prev = ""
for k in range(len(ids)):
cur = tok.decode(ids[: k + 1], skip_special_tokens=False, clean_up_tokenization_spaces=False)
pieces.append(cur[len(prev):])
prev = cur
return pieces
s_pieces = to_canonical_pieces(student_tokenizer, student_token_ids)
t_pieces = to_canonical_pieces(teacher_tokenizer, teacher_token_ids)
i = j = 0
s_buf = t_buf = ""
s_group = []
t_group = []
s_groups = []
t_groups = []
def flush():
if s_group and t_group:
s_groups.append(s_group.copy())
t_groups.append(t_group.copy())
while i < len(s_pieces) or j < len(t_pieces):
if s_buf == t_buf and s_buf != "":
flush()
s_buf = t_buf = ""
s_group = []
t_group = []
continue
if s_buf == "" and i < len(s_pieces):
s_buf += s_pieces[i]
s_group.append(i)
i += 1
continue
if t_buf == "" and j < len(t_pieces):
t_buf += t_pieces[j]
t_group.append(j)
j += 1
continue
if len(s_buf) <= len(t_buf):
if i < len(s_pieces):
s_buf += s_pieces[i]
s_group.append(i)
i += 1
elif j < len(t_pieces):
t_buf += t_pieces[j]
t_group.append(j)
j += 1
else:
if j < len(t_pieces):
t_buf += t_pieces[j]
t_group.append(j)
j += 1
elif i < len(s_pieces):
s_buf += s_pieces[i]
s_group.append(i)
i += 1
if s_buf == t_buf and s_group and t_group:
flush()
elif s_group or t_group:
if not s_group:
s_group = []
if not t_group:
t_group = []
if s_group or t_group:
s_groups.append(s_group.copy() if s_group else [])
t_groups.append(t_group.copy() if t_group else [])
return s_groups, t_groups
def _decode_pieces(tokenizer, token_ids, indices):
"""Decode individual token pieces for a group of token indices."""
return [
tokenizer.decode([token_ids[idx]], skip_special_tokens=False, clean_up_tokenization_spaces=False)
for idx in indices
]
def _format_pieces(pieces):
"""Format token pieces as a list, e.g. '["hel", "lo"]'."""
inner = ", ".join(f'"{p}"' for p in pieces)
return f"[{inner}]"
def highlight_groups(student_tokenizer, teacher_tokenizer, student_token_ids, teacher_token_ids, s_groups, t_groups):
"""Build an HTML string with highlighted misalignment regions."""
parts = []
first_purple = True
for k in range(len(s_groups)):
s_ids = [student_token_ids[idx] for idx in s_groups[k]]
text = student_tokenizer.decode(s_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)
escaped = html.escape(text)
s_multi = len(s_groups[k]) > 1
t_multi = len(t_groups[k]) > 1
if s_multi and t_multi:
if first_purple:
s_pieces = _decode_pieces(student_tokenizer, student_token_ids, s_groups[k])
t_pieces = _decode_pieces(teacher_tokenizer, teacher_token_ids, t_groups[k])
tooltip = html.escape(f'Student: {_format_pieces(s_pieces)} / Teacher: {_format_pieces(t_pieces)}')
parts.append(f'<span style="background-color: #b388ff;" title="{tooltip}">{escaped}</span>')
first_purple = False
else:
parts.append(f'<span style="background-color: #b388ff;">{escaped}</span>')
elif s_multi:
s_pieces = _decode_pieces(student_tokenizer, student_token_ids, s_groups[k])
tooltip = html.escape(f'Student: {_format_pieces(s_pieces)}')
parts.append(f'<span style="background-color: #ffcc80;" title="{tooltip}">{escaped}</span>')
elif t_multi:
t_pieces = _decode_pieces(teacher_tokenizer, teacher_token_ids, t_groups[k])
tooltip = html.escape(f'Teacher: {_format_pieces(t_pieces)}')
parts.append(f'<span style="background-color: #90caf9;" title="{tooltip}">{escaped}</span>')
else:
parts.append(escaped)
return "".join(parts)
def make_html_block(student_tokenizer, teacher_tokenizer, text, idx):
"""Process a single text and return its highlighted HTML block."""
s_ids = student_tokenizer.encode(text, add_special_tokens=False)
t_ids = teacher_tokenizer.encode(text, add_special_tokens=False)
s_groups, t_groups = build_alignment_groups_from_ids(
student_tokenizer, teacher_tokenizer, s_ids, t_ids
)
highlighted = highlight_groups(student_tokenizer, teacher_tokenizer, s_ids, t_ids, s_groups, t_groups)
# Build tokenized views with alternating colors
s_tokens = [student_tokenizer.decode([tid], skip_special_tokens=False, clean_up_tokenization_spaces=False) for tid in s_ids]
t_tokens = [teacher_tokenizer.decode([tid], skip_special_tokens=False, clean_up_tokenization_spaces=False) for tid in t_ids]
color1 = "#fff9c4"
color2 = "#b2ebf2"
s_tokens_html = "".join(
f'<span style="background-color:{color1 if i % 2 == 0 else color2};">{html.escape(t)}</span>'
for i, t in enumerate(s_tokens)
)
t_tokens_html = "".join(
f'<span style="background-color:{color1 if i % 2 == 0 else color2};">{html.escape(t)}</span>'
for i, t in enumerate(t_tokens)
)
tokenized_section = f'''
<div style="margin-bottom:15px;">
<details style="margin-bottom:10px;">
<summary style="cursor:pointer; font-weight:bold; user-select:none;">Show tokenization details</summary>
<div style="display:grid; grid-template-columns:1fr 1fr; gap:15px; margin-top:10px;">
<div style="border:1px solid #ddd; padding:10px; border-radius:5px;">
<strong style="color:#f57c00;">Student Tokens ({len(s_ids)})</strong>
<div style="margin-top:8px; font-size:12px; word-break:break-word;">{s_tokens_html}</div>
</div>
<div style="border:1px solid #ddd; padding:10px; border-radius:5px;">
<strong style="color:#1976d2;">Teacher Tokens ({len(t_ids)})</strong>
<div style="margin-top:8px; font-size:12px; word-break:break-word;">{t_tokens_html}</div>
</div>
</div>
</details>
</div>
'''
return (
f'<div style="border:1px solid #ccc; padding:10px; margin:10px 0; '
f'border-radius:5px; white-space:pre-wrap; font-family:monospace; font-size:13px;">'
f"<strong>Text {idx + 1}</strong> "
f"(student tokens: {len(s_ids)}, teacher tokens: {len(t_ids)})<br><br>"
f"{tokenized_section}"
f"{highlighted}"
f"</div>"
)
def process_texts(student_model_id, teacher_model_id, dataset_id, dataset_config, progress=gr.Progress()):
"""Load tokenizers and dataset, compute first row only."""
progress(0, desc="Loading tokenizers...")
student_tokenizer = AutoTokenizer.from_pretrained(student_model_id)
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_id)
progress(0.5, desc="Loading dataset...")
config = dataset_config.strip() if dataset_config and dataset_config.strip() else None
ds = load_dataset(dataset_id, name=config, split="train")
rows = ds.select(range(min(10, len(ds))))
texts = ["".join(msg["content"] for msg in row["messages"]) for row in rows]
progress(0.8, desc="Processing first text...")
first_block = make_html_block(student_tokenizer, teacher_tokenizer, texts[0], 0)
cache = {0: first_block}
progress(1, desc="Done!")
return student_tokenizer, teacher_tokenizer, texts, cache, 0, render_page(cache, 0, len(texts))
LEGEND = (
'<div style="margin-bottom:15px; font-family:sans-serif;">'
"<strong>Legend:</strong> "
'<span style="background-color:#ffcc80; padding:2px 8px; margin-right:8px;">Student token split (orange)</span>'
'<span style="background-color:#90caf9; padding:2px 8px; margin-right:8px;">Teacher token split (blue)</span>'
'<span style="background-color:#b388ff; padding:2px 8px;">Both (purple)</span>'
"</div>"
)
def render_page(cache, idx, total):
if not cache:
return ""
counter = f'<div style="font-family:sans-serif; margin-bottom:10px;">Text {idx + 1} of {total}</div>'
return LEGEND + counter + cache[idx]
def go_prev(cache, idx, texts):
idx = max(0, idx - 1)
return cache, idx, render_page(cache, idx, len(texts))
def go_next(student_tokenizer, teacher_tokenizer, texts, cache, idx):
idx = min(len(texts) - 1, idx + 1)
if idx not in cache:
cache[idx] = make_html_block(student_tokenizer, teacher_tokenizer, texts[idx], idx)
return cache, idx, render_page(cache, idx, len(texts))
with gr.Blocks(title="Tokenization Diff") as demo:
gr.Markdown("# Tokenization Diff\nVisualize where two tokenizers differ in how they tokenize text.")
with gr.Row():
student_model = gr.Textbox(label="Student Model", value="Qwen/Qwen3-8B")
teacher_model = gr.Textbox(label="Teacher Model", value="deepseek-ai/DeepSeek-Math-V2")
dataset_id = gr.Textbox(label="Dataset ID", value="lm-provers/FineProofs-SFT")
dataset_config = gr.Textbox(label="Dataset Config", value="default")
submit_btn = gr.Button("Submit", variant="primary")
student_tok_state = gr.State(None)
teacher_tok_state = gr.State(None)
texts_state = gr.State([])
cache_state = gr.State({})
idx_state = gr.State(0)
output = gr.HTML(label="Tokenization Diff Output")
with gr.Row():
prev_btn = gr.Button("Previous")
next_btn = gr.Button("Next")
submit_btn.click(
fn=process_texts,
inputs=[student_model, teacher_model, dataset_id, dataset_config],
outputs=[student_tok_state, teacher_tok_state, texts_state, cache_state, idx_state, output],
)
prev_btn.click(
fn=go_prev,
inputs=[cache_state, idx_state, texts_state],
outputs=[cache_state, idx_state, output],
)
next_btn.click(
fn=go_next,
inputs=[student_tok_state, teacher_tok_state, texts_state, cache_state, idx_state],
outputs=[cache_state, idx_state, output],
)
if __name__ == "__main__":
demo.launch()