File size: 1,008 Bytes
81d7389
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import torch.nn as nn
import torchvision.models.video as models

class TwoStreamNetwork(nn.Module):
    def __init__(self):
        super(TwoStreamNetwork, self).__init__()
        
        # Stream 1: RGB
        self.rgb_backbone = models.r3d_18(weights=None)
        self.rgb_backbone.fc = nn.Identity() # Remove classification head
        
        # Stream 2: Optical Flow
        self.flow_backbone = models.r3d_18(weights=None)
        self.flow_backbone.fc = nn.Identity()
        
        # Fusion
        # R3D_18 output dim is 512
        self.fusion_fc = nn.Sequential(
            nn.Linear(512 * 2, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 2)
        )
        
    def forward(self, rgb, flow):
        idx_rgb = self.rgb_backbone(rgb)
        idx_flow = self.flow_backbone(flow)
        
        combined = torch.cat((idx_rgb, idx_flow), dim=1)
        out = self.fusion_fc(combined)
        return out