| """Forced alignment for word-level timestamps using Wav2Vec2.""" |
|
|
| import numpy as np |
| import torch |
|
|
| |
| |
| START_OFFSET = 0.04 |
| END_OFFSET = -0.04 |
|
|
|
|
| def _get_device() -> str: |
| """Get best available device for non-transformers models.""" |
| if torch.cuda.is_available(): |
| return "cuda" |
| if torch.backends.mps.is_available(): |
| return "mps" |
| return "cpu" |
|
|
|
|
| class ForcedAligner: |
| """Lazy-loaded forced aligner for word-level timestamps using torchaudio wav2vec2. |
| |
| Uses Viterbi trellis algorithm for optimal alignment path finding. |
| """ |
|
|
| _bundle = None |
| _model = None |
| _labels = None |
| _dictionary = None |
|
|
| @classmethod |
| def get_instance(cls, device: str = "cuda"): |
| """Get or create the forced alignment model (singleton). |
| |
| Args: |
| device: Device to run model on ("cuda" or "cpu") |
| |
| Returns: |
| Tuple of (model, labels, dictionary) |
| """ |
| if cls._model is None: |
| import torchaudio |
|
|
| cls._bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H |
| cls._model = cls._bundle.get_model().to(device) |
| cls._model.eval() |
| cls._labels = cls._bundle.get_labels() |
| cls._dictionary = {c: i for i, c in enumerate(cls._labels)} |
| return cls._model, cls._labels, cls._dictionary |
|
|
| @staticmethod |
| def _get_trellis(emission: torch.Tensor, tokens: list[int], blank_id: int = 0) -> torch.Tensor: |
| """Build trellis for forced alignment using forward algorithm. |
| |
| The trellis[t, j] represents the log probability of the best path that |
| aligns the first j tokens to the first t frames. |
| |
| Args: |
| emission: Log-softmax emission matrix of shape (num_frames, num_classes) |
| tokens: List of target token indices |
| blank_id: Index of the blank/CTC token (default 0) |
| |
| Returns: |
| Trellis matrix of shape (num_frames + 1, num_tokens + 1) |
| """ |
| num_frames = emission.size(0) |
| num_tokens = len(tokens) |
|
|
| trellis = torch.full((num_frames + 1, num_tokens + 1), -float("inf")) |
| trellis[0, 0] = 0 |
|
|
| |
| |
| if num_tokens > 1: |
| trellis[-num_tokens + 1 :, 0] = float("inf") |
|
|
| for t in range(num_frames): |
| for j in range(num_tokens + 1): |
| |
| stay = trellis[t, j] + emission[t, blank_id] |
|
|
| |
| move = trellis[t, j - 1] + emission[t, tokens[j - 1]] if j > 0 else -float("inf") |
|
|
| trellis[t + 1, j] = max(stay, move) |
|
|
| return trellis |
|
|
| @staticmethod |
| def _backtrack( |
| trellis: torch.Tensor, emission: torch.Tensor, tokens: list[int], blank_id: int = 0 |
| ) -> list[tuple[int, float, float, float]]: |
| """Backtrack through trellis to find optimal forced monotonic alignment. |
| |
| Guarantees: |
| - All tokens are emitted exactly once |
| - Strictly monotonic: each token's frames come after previous token's |
| - No frame skipping or token teleporting |
| |
| Returns list of (token_id, start_frame, end_frame, peak_frame) for each token. |
| The peak_frame is the frame with highest emission probability for that token. |
| """ |
| num_frames = emission.size(0) |
| num_tokens = len(tokens) |
|
|
| if num_tokens == 0: |
| return [] |
|
|
| |
| |
| if trellis[num_frames, num_tokens] == -float("inf"): |
| |
| frames_per_token = num_frames / num_tokens |
| return [ |
| ( |
| tokens[i], |
| i * frames_per_token, |
| (i + 1) * frames_per_token, |
| (i + 0.5) * frames_per_token, |
| ) |
| for i in range(num_tokens) |
| ] |
|
|
| |
| |
| token_frames: list[list[tuple[int, float]]] = [[] for _ in range(num_tokens)] |
|
|
| t = num_frames |
| j = num_tokens |
|
|
| while t > 0 and j > 0: |
| |
| stay_score = trellis[t - 1, j] + emission[t - 1, blank_id] |
| move_score = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]] |
|
|
| if move_score >= stay_score: |
| |
| |
| emit_prob = emission[t - 1, tokens[j - 1]].exp().item() |
| token_frames[j - 1].insert(0, (t - 1, emit_prob)) |
| j -= 1 |
| |
| t -= 1 |
|
|
| |
| while j > 0: |
| token_frames[j - 1].insert(0, (0, 0.0)) |
| j -= 1 |
|
|
| |
| token_spans: list[tuple[int, float, float, float]] = [] |
| for token_idx, frames_with_scores in enumerate(token_frames): |
| if not frames_with_scores: |
| |
| if token_spans: |
| prev_end = token_spans[-1][2] |
| frames_with_scores = [(int(prev_end), 0.0)] |
| else: |
| frames_with_scores = [(0, 0.0)] |
|
|
| token_id = tokens[token_idx] |
| frames = [f for f, _ in frames_with_scores] |
| start_frame = float(min(frames)) |
| end_frame = float(max(frames)) + 1.0 |
|
|
| |
| peak_frame, _ = max(frames_with_scores, key=lambda x: x[1]) |
|
|
| token_spans.append((token_id, start_frame, end_frame, float(peak_frame))) |
|
|
| return token_spans |
|
|
| @classmethod |
| def align( |
| cls, |
| audio: np.ndarray, |
| text: str, |
| sample_rate: int = 16000, |
| _language: str = "eng", |
| _batch_size: int = 16, |
| ) -> list[dict]: |
| """Align transcript to audio and return word-level timestamps. |
| |
| Uses Viterbi trellis algorithm for optimal forced alignment. |
| |
| Args: |
| audio: Audio waveform as numpy array |
| text: Transcript text to align |
| sample_rate: Audio sample rate (default 16000) |
| _language: ISO-639-3 language code (default "eng" for English, unused) |
| _batch_size: Batch size for alignment model (unused) |
| |
| Returns: |
| List of dicts with 'word', 'start', 'end' keys |
| """ |
| import torchaudio |
|
|
| device = _get_device() |
| model, _labels, dictionary = cls.get_instance(device) |
| assert cls._bundle is not None and dictionary is not None |
|
|
| |
| if isinstance(audio, np.ndarray): |
| waveform = torch.from_numpy(audio.copy()).float() |
| else: |
| waveform = audio.clone().float() |
|
|
| |
| if waveform.dim() == 1: |
| waveform = waveform.unsqueeze(0) |
|
|
| |
| if sample_rate != cls._bundle.sample_rate: |
| waveform = torchaudio.functional.resample( |
| waveform, sample_rate, cls._bundle.sample_rate |
| ) |
|
|
| waveform = waveform.to(device) |
|
|
| |
| with torch.inference_mode(): |
| emissions, _ = model(waveform) |
| emissions = torch.log_softmax(emissions, dim=-1) |
|
|
| emission = emissions[0].cpu() |
|
|
| |
| transcript = text.upper() |
|
|
| |
| tokens = [] |
| for char in transcript: |
| if char in dictionary: |
| tokens.append(dictionary[char]) |
| elif char == " ": |
| tokens.append(dictionary.get("|", dictionary.get(" ", 0))) |
|
|
| if not tokens: |
| return [] |
|
|
| |
| trellis = cls._get_trellis(emission, tokens, blank_id=0) |
| alignment_path = cls._backtrack(trellis, emission, tokens, blank_id=0) |
|
|
| |
| frame_duration = 320 / cls._bundle.sample_rate |
|
|
| |
| start_offset = START_OFFSET |
| end_offset = END_OFFSET |
|
|
| |
| |
| words = text.split() |
| word_timestamps = [] |
| first_char_peak = None |
| last_char_peak = None |
| word_idx = 0 |
| separator_id = dictionary.get("|", dictionary.get(" ", 0)) |
|
|
| for token_id, _start_frame, _end_frame, peak_frame in alignment_path: |
| if token_id == separator_id: |
| if ( |
| first_char_peak is not None |
| and last_char_peak is not None |
| and word_idx < len(words) |
| ): |
| |
| start_time = max(0.0, first_char_peak * frame_duration - start_offset) |
| end_time = max(0.0, (last_char_peak + 1) * frame_duration - end_offset) |
| word_timestamps.append( |
| { |
| "word": words[word_idx], |
| "start": start_time, |
| "end": end_time, |
| } |
| ) |
| word_idx += 1 |
| first_char_peak = None |
| last_char_peak = None |
| else: |
| if first_char_peak is None: |
| first_char_peak = peak_frame |
| last_char_peak = peak_frame |
|
|
| |
| if first_char_peak is not None and last_char_peak is not None and word_idx < len(words): |
| start_time = max(0.0, first_char_peak * frame_duration - start_offset) |
| end_time = max(0.0, (last_char_peak + 1) * frame_duration - end_offset) |
| word_timestamps.append( |
| { |
| "word": words[word_idx], |
| "start": start_time, |
| "end": end_time, |
| } |
| ) |
|
|
| return word_timestamps |
|
|