| |
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| |
|
| |
|
| | def normalize(in_channels):
|
| | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
| |
|
| | def swish(x):
|
| | return x*torch.sigmoid(x)
|
| |
|
| | class VectorQuantizer(nn.Module):
|
| | """
|
| | see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
|
| | ____________________________________________
|
| | Discretization bottleneck part of the VQ-VAE.
|
| | Inputs:
|
| | - n_e : number of embeddings
|
| | - e_dim : dimension of embedding
|
| | - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
| | _____________________________________________
|
| | """
|
| |
|
| | def __init__(self, n_e, e_dim, beta):
|
| | super(VectorQuantizer, self).__init__()
|
| | self.n_e = n_e
|
| | self.e_dim = e_dim
|
| | self.beta = beta
|
| |
|
| | self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
| | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
| |
|
| | def forward(self, z):
|
| | z_flattened = z.view(-1, self.e_dim)
|
| |
|
| | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
| | torch.sum(self.embedding.weight**2, dim=1) - 2 * \
|
| | torch.matmul(z_flattened, self.embedding.weight.t())
|
| | d1 = torch.sum(z_flattened ** 2, dim=1, keepdim=True)
|
| | d2 = torch.sum(self.embedding.weight**2, dim=1)
|
| | d3 = torch.matmul(z_flattened, self.embedding.weight.t())
|
| |
|
| | min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
|
| | min_encodings = torch.zeros(min_encoding_indices.shape[0], self.n_e).to(z)
|
| | min_encodings.scatter_(1, min_encoding_indices, 1)
|
| |
|
| |
|
| |
|
| | z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
|
| |
|
| |
|
| | loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
|
| | torch.mean((z_q - z.detach()) ** 2)
|
| |
|
| |
|
| | z_q = z + (z_q - z).detach()
|
| |
|
| |
|
| | e_mean = torch.mean(min_encodings, dim=0)
|
| | perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
|
| |
|
| |
|
| | z_q = z_q.permute(0, 2, 1).contiguous()
|
| | return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
| |
|
| | def get_distance(self, z):
|
| | z = z.permute(0, 2, 1).contiguous()
|
| | z_flattened = z.view(-1, self.e_dim)
|
| |
|
| |
|
| | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
| | torch.sum(self.embedding.weight**2, dim=1) - 2 * \
|
| | torch.matmul(z_flattened, self.embedding.weight.t())
|
| | d = torch.reshape(d, (z.shape[0], -1, z.shape[2])).permute(0,2,1).contiguous()
|
| | return d
|
| |
|
| | def get_codebook_entry(self, indices, shape):
|
| |
|
| | min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
|
| | min_encodings.scatter_(1, indices[:,None], 1)
|
| |
|
| |
|
| | z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
|
| |
|
| | if shape is not None:
|
| | z_q = z_q.view(shape)
|
| |
|
| | return z_q
|
| |
|