LEGIONM36's picture
Upload 4 files
81d7389 verified
import torch
import torch.nn as nn
import torchvision.models.video as models
class TwoStreamNetwork(nn.Module):
def __init__(self):
super(TwoStreamNetwork, self).__init__()
# Stream 1: RGB
self.rgb_backbone = models.r3d_18(weights=None)
self.rgb_backbone.fc = nn.Identity() # Remove classification head
# Stream 2: Optical Flow
self.flow_backbone = models.r3d_18(weights=None)
self.flow_backbone.fc = nn.Identity()
# Fusion
# R3D_18 output dim is 512
self.fusion_fc = nn.Sequential(
nn.Linear(512 * 2, 512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, 2)
)
def forward(self, rgb, flow):
idx_rgb = self.rgb_backbone(rgb)
idx_flow = self.flow_backbone(flow)
combined = torch.cat((idx_rgb, idx_flow), dim=1)
out = self.fusion_fc(combined)
return out