| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import torch |
| |
|
| | class ImageEncoder(nn.Module): |
| | |
| | def __init__(self): |
| | super(ImageEncoder, self).__init__() |
| | self.conv1 = nn.Conv2d(3, 32, 3) |
| | self.conv2 = nn.Conv2d(32, 32, 3) |
| | self.conv3 = nn.Conv2d(32, 64, 3) |
| | self.conv4 = nn.Conv2d(64, 64, 3) |
| | self.conv5 = nn.Conv2d(64, 128, 3) |
| | self.conv6 = nn.Conv2d(128, 128, 3) |
| | self.fc1 = nn.Linear(in_features=128*28*28, out_features=1024) |
| | self.fc2 = nn.Linear(in_features=1024, out_features=1024) |
| |
|
| | def forward(self, x): |
| | |
| | |
| | x = F.relu(self.conv1(x)) |
| | |
| | x = F.relu(self.conv2(x)) |
| | |
| | x = F.max_pool2d(x, 2) |
| | |
| | |
| | x = F.relu(self.conv3(x)) |
| | |
| | x = F.relu(self.conv4(x)) |
| | |
| | x = F.max_pool2d(x, 2) |
| | |
| |
|
| | x = F.relu(self.conv5(x)) |
| | |
| | x = F.relu(self.conv6(x)) |
| | |
| | x = F.max_pool2d(x, 2) |
| | |
| |
|
| | x = x.view(-1, 128*28*28) |
| | x = F.relu(self.fc1(x)) |
| | x = F.relu(self.fc2(x)) |
| | return x |
| |
|
| | class ContextEncoder(nn.Module): |
| |
|
| | def __init__(self): |
| | super(ContextEncoder, self).__init__() |
| | self.rnn = nn.RNN(input_size=19, hidden_size=128, num_layers=2, batch_first=True) |
| | |
| | def forward(self, x, h=None): |
| | |
| |
|
| | if not h: |
| | h = torch.zeros((2, x.size(0), 128)).cuda() |
| |
|
| | x, _ = self.rnn(x, h) |
| | return x |
| |
|
| | class Decoder(nn.Module): |
| |
|
| | def __init__(self): |
| | super(Decoder, self).__init__() |
| | self.rnn = nn.RNN(input_size=1024+128, hidden_size=512, num_layers=2, batch_first=True) |
| | self.l1 = nn.Linear(512, 19) |
| | |
| | def forward(self, image_feature, context_feature, on_cuda = False, h = None): |
| | |
| | image_feature = image_feature.unsqueeze(1) |
| | |
| | image_feature = image_feature.repeat(1, context_feature.size(1), 1) |
| | |
| | x = torch.cat((image_feature, context_feature), 2) |
| | |
| |
|
| | if not h: |
| | h = torch.zeros((2, x.size(0), 512)).cuda() |
| |
|
| | x, _ = self.rnn(x, h) |
| | x = self.l1(x) |
| | |
| | return x |
| |
|
| | class Pix2Code(nn.Module): |
| |
|
| | def __init__(self): |
| | super(Pix2Code, self).__init__() |
| | self.image_encoder = ImageEncoder() |
| | self.context_encoder = ContextEncoder() |
| | self.decoder = Decoder() |
| |
|
| | def forward(self, image, context): |
| | image_feature = self.image_encoder(image) |
| | context_feature = self.context_encoder(context) |
| | output = self.decoder(image_feature, context_feature) |
| | return output |
| |
|