|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
|
|
|
ALPHA = 8
|
|
|
|
|
|
class LateralConnection(nn.Module):
|
|
|
"""
|
|
|
Fuses Fast pathway features into Slow pathway.
|
|
|
Transforms Fast features to match Slow features in temporal dimension.
|
|
|
"""
|
|
|
def __init__(self, fast_channels, slow_channels, alpha=ALPHA):
|
|
|
super(LateralConnection, self).__init__()
|
|
|
|
|
|
|
|
|
|
|
|
self.conv = nn.Conv3d(fast_channels, slow_channels * 2, kernel_size=(5, 1, 1), stride=(alpha, 1, 1), padding=(2, 0, 0), bias=False)
|
|
|
|
|
|
def forward(self, x_fast):
|
|
|
return self.conv(x_fast)
|
|
|
|
|
|
class SlowFastNetwork(nn.Module):
|
|
|
def __init__(self):
|
|
|
super(SlowFastNetwork, self).__init__()
|
|
|
|
|
|
|
|
|
|
|
|
self.fast_conv1 = nn.Conv3d(3, 8, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=(2, 3, 3), bias=False)
|
|
|
self.fast_bn1 = nn.BatchNorm3d(8)
|
|
|
self.fast_pool1 = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1))
|
|
|
|
|
|
|
|
|
self.fast_conv2 = nn.Conv3d(8, 16, kernel_size=(3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1), bias=False)
|
|
|
self.fast_bn2 = nn.BatchNorm3d(16)
|
|
|
|
|
|
|
|
|
self.fast_conv3 = nn.Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
|
|
|
self.fast_bn3 = nn.BatchNorm3d(32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.slow_conv1 = nn.Conv3d(3, 64, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3), bias=False)
|
|
|
self.slow_bn1 = nn.BatchNorm3d(64)
|
|
|
self.slow_pool1 = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1))
|
|
|
|
|
|
|
|
|
self.slow_conv2 = nn.Conv3d(64 + 16, 128, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1), bias=False)
|
|
|
self.slow_bn2 = nn.BatchNorm3d(128)
|
|
|
|
|
|
|
|
|
self.slow_conv3 = nn.Conv3d(128 + 64, 256, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False)
|
|
|
self.slow_bn3 = nn.BatchNorm3d(256)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.lateral1 = nn.Conv3d(8, 16, kernel_size=(5, 1, 1), stride=(ALPHA, 1, 1), padding=(2, 0, 0), bias=False)
|
|
|
|
|
|
|
|
|
self.lateral2 = nn.Conv3d(16, 64, kernel_size=(5, 1, 1), stride=(ALPHA, 1, 1), padding=(2, 0, 0), bias=False)
|
|
|
|
|
|
self.relu = nn.ReLU(inplace=True)
|
|
|
self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
|
|
|
|
|
|
|
|
|
self.dropout = nn.Dropout(0.5)
|
|
|
self.fc = nn.Linear(32 + 256, 2)
|
|
|
|
|
|
def forward(self, slow_input, fast_input):
|
|
|
|
|
|
f1 = self.relu(self.fast_bn1(self.fast_conv1(fast_input)))
|
|
|
f1_p = self.fast_pool1(f1)
|
|
|
|
|
|
|
|
|
s1 = self.relu(self.slow_bn1(self.slow_conv1(slow_input)))
|
|
|
s1_p = self.slow_pool1(s1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
l1 = self.lateral1(f1_p)
|
|
|
s2_input = torch.cat([s1_p, l1], dim=1)
|
|
|
|
|
|
|
|
|
f2 = self.relu(self.fast_bn2(self.fast_conv2(f1_p)))
|
|
|
|
|
|
|
|
|
s2 = self.relu(self.slow_bn2(self.slow_conv2(s2_input)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
l2 = self.lateral2(f2)
|
|
|
s3_input = torch.cat([s2, l2], dim=1)
|
|
|
|
|
|
|
|
|
f3 = self.relu(self.fast_bn3(self.fast_conv3(f2)))
|
|
|
|
|
|
|
|
|
s3 = self.relu(self.slow_bn3(self.slow_conv3(s3_input)))
|
|
|
|
|
|
|
|
|
f_out = self.avg_pool(f3).view(f3.size(0), -1)
|
|
|
s_out = self.avg_pool(s3).view(s3.size(0), -1)
|
|
|
|
|
|
|
|
|
x = torch.cat([s_out, f_out], dim=1)
|
|
|
x = self.dropout(x)
|
|
|
x = self.fc(x)
|
|
|
|
|
|
return x
|
|
|
|