RobustMNIST-v1.0 / model.py
MultivexAI's picture
Upload 2 files
0f5b4cb verified
import torch
import torch.nn as nn
import torch.nn.functional as F
class InputPreparer(nn.Module):
def __init__(self):
super().__init__()
# smoothing and diff filters
matrix_a = torch.tensor([[1., 2., 1.],
[2., 4., 2.],
[1., 2., 1.]], dtype=torch.float32) / 16.0
self.register_buffer('filter_pattern_a', matrix_a.view(1, 1, 3, 3))
matrix_b = torch.tensor([[-1., 0., 1.],[-2., 0., 2.],[-1., 0., 1.]], dtype=torch.float32).view(1, 1, 3, 3)
matrix_c = torch.tensor([[-1., -2., -1.],
[ 0., 0., 0.],
[ 1., 2., 1.]], dtype=torch.float32).view(1, 1, 3, 3)
self.register_buffer('filter_pattern_b', matrix_b)
self.register_buffer('filter_pattern_c',matrix_c)
self.gating_network = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(2,2, kernel_size=1),
nn.Sigmoid()
)
self.mapping_conv = nn.Conv2d(2, 32, kernel_size=3, padding=1, bias=False)
self.normalization = nn.BatchNorm2d(32)
def forward(self, x: torch.Tensor) -> torch.Tensor:
filtered_input = F.conv2d(x, self.filter_pattern_a, padding=1)
response_b = F.conv2d(filtered_input, self.filter_pattern_b, padding=1)
response_c = F.conv2d(filtered_input, self.filter_pattern_c, padding=1)
combined_response = torch.sqrt(response_b**2 + response_c**2+1e-5)
integrated_features = torch.cat([x, combined_response], dim=1)
modulated_features = integrated_features * self.gating_network(integrated_features)
return F.silu(self.normalization(self.mapping_conv(modulated_features)))
class MagnitudeScaler(nn.Module):
def __init__(self, kernel_size=2, stride=2, padding=0):
super().__init__()
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
def forward(self, x: torch.Tensor) -> torch.Tensor:
squared_values = torch.clamp(x, min=0.0)**2
aggregated_values = F.avg_pool2d(squared_values, self.kernel_size, self.stride, self.padding)
return torch.sqrt(aggregated_values + 1e-5)
class FeatureWeighting(nn.Module):
def __init__(self, kernel_size: int = 7):
super().__init__()
self.spatial_weighting = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
self.activation = nn.Sigmoid()
def forward(self, x: torch.Tensor) -> torch.Tensor:
mean_projection = torch.mean(x, dim=1, keepdim=True)
max_projection, _ = torch.max(x, dim=1, keepdim=True)
combined_projection = torch.cat([mean_projection, max_projection], dim=1)
return x * self.activation(self.spatial_weighting(combined_projection))
class ProcessingBlock(nn.Module):
def __init__(self, in_c: int, out_c: int, drop: float = 0.1) -> None:
super().__init__()
self.core_conv = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1, bias=False)
self.core_norm = nn.BatchNorm2d(out_c)
self.refinement = FeatureWeighting()
self.nonlinearity = nn.SiLU()
self.regularization = nn.Dropout2d(p=drop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self.nonlinearity(self.core_norm(self.core_conv(x)))
out = self.regularization(out)
return self.refinement(out)
class HierarchicalNetwork(nn.Module):
def __init__(self, out_dims: int = 11):
super().__init__()
self.pre_processor = InputPreparer()
self.stage_a = ProcessingBlock(32, 64, drop=0.1)
self.downsampler_a = MagnitudeScaler(kernel_size=2, stride=2)
self.stage_b = ProcessingBlock(64, 128, drop=0.1)
self.downsampler_b = MagnitudeScaler(kernel_size=2, stride=2)
self.stage_c = ProcessingBlock(128, 256, drop=0.1)
self.global_reducer_a = nn.AdaptiveAvgPool2d(1)
self.global_reducer_b = nn.AdaptiveMaxPool2d(1)
self.decision_network = nn.Sequential(
nn.Linear(256 * 2, 128),
nn.SiLU(),
nn.Dropout(0.2),
nn.Linear(128, out_dims)
)
self._reset_parameters()
def _reset_parameters(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pre_processor(x)
x = self.downsampler_a(self.stage_a(x))
x = self.downsampler_b(self.stage_b(x))
x = self.stage_c(x)
reduced_a = self.global_reducer_a(x).view(x.size(0), -1)
reduced_b = self.global_reducer_b(x).view(x.size(0), -1)
return self.decision_network(torch.cat([reduced_a, reduced_b], dim=1))