ltuncay's picture
Submission to the Interspeech 2026 Audio Encoder Capability Challenge
eca55dc verified
Raw
History Blame Contribute Delete
2.53 kB
import torch
import torch.nn as nn
import torchaudio
from typing import Optional
class Spectrogram(nn.Module):
"""
Mel Spectrogram module with AmplitudeToDB conversion.
Args:
sample_rate (int): Sample rate of the audio.
n_fft (int): Size of FFT.
win_length (Optional[int]): Window length. Defaults to n_fft.
win_length_ms (Optional[float]): Window length in milliseconds. Overrides win_length if provided.
hop_length (Optional[int]): Hop length. Defaults to win_length // 2.
hop_length_ms (Optional[float]): Hop length in milliseconds. Overrides hop_length if provided.
n_mels (int): Number of mel filterbanks.
f_min (float): Minimum frequency.
f_max (Optional[float]): Maximum frequency.
power (float): Power of the magnitude.
"""
def __init__(
self,
sample_rate: int = 32000,
n_fft: int = 4096,
win_length: Optional[int] = None,
win_length_ms: Optional[float] = None,
hop_length: Optional[int] = None,
hop_length_ms: Optional[float] = None,
n_mels: int = 128,
f_min: float = 0.0,
f_max: Optional[float] = None,
power: float = 2.0,
):
super().__init__()
if win_length is None:
if win_length_ms is None:
win_length = n_fft
else:
win_length = int(sample_rate * win_length_ms / 1000)
if hop_length is None:
if hop_length_ms is None:
hop_length = win_length // 2
else:
hop_length = int(sample_rate * hop_length_ms / 1000)
self.mel_spec = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
n_mels=n_mels,
f_min=f_min,
f_max=f_max,
power=power,
normalized=True,
)
self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass.
Args:
x (torch.Tensor): Input waveform [B, C, T] or [B, T].
Returns:
torch.Tensor: Log-Mel Spectrogram [B, C, F, T].
"""
# x: [B, C, T]
# MelSpectrogram expects [..., T]
# Output will be [..., n_mels, time]
spec = self.mel_spec(x)
spec = self.amplitude_to_db(spec)
return spec