| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
|
|
| class InputPreparer(nn.Module):
|
| def __init__(self):
|
| super().__init__()
|
|
|
| matrix_a = torch.tensor([[1., 2., 1.],
|
| [2., 4., 2.],
|
| [1., 2., 1.]], dtype=torch.float32) / 16.0
|
| self.register_buffer('filter_pattern_a', matrix_a.view(1, 1, 3, 3))
|
|
|
| matrix_b = torch.tensor([[-1., 0., 1.],[-2., 0., 2.],[-1., 0., 1.]], dtype=torch.float32).view(1, 1, 3, 3)
|
| matrix_c = torch.tensor([[-1., -2., -1.],
|
| [ 0., 0., 0.],
|
| [ 1., 2., 1.]], dtype=torch.float32).view(1, 1, 3, 3)
|
| self.register_buffer('filter_pattern_b', matrix_b)
|
| self.register_buffer('filter_pattern_c',matrix_c)
|
|
|
| self.gating_network = nn.Sequential(
|
| nn.AdaptiveAvgPool2d(1),
|
| nn.Conv2d(2,2, kernel_size=1),
|
| nn.Sigmoid()
|
| )
|
| self.mapping_conv = nn.Conv2d(2, 32, kernel_size=3, padding=1, bias=False)
|
| self.normalization = nn.BatchNorm2d(32)
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| filtered_input = F.conv2d(x, self.filter_pattern_a, padding=1)
|
| response_b = F.conv2d(filtered_input, self.filter_pattern_b, padding=1)
|
| response_c = F.conv2d(filtered_input, self.filter_pattern_c, padding=1)
|
| combined_response = torch.sqrt(response_b**2 + response_c**2+1e-5)
|
|
|
| integrated_features = torch.cat([x, combined_response], dim=1)
|
| modulated_features = integrated_features * self.gating_network(integrated_features)
|
| return F.silu(self.normalization(self.mapping_conv(modulated_features)))
|
|
|
|
|
| class MagnitudeScaler(nn.Module):
|
| def __init__(self, kernel_size=2, stride=2, padding=0):
|
| super().__init__()
|
| self.kernel_size = kernel_size
|
| self.stride = stride
|
| self.padding = padding
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| squared_values = torch.clamp(x, min=0.0)**2
|
| aggregated_values = F.avg_pool2d(squared_values, self.kernel_size, self.stride, self.padding)
|
| return torch.sqrt(aggregated_values + 1e-5)
|
|
|
|
|
| class FeatureWeighting(nn.Module):
|
| def __init__(self, kernel_size: int = 7):
|
| super().__init__()
|
| self.spatial_weighting = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
|
| self.activation = nn.Sigmoid()
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| mean_projection = torch.mean(x, dim=1, keepdim=True)
|
| max_projection, _ = torch.max(x, dim=1, keepdim=True)
|
| combined_projection = torch.cat([mean_projection, max_projection], dim=1)
|
| return x * self.activation(self.spatial_weighting(combined_projection))
|
|
|
|
|
| class ProcessingBlock(nn.Module):
|
| def __init__(self, in_c: int, out_c: int, drop: float = 0.1) -> None:
|
| super().__init__()
|
| self.core_conv = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1, bias=False)
|
| self.core_norm = nn.BatchNorm2d(out_c)
|
| self.refinement = FeatureWeighting()
|
| self.nonlinearity = nn.SiLU()
|
| self.regularization = nn.Dropout2d(p=drop)
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| out = self.nonlinearity(self.core_norm(self.core_conv(x)))
|
| out = self.regularization(out)
|
| return self.refinement(out)
|
|
|
|
|
| class HierarchicalNetwork(nn.Module):
|
| def __init__(self, out_dims: int = 11):
|
| super().__init__()
|
| self.pre_processor = InputPreparer()
|
|
|
| self.stage_a = ProcessingBlock(32, 64, drop=0.1)
|
| self.downsampler_a = MagnitudeScaler(kernel_size=2, stride=2)
|
|
|
| self.stage_b = ProcessingBlock(64, 128, drop=0.1)
|
| self.downsampler_b = MagnitudeScaler(kernel_size=2, stride=2)
|
|
|
| self.stage_c = ProcessingBlock(128, 256, drop=0.1)
|
| self.global_reducer_a = nn.AdaptiveAvgPool2d(1)
|
| self.global_reducer_b = nn.AdaptiveMaxPool2d(1)
|
|
|
| self.decision_network = nn.Sequential(
|
| nn.Linear(256 * 2, 128),
|
| nn.SiLU(),
|
| nn.Dropout(0.2),
|
| nn.Linear(128, out_dims)
|
| )
|
| self._reset_parameters()
|
|
|
|
|
| def _reset_parameters(self):
|
| for m in self.modules():
|
| if isinstance(m, nn.Conv2d):
|
| nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| if m.bias is not None:
|
| nn.init.zeros_(m.bias)
|
| elif isinstance(m, nn.BatchNorm2d):
|
| nn.init.ones_(m.weight)
|
| nn.init.zeros_(m.bias)
|
| elif isinstance(m, nn.Linear):
|
| nn.init.normal_(m.weight, 0, 0.01)
|
| nn.init.zeros_(m.bias)
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| x = self.pre_processor(x)
|
| x = self.downsampler_a(self.stage_a(x))
|
| x = self.downsampler_b(self.stage_b(x))
|
| x = self.stage_c(x)
|
|
|
| reduced_a = self.global_reducer_a(x).view(x.size(0), -1)
|
| reduced_b = self.global_reducer_b(x).view(x.size(0), -1)
|
|
|
| return self.decision_network(torch.cat([reduced_a, reduced_b], dim=1)) |