import html import gradio as gr from datasets import load_dataset from transformers import AutoTokenizer def build_alignment_groups_from_ids(student_tokenizer, teacher_tokenizer, student_token_ids, teacher_token_ids): """ Build alignment groups using a greedy substring-equality algorithm on decoded token pieces. Adapted from TRL's GoldTrainer._build_alignment_groups_from_ids. """ def to_canonical_pieces(tok, ids): pieces = [] prev = "" for k in range(len(ids)): cur = tok.decode(ids[: k + 1], skip_special_tokens=False, clean_up_tokenization_spaces=False) pieces.append(cur[len(prev):]) prev = cur return pieces s_pieces = to_canonical_pieces(student_tokenizer, student_token_ids) t_pieces = to_canonical_pieces(teacher_tokenizer, teacher_token_ids) i = j = 0 s_buf = t_buf = "" s_group = [] t_group = [] s_groups = [] t_groups = [] def flush(): if s_group and t_group: s_groups.append(s_group.copy()) t_groups.append(t_group.copy()) while i < len(s_pieces) or j < len(t_pieces): if s_buf == t_buf and s_buf != "": flush() s_buf = t_buf = "" s_group = [] t_group = [] continue if s_buf == "" and i < len(s_pieces): s_buf += s_pieces[i] s_group.append(i) i += 1 continue if t_buf == "" and j < len(t_pieces): t_buf += t_pieces[j] t_group.append(j) j += 1 continue if len(s_buf) <= len(t_buf): if i < len(s_pieces): s_buf += s_pieces[i] s_group.append(i) i += 1 elif j < len(t_pieces): t_buf += t_pieces[j] t_group.append(j) j += 1 else: if j < len(t_pieces): t_buf += t_pieces[j] t_group.append(j) j += 1 elif i < len(s_pieces): s_buf += s_pieces[i] s_group.append(i) i += 1 if s_buf == t_buf and s_group and t_group: flush() elif s_group or t_group: if not s_group: s_group = [] if not t_group: t_group = [] if s_group or t_group: s_groups.append(s_group.copy() if s_group else []) t_groups.append(t_group.copy() if t_group else []) return s_groups, t_groups def _decode_pieces(tokenizer, token_ids, indices): """Decode individual token pieces for a group of token indices.""" return [ tokenizer.decode([token_ids[idx]], skip_special_tokens=False, clean_up_tokenization_spaces=False) for idx in indices ] def _format_pieces(pieces): """Format token pieces as a list, e.g. '["hel", "lo"]'.""" inner = ", ".join(f'"{p}"' for p in pieces) return f"[{inner}]" def highlight_groups(student_tokenizer, teacher_tokenizer, student_token_ids, teacher_token_ids, s_groups, t_groups): """Build an HTML string with highlighted misalignment regions.""" parts = [] first_purple = True for k in range(len(s_groups)): s_ids = [student_token_ids[idx] for idx in s_groups[k]] text = student_tokenizer.decode(s_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False) escaped = html.escape(text) s_multi = len(s_groups[k]) > 1 t_multi = len(t_groups[k]) > 1 if s_multi and t_multi: if first_purple: s_pieces = _decode_pieces(student_tokenizer, student_token_ids, s_groups[k]) t_pieces = _decode_pieces(teacher_tokenizer, teacher_token_ids, t_groups[k]) tooltip = html.escape(f'Student: {_format_pieces(s_pieces)} / Teacher: {_format_pieces(t_pieces)}') parts.append(f'{escaped}') first_purple = False else: parts.append(f'{escaped}') elif s_multi: s_pieces = _decode_pieces(student_tokenizer, student_token_ids, s_groups[k]) tooltip = html.escape(f'Student: {_format_pieces(s_pieces)}') parts.append(f'{escaped}') elif t_multi: t_pieces = _decode_pieces(teacher_tokenizer, teacher_token_ids, t_groups[k]) tooltip = html.escape(f'Teacher: {_format_pieces(t_pieces)}') parts.append(f'{escaped}') else: parts.append(escaped) return "".join(parts) def make_html_block(student_tokenizer, teacher_tokenizer, text, idx): """Process a single text and return its highlighted HTML block.""" s_ids = student_tokenizer.encode(text, add_special_tokens=False) t_ids = teacher_tokenizer.encode(text, add_special_tokens=False) s_groups, t_groups = build_alignment_groups_from_ids( student_tokenizer, teacher_tokenizer, s_ids, t_ids ) highlighted = highlight_groups(student_tokenizer, teacher_tokenizer, s_ids, t_ids, s_groups, t_groups) # Build tokenized views with alternating colors s_tokens = [student_tokenizer.decode([tid], skip_special_tokens=False, clean_up_tokenization_spaces=False) for tid in s_ids] t_tokens = [teacher_tokenizer.decode([tid], skip_special_tokens=False, clean_up_tokenization_spaces=False) for tid in t_ids] color1 = "#fff9c4" color2 = "#b2ebf2" s_tokens_html = "".join( f'{html.escape(t)}' for i, t in enumerate(s_tokens) ) t_tokens_html = "".join( f'{html.escape(t)}' for i, t in enumerate(t_tokens) ) tokenized_section = f'''