import html import gradio as gr from datasets import load_dataset from transformers import AutoTokenizer def build_alignment_groups_from_ids(student_tokenizer, teacher_tokenizer, student_token_ids, teacher_token_ids): """ Build alignment groups using a greedy substring-equality algorithm on decoded token pieces. Adapted from TRL's GoldTrainer._build_alignment_groups_from_ids. """ def to_canonical_pieces(tok, ids): pieces = [] prev = "" for k in range(len(ids)): cur = tok.decode(ids[: k + 1], skip_special_tokens=False, clean_up_tokenization_spaces=False) pieces.append(cur[len(prev):]) prev = cur return pieces s_pieces = to_canonical_pieces(student_tokenizer, student_token_ids) t_pieces = to_canonical_pieces(teacher_tokenizer, teacher_token_ids) i = j = 0 s_buf = t_buf = "" s_group = [] t_group = [] s_groups = [] t_groups = [] def flush(): if s_group and t_group: s_groups.append(s_group.copy()) t_groups.append(t_group.copy()) while i < len(s_pieces) or j < len(t_pieces): if s_buf == t_buf and s_buf != "": flush() s_buf = t_buf = "" s_group = [] t_group = [] continue if s_buf == "" and i < len(s_pieces): s_buf += s_pieces[i] s_group.append(i) i += 1 continue if t_buf == "" and j < len(t_pieces): t_buf += t_pieces[j] t_group.append(j) j += 1 continue if len(s_buf) <= len(t_buf): if i < len(s_pieces): s_buf += s_pieces[i] s_group.append(i) i += 1 elif j < len(t_pieces): t_buf += t_pieces[j] t_group.append(j) j += 1 else: if j < len(t_pieces): t_buf += t_pieces[j] t_group.append(j) j += 1 elif i < len(s_pieces): s_buf += s_pieces[i] s_group.append(i) i += 1 if s_buf == t_buf and s_group and t_group: flush() elif s_group or t_group: if not s_group: s_group = [] if not t_group: t_group = [] if s_group or t_group: s_groups.append(s_group.copy() if s_group else []) t_groups.append(t_group.copy() if t_group else []) return s_groups, t_groups def _decode_pieces(tokenizer, token_ids, indices): """Decode individual token pieces for a group of token indices.""" return [ tokenizer.decode([token_ids[idx]], skip_special_tokens=False, clean_up_tokenization_spaces=False) for idx in indices ] def _format_pieces(pieces): """Format token pieces as a list, e.g. '["hel", "lo"]'.""" inner = ", ".join(f'"{p}"' for p in pieces) return f"[{inner}]" def highlight_groups(student_tokenizer, teacher_tokenizer, student_token_ids, teacher_token_ids, s_groups, t_groups): """Build an HTML string with highlighted misalignment regions.""" parts = [] first_purple = True for k in range(len(s_groups)): s_ids = [student_token_ids[idx] for idx in s_groups[k]] text = student_tokenizer.decode(s_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False) escaped = html.escape(text) s_multi = len(s_groups[k]) > 1 t_multi = len(t_groups[k]) > 1 if s_multi and t_multi: if first_purple: s_pieces = _decode_pieces(student_tokenizer, student_token_ids, s_groups[k]) t_pieces = _decode_pieces(teacher_tokenizer, teacher_token_ids, t_groups[k]) tooltip = html.escape(f'Student: {_format_pieces(s_pieces)} / Teacher: {_format_pieces(t_pieces)}') parts.append(f'{escaped}') first_purple = False else: parts.append(f'{escaped}') elif s_multi: s_pieces = _decode_pieces(student_tokenizer, student_token_ids, s_groups[k]) tooltip = html.escape(f'Student: {_format_pieces(s_pieces)}') parts.append(f'{escaped}') elif t_multi: t_pieces = _decode_pieces(teacher_tokenizer, teacher_token_ids, t_groups[k]) tooltip = html.escape(f'Teacher: {_format_pieces(t_pieces)}') parts.append(f'{escaped}') else: parts.append(escaped) return "".join(parts) def make_html_block(student_tokenizer, teacher_tokenizer, text, idx): """Process a single text and return its highlighted HTML block.""" s_ids = student_tokenizer.encode(text, add_special_tokens=False) t_ids = teacher_tokenizer.encode(text, add_special_tokens=False) s_groups, t_groups = build_alignment_groups_from_ids( student_tokenizer, teacher_tokenizer, s_ids, t_ids ) highlighted = highlight_groups(student_tokenizer, teacher_tokenizer, s_ids, t_ids, s_groups, t_groups) # Build tokenized views with alternating colors s_tokens = [student_tokenizer.decode([tid], skip_special_tokens=False, clean_up_tokenization_spaces=False) for tid in s_ids] t_tokens = [teacher_tokenizer.decode([tid], skip_special_tokens=False, clean_up_tokenization_spaces=False) for tid in t_ids] color1 = "#fff9c4" color2 = "#b2ebf2" s_tokens_html = "".join( f'{html.escape(t)}' for i, t in enumerate(s_tokens) ) t_tokens_html = "".join( f'{html.escape(t)}' for i, t in enumerate(t_tokens) ) tokenized_section = f'''
Show tokenization details
Student Tokens ({len(s_ids)})
{s_tokens_html}
Teacher Tokens ({len(t_ids)})
{t_tokens_html}
''' return ( f'
' f"Text {idx + 1} " f"(student tokens: {len(s_ids)}, teacher tokens: {len(t_ids)})

" f"{tokenized_section}" f"{highlighted}" f"
" ) def process_texts(student_model_id, teacher_model_id, dataset_id, dataset_config, progress=gr.Progress()): """Load tokenizers and dataset, compute first row only.""" progress(0, desc="Loading tokenizers...") student_tokenizer = AutoTokenizer.from_pretrained(student_model_id) teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_id) progress(0.5, desc="Loading dataset...") config = dataset_config.strip() if dataset_config and dataset_config.strip() else None ds = load_dataset(dataset_id, name=config, split="train") rows = ds.select(range(min(10, len(ds)))) texts = ["".join(msg["content"] for msg in row["messages"]) for row in rows] progress(0.8, desc="Processing first text...") first_block = make_html_block(student_tokenizer, teacher_tokenizer, texts[0], 0) cache = {0: first_block} progress(1, desc="Done!") return student_tokenizer, teacher_tokenizer, texts, cache, 0, render_page(cache, 0, len(texts)) LEGEND = ( '
' "Legend: " 'Student token split (orange)' 'Teacher token split (blue)' 'Both (purple)' "
" ) def render_page(cache, idx, total): if not cache: return "" counter = f'
Text {idx + 1} of {total}
' return LEGEND + counter + cache[idx] def go_prev(cache, idx, texts): idx = max(0, idx - 1) return cache, idx, render_page(cache, idx, len(texts)) def go_next(student_tokenizer, teacher_tokenizer, texts, cache, idx): idx = min(len(texts) - 1, idx + 1) if idx not in cache: cache[idx] = make_html_block(student_tokenizer, teacher_tokenizer, texts[idx], idx) return cache, idx, render_page(cache, idx, len(texts)) with gr.Blocks(title="Tokenization Diff") as demo: gr.Markdown("# Tokenization Diff\nVisualize where two tokenizers differ in how they tokenize text.") with gr.Row(): student_model = gr.Textbox(label="Student Model", value="Qwen/Qwen3-8B") teacher_model = gr.Textbox(label="Teacher Model", value="deepseek-ai/DeepSeek-Math-V2") dataset_id = gr.Textbox(label="Dataset ID", value="lm-provers/FineProofs-SFT") dataset_config = gr.Textbox(label="Dataset Config", value="default") submit_btn = gr.Button("Submit", variant="primary") student_tok_state = gr.State(None) teacher_tok_state = gr.State(None) texts_state = gr.State([]) cache_state = gr.State({}) idx_state = gr.State(0) output = gr.HTML(label="Tokenization Diff Output") with gr.Row(): prev_btn = gr.Button("Previous") next_btn = gr.Button("Next") submit_btn.click( fn=process_texts, inputs=[student_model, teacher_model, dataset_id, dataset_config], outputs=[student_tok_state, teacher_tok_state, texts_state, cache_state, idx_state, output], ) prev_btn.click( fn=go_prev, inputs=[cache_state, idx_state, texts_state], outputs=[cache_state, idx_state, output], ) next_btn.click( fn=go_next, inputs=[student_tok_state, teacher_tok_state, texts_state, cache_state, idx_state], outputs=[cache_state, idx_state, output], ) if __name__ == "__main__": demo.launch()