Spaces:
Running
Running
| import html | |
| import gradio as gr | |
| from datasets import load_dataset | |
| from transformers import AutoTokenizer | |
| def build_alignment_groups_from_ids(student_tokenizer, teacher_tokenizer, student_token_ids, teacher_token_ids): | |
| """ | |
| Build alignment groups using a greedy substring-equality algorithm on decoded token pieces. | |
| Adapted from TRL's GoldTrainer._build_alignment_groups_from_ids. | |
| """ | |
| def to_canonical_pieces(tok, ids): | |
| pieces = [] | |
| prev = "" | |
| for k in range(len(ids)): | |
| cur = tok.decode(ids[: k + 1], skip_special_tokens=False, clean_up_tokenization_spaces=False) | |
| pieces.append(cur[len(prev):]) | |
| prev = cur | |
| return pieces | |
| s_pieces = to_canonical_pieces(student_tokenizer, student_token_ids) | |
| t_pieces = to_canonical_pieces(teacher_tokenizer, teacher_token_ids) | |
| i = j = 0 | |
| s_buf = t_buf = "" | |
| s_group = [] | |
| t_group = [] | |
| s_groups = [] | |
| t_groups = [] | |
| def flush(): | |
| if s_group and t_group: | |
| s_groups.append(s_group.copy()) | |
| t_groups.append(t_group.copy()) | |
| while i < len(s_pieces) or j < len(t_pieces): | |
| if s_buf == t_buf and s_buf != "": | |
| flush() | |
| s_buf = t_buf = "" | |
| s_group = [] | |
| t_group = [] | |
| continue | |
| if s_buf == "" and i < len(s_pieces): | |
| s_buf += s_pieces[i] | |
| s_group.append(i) | |
| i += 1 | |
| continue | |
| if t_buf == "" and j < len(t_pieces): | |
| t_buf += t_pieces[j] | |
| t_group.append(j) | |
| j += 1 | |
| continue | |
| if len(s_buf) <= len(t_buf): | |
| if i < len(s_pieces): | |
| s_buf += s_pieces[i] | |
| s_group.append(i) | |
| i += 1 | |
| elif j < len(t_pieces): | |
| t_buf += t_pieces[j] | |
| t_group.append(j) | |
| j += 1 | |
| else: | |
| if j < len(t_pieces): | |
| t_buf += t_pieces[j] | |
| t_group.append(j) | |
| j += 1 | |
| elif i < len(s_pieces): | |
| s_buf += s_pieces[i] | |
| s_group.append(i) | |
| i += 1 | |
| if s_buf == t_buf and s_group and t_group: | |
| flush() | |
| elif s_group or t_group: | |
| if not s_group: | |
| s_group = [] | |
| if not t_group: | |
| t_group = [] | |
| if s_group or t_group: | |
| s_groups.append(s_group.copy() if s_group else []) | |
| t_groups.append(t_group.copy() if t_group else []) | |
| return s_groups, t_groups | |
| def _decode_pieces(tokenizer, token_ids, indices): | |
| """Decode individual token pieces for a group of token indices.""" | |
| return [ | |
| tokenizer.decode([token_ids[idx]], skip_special_tokens=False, clean_up_tokenization_spaces=False) | |
| for idx in indices | |
| ] | |
| def _format_pieces(pieces): | |
| """Format token pieces as a list, e.g. '["hel", "lo"]'.""" | |
| inner = ", ".join(f'"{p}"' for p in pieces) | |
| return f"[{inner}]" | |
| def highlight_groups(student_tokenizer, teacher_tokenizer, student_token_ids, teacher_token_ids, s_groups, t_groups): | |
| """Build an HTML string with highlighted misalignment regions.""" | |
| parts = [] | |
| first_purple = True | |
| for k in range(len(s_groups)): | |
| s_ids = [student_token_ids[idx] for idx in s_groups[k]] | |
| text = student_tokenizer.decode(s_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False) | |
| escaped = html.escape(text) | |
| s_multi = len(s_groups[k]) > 1 | |
| t_multi = len(t_groups[k]) > 1 | |
| if s_multi and t_multi: | |
| if first_purple: | |
| s_pieces = _decode_pieces(student_tokenizer, student_token_ids, s_groups[k]) | |
| t_pieces = _decode_pieces(teacher_tokenizer, teacher_token_ids, t_groups[k]) | |
| tooltip = html.escape(f'Student: {_format_pieces(s_pieces)} / Teacher: {_format_pieces(t_pieces)}') | |
| parts.append(f'<span style="background-color: #b388ff;" title="{tooltip}">{escaped}</span>') | |
| first_purple = False | |
| else: | |
| parts.append(f'<span style="background-color: #b388ff;">{escaped}</span>') | |
| elif s_multi: | |
| s_pieces = _decode_pieces(student_tokenizer, student_token_ids, s_groups[k]) | |
| tooltip = html.escape(f'Student: {_format_pieces(s_pieces)}') | |
| parts.append(f'<span style="background-color: #ffcc80;" title="{tooltip}">{escaped}</span>') | |
| elif t_multi: | |
| t_pieces = _decode_pieces(teacher_tokenizer, teacher_token_ids, t_groups[k]) | |
| tooltip = html.escape(f'Teacher: {_format_pieces(t_pieces)}') | |
| parts.append(f'<span style="background-color: #90caf9;" title="{tooltip}">{escaped}</span>') | |
| else: | |
| parts.append(escaped) | |
| return "".join(parts) | |
| def make_html_block(student_tokenizer, teacher_tokenizer, text, idx): | |
| """Process a single text and return its highlighted HTML block.""" | |
| s_ids = student_tokenizer.encode(text, add_special_tokens=False) | |
| t_ids = teacher_tokenizer.encode(text, add_special_tokens=False) | |
| s_groups, t_groups = build_alignment_groups_from_ids( | |
| student_tokenizer, teacher_tokenizer, s_ids, t_ids | |
| ) | |
| highlighted = highlight_groups(student_tokenizer, teacher_tokenizer, s_ids, t_ids, s_groups, t_groups) | |
| # Build tokenized views with alternating colors | |
| s_tokens = [student_tokenizer.decode([tid], skip_special_tokens=False, clean_up_tokenization_spaces=False) for tid in s_ids] | |
| t_tokens = [teacher_tokenizer.decode([tid], skip_special_tokens=False, clean_up_tokenization_spaces=False) for tid in t_ids] | |
| color1 = "#fff9c4" | |
| color2 = "#b2ebf2" | |
| s_tokens_html = "".join( | |
| f'<span style="background-color:{color1 if i % 2 == 0 else color2};">{html.escape(t)}</span>' | |
| for i, t in enumerate(s_tokens) | |
| ) | |
| t_tokens_html = "".join( | |
| f'<span style="background-color:{color1 if i % 2 == 0 else color2};">{html.escape(t)}</span>' | |
| for i, t in enumerate(t_tokens) | |
| ) | |
| tokenized_section = f''' | |
| <div style="margin-bottom:15px;"> | |
| <details style="margin-bottom:10px;"> | |
| <summary style="cursor:pointer; font-weight:bold; user-select:none;">Show tokenization details</summary> | |
| <div style="display:grid; grid-template-columns:1fr 1fr; gap:15px; margin-top:10px;"> | |
| <div style="border:1px solid #ddd; padding:10px; border-radius:5px;"> | |
| <strong style="color:#f57c00;">Student Tokens ({len(s_ids)})</strong> | |
| <div style="margin-top:8px; font-size:12px; word-break:break-word;">{s_tokens_html}</div> | |
| </div> | |
| <div style="border:1px solid #ddd; padding:10px; border-radius:5px;"> | |
| <strong style="color:#1976d2;">Teacher Tokens ({len(t_ids)})</strong> | |
| <div style="margin-top:8px; font-size:12px; word-break:break-word;">{t_tokens_html}</div> | |
| </div> | |
| </div> | |
| </details> | |
| </div> | |
| ''' | |
| return ( | |
| f'<div style="border:1px solid #ccc; padding:10px; margin:10px 0; ' | |
| f'border-radius:5px; white-space:pre-wrap; font-family:monospace; font-size:13px;">' | |
| f"<strong>Text {idx + 1}</strong> " | |
| f"(student tokens: {len(s_ids)}, teacher tokens: {len(t_ids)})<br><br>" | |
| f"{tokenized_section}" | |
| f"{highlighted}" | |
| f"</div>" | |
| ) | |
| def process_texts(student_model_id, teacher_model_id, dataset_id, dataset_config, progress=gr.Progress()): | |
| """Load tokenizers and dataset, compute first row only.""" | |
| progress(0, desc="Loading tokenizers...") | |
| student_tokenizer = AutoTokenizer.from_pretrained(student_model_id) | |
| teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_id) | |
| progress(0.5, desc="Loading dataset...") | |
| config = dataset_config.strip() if dataset_config and dataset_config.strip() else None | |
| ds = load_dataset(dataset_id, name=config, split="train") | |
| rows = ds.select(range(min(10, len(ds)))) | |
| texts = ["".join(msg["content"] for msg in row["messages"]) for row in rows] | |
| progress(0.8, desc="Processing first text...") | |
| first_block = make_html_block(student_tokenizer, teacher_tokenizer, texts[0], 0) | |
| cache = {0: first_block} | |
| progress(1, desc="Done!") | |
| return student_tokenizer, teacher_tokenizer, texts, cache, 0, render_page(cache, 0, len(texts)) | |
| LEGEND = ( | |
| '<div style="margin-bottom:15px; font-family:sans-serif;">' | |
| "<strong>Legend:</strong> " | |
| '<span style="background-color:#ffcc80; padding:2px 8px; margin-right:8px;">Student token split (orange)</span>' | |
| '<span style="background-color:#90caf9; padding:2px 8px; margin-right:8px;">Teacher token split (blue)</span>' | |
| '<span style="background-color:#b388ff; padding:2px 8px;">Both (purple)</span>' | |
| "</div>" | |
| ) | |
| def render_page(cache, idx, total): | |
| if not cache: | |
| return "" | |
| counter = f'<div style="font-family:sans-serif; margin-bottom:10px;">Text {idx + 1} of {total}</div>' | |
| return LEGEND + counter + cache[idx] | |
| def go_prev(cache, idx, texts): | |
| idx = max(0, idx - 1) | |
| return cache, idx, render_page(cache, idx, len(texts)) | |
| def go_next(student_tokenizer, teacher_tokenizer, texts, cache, idx): | |
| idx = min(len(texts) - 1, idx + 1) | |
| if idx not in cache: | |
| cache[idx] = make_html_block(student_tokenizer, teacher_tokenizer, texts[idx], idx) | |
| return cache, idx, render_page(cache, idx, len(texts)) | |
| with gr.Blocks(title="Tokenization Diff") as demo: | |
| gr.Markdown("# Tokenization Diff\nVisualize where two tokenizers differ in how they tokenize text.") | |
| with gr.Row(): | |
| student_model = gr.Textbox(label="Student Model", value="Qwen/Qwen3-8B") | |
| teacher_model = gr.Textbox(label="Teacher Model", value="deepseek-ai/DeepSeek-Math-V2") | |
| dataset_id = gr.Textbox(label="Dataset ID", value="lm-provers/FineProofs-SFT") | |
| dataset_config = gr.Textbox(label="Dataset Config", value="default") | |
| submit_btn = gr.Button("Submit", variant="primary") | |
| student_tok_state = gr.State(None) | |
| teacher_tok_state = gr.State(None) | |
| texts_state = gr.State([]) | |
| cache_state = gr.State({}) | |
| idx_state = gr.State(0) | |
| output = gr.HTML(label="Tokenization Diff Output") | |
| with gr.Row(): | |
| prev_btn = gr.Button("Previous") | |
| next_btn = gr.Button("Next") | |
| submit_btn.click( | |
| fn=process_texts, | |
| inputs=[student_model, teacher_model, dataset_id, dataset_config], | |
| outputs=[student_tok_state, teacher_tok_state, texts_state, cache_state, idx_state, output], | |
| ) | |
| prev_btn.click( | |
| fn=go_prev, | |
| inputs=[cache_state, idx_state, texts_state], | |
| outputs=[cache_state, idx_state, output], | |
| ) | |
| next_btn.click( | |
| fn=go_next, | |
| inputs=[student_tok_state, teacher_tok_state, texts_state, cache_state, idx_state], | |
| outputs=[cache_state, idx_state, output], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |