BEST-RQ-2 / audio-embeddings /src /callbacks /visualization_callback.py
ltuncay's picture
Submission to the Interspeech 2026 Audio Encoder Capability Challenge
eca55dc verified
Raw
History Blame Contribute Delete
7.6 kB
import torch
import matplotlib.pyplot as plt
import numpy as np
import lightning as L
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.loggers import WandbLogger
from typing import Any, Dict, Optional
class VisualizationCallback(Callback):
"""
Callback to visualize spectrograms, patches, and masks.
Logs the first 4 samples of the first 2 batches.
"""
def __init__(self, num_samples: int = 4):
super().__init__()
self.num_samples = num_samples
self.batches_logged = 0
def on_train_batch_end(
self,
trainer: L.Trainer,
pl_module: L.LightningModule,
outputs: Any,
batch: Any,
batch_idx: int,
) -> None:
if self.batches_logged >= 2:
return
# Log for the first 2 batches
if batch_idx < 2:
self._log_visualizations(trainer, pl_module, batch, batch_idx)
self.batches_logged += 1
def _log_visualizations(
self,
trainer: L.Trainer,
pl_module: L.LightningModule,
batch: Dict[str, Any],
batch_idx: int,
) -> None:
logger = trainer.logger
if not isinstance(logger, WandbLogger):
return
waveform = batch["waveform"][: self.num_samples] # [B, 1, T]
sample_rate = self._resolve_sample_rate(trainer, pl_module)
# Get spectrograms
with torch.no_grad():
spec = pl_module.spectrogram(waveform.to(pl_module.device)) # [B, 1, F, T]
# Get grid size and patch info
patch_size = pl_module.patch_embed.patch_embed.patch_size
F_pix = spec.shape[2]
T_pix = spec.shape[3]
H_grid = F_pix // patch_size[0]
W_grid = T_pix // patch_size[1]
current_grid_size = (H_grid, W_grid)
# Generate mask
# Using the same logic as training step (shared mask across batch)
# But we want to see if it's the same across batches (it should be random each step)
mask = pl_module.mask_generator(
1, device=pl_module.device, grid_size=current_grid_size
) # [1, N]
mask = mask.expand(self.num_samples, -1) # [B, N]
# Log to WandB
import wandb
columns = [
"Batch Idx",
"Sample Idx",
"Audio",
"Spectrogram",
"Masked Spectrogram (Context)",
"Inverse Masked Spectrogram (Targets)",
]
data = []
for i in range(self.num_samples):
# Audio
audio_data = waveform[i].squeeze().cpu().numpy()
audio = wandb.Audio(
audio_data, sample_rate=sample_rate, caption=f"B{batch_idx}_S{i}"
)
# Spectrograms
spec_data = spec[i].squeeze().cpu().numpy()
mask_data = mask[i].cpu().numpy()
# 1. Original
fig_orig = self._plot_spectrogram(spec_data, patch_size, current_grid_size)
img_orig = wandb.Image(fig_orig, caption=f"Spec B{batch_idx}_S{i}")
plt.close(fig_orig)
# 2. Masked (Context) - Masked parts are dark
fig_masked = self._plot_spectrogram_with_mask(
spec_data, mask_data, patch_size, current_grid_size, invert_mask=False
)
img_masked = wandb.Image(fig_masked, caption=f"Masked B{batch_idx}_S{i}")
plt.close(fig_masked)
# 3. Inverse Masked (Targets) - Context parts are dark
fig_inv_masked = self._plot_spectrogram_with_mask(
spec_data, mask_data, patch_size, current_grid_size, invert_mask=True
)
img_inv_masked = wandb.Image(
fig_inv_masked, caption=f"InvMasked B{batch_idx}_S{i}"
)
plt.close(fig_inv_masked)
data.append([batch_idx, i, audio, img_orig, img_masked, img_inv_masked])
# Log Table
table = wandb.Table(columns=columns, data=data)
logger.experiment.log({f"train/visualizations_batch_{batch_idx}": table})
@staticmethod
def _resolve_sample_rate(trainer: L.Trainer, pl_module: L.LightningModule) -> int:
"""Resolve audio logging sample rate, preferring data target sample rate."""
sample_rate = 32000
datamodule = getattr(trainer, "datamodule", None)
if datamodule is not None:
dm_sr = getattr(datamodule, "target_sample_rate", None)
if dm_sr is None and hasattr(datamodule, "hparams"):
hparams = datamodule.hparams
if isinstance(hparams, dict):
dm_sr = hparams.get("target_sample_rate")
else:
dm_sr = getattr(hparams, "target_sample_rate", None)
if dm_sr is not None:
return int(dm_sr)
spectrogram = getattr(pl_module, "spectrogram", None)
module_sr = getattr(spectrogram, "sample_rate", None)
if module_sr is not None:
return int(module_sr)
hparams = getattr(pl_module, "hparams", None)
if isinstance(hparams, dict):
net_cfg = hparams.get("net")
if isinstance(net_cfg, dict):
spectrogram_cfg = net_cfg.get("spectrogram")
if isinstance(spectrogram_cfg, dict):
config_sr = spectrogram_cfg.get("sample_rate")
if config_sr is not None:
return int(config_sr)
return sample_rate
def _plot_spectrogram(
self, spec: np.ndarray, patch_size: tuple[int, int], grid_size: tuple[int, int]
) -> plt.Figure:
"""Plots spectrogram with grid lines."""
return self._plot_spectrogram_with_mask(spec, None, patch_size, grid_size)
def _plot_spectrogram_with_mask(
self,
spec: np.ndarray,
mask: Optional[np.ndarray],
patch_size: tuple[int, int],
grid_size: tuple[int, int],
invert_mask: bool = False,
) -> plt.Figure:
"""
Plots spectrogram with dashed grid lines and darker masked patches.
If mask is None, just plots spectrogram and grid.
If invert_mask is True, darkens the unmasked parts instead.
"""
H_grid, W_grid = grid_size
Ph, Pw = patch_size
H, W = spec.shape
fig, ax = plt.subplots(figsize=(10, 4))
ax.imshow(spec, origin="lower", aspect="auto", cmap="viridis")
# Overlay Grid
for h in range(0, H + 1, Ph):
ax.axhline(h - 0.5, color="white", linestyle="--", linewidth=0.5, alpha=0.5)
for w in range(0, W + 1, Pw):
ax.axvline(w - 0.5, color="white", linestyle="--", linewidth=0.5, alpha=0.5)
# Overlay Mask
if mask is not None:
mask_grid = mask.reshape(H_grid, W_grid)
if invert_mask:
mask_grid = ~mask_grid
overlay = np.zeros((H, W, 4)) # RGBA
for r in range(H_grid):
for c in range(W_grid):
if mask_grid[r, c]:
y_start = r * Ph
y_end = (r + 1) * Ph
x_start = c * Pw
x_end = (c + 1) * Pw
overlay[y_start:y_end, x_start:x_end, 3] = 0.7
ax.imshow(overlay, origin="lower", aspect="auto")
ax.set_title("Spectrogram")
ax.set_xlabel("Time Frames")
ax.set_ylabel("Frequency Bins")
plt.tight_layout()
return fig