|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torchvision.models.video as models
|
|
|
|
|
|
class TwoStreamNetwork(nn.Module):
|
|
|
def __init__(self):
|
|
|
super(TwoStreamNetwork, self).__init__()
|
|
|
|
|
|
|
|
|
self.rgb_backbone = models.r3d_18(weights=None)
|
|
|
self.rgb_backbone.fc = nn.Identity()
|
|
|
|
|
|
|
|
|
self.flow_backbone = models.r3d_18(weights=None)
|
|
|
self.flow_backbone.fc = nn.Identity()
|
|
|
|
|
|
|
|
|
|
|
|
self.fusion_fc = nn.Sequential(
|
|
|
nn.Linear(512 * 2, 512),
|
|
|
nn.ReLU(),
|
|
|
nn.Dropout(0.5),
|
|
|
nn.Linear(512, 2)
|
|
|
)
|
|
|
|
|
|
def forward(self, rgb, flow):
|
|
|
idx_rgb = self.rgb_backbone(rgb)
|
|
|
idx_flow = self.flow_backbone(flow)
|
|
|
|
|
|
combined = torch.cat((idx_rgb, idx_flow), dim=1)
|
|
|
out = self.fusion_fc(combined)
|
|
|
return out
|
|
|
|