ltuncay's picture
Submission to the Interspeech 2026 Audio Encoder Capability Challenge
eca55dc verified
Raw
History Blame Contribute Delete
1.36 kB
import torch
import torch.nn as nn
from timm.layers import PatchEmbed as TimmPatchEmbed
class PatchEmbed(nn.Module):
"""
2D Image to Patch Embedding.
Args:
img_size (tuple[int, int]): Input image size (H, W).
patch_size (tuple[int, int]): Patch size (H, W).
in_chans (int): Number of input channels.
embed_dim (int): Embedding dimension.
"""
def __init__(
self,
img_size: tuple[int, int] = (128, 256),
patch_size: tuple[int, int] = (16, 16),
in_chans: int = 1,
embed_dim: int = 768,
):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.in_chans = in_chans
self.embed_dim = embed_dim
self.patch_embed = TimmPatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
flatten=True,
bias=True,
strict_img_size=False,
)
self.num_patches = self.patch_embed.num_patches
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass.
Args:
x (torch.Tensor): Input tensor [B, C, H, W].
Returns:
torch.Tensor: Patch embeddings [B, N, D].
"""
return self.patch_embed(x)