| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| | class UpSampleBN(nn.Module): |
| | def __init__(self, skip_input, output_features): |
| | super(UpSampleBN, self).__init__() |
| |
|
| | self._net = nn.Sequential(nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1), |
| | nn.BatchNorm2d(output_features), |
| | nn.LeakyReLU(), |
| | nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1), |
| | nn.BatchNorm2d(output_features), |
| | nn.LeakyReLU()) |
| |
|
| | def forward(self, x, concat_with): |
| | up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True) |
| | f = torch.cat([up_x, concat_with], dim=1) |
| | return self._net(f) |
| |
|
| |
|
| | |
| | class UpSampleGN(nn.Module): |
| | def __init__(self, skip_input, output_features): |
| | super(UpSampleGN, self).__init__() |
| |
|
| | self._net = nn.Sequential(Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1), |
| | nn.GroupNorm(8, output_features), |
| | nn.LeakyReLU(), |
| | Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1), |
| | nn.GroupNorm(8, output_features), |
| | nn.LeakyReLU()) |
| |
|
| | def forward(self, x, concat_with): |
| | up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True) |
| | f = torch.cat([up_x, concat_with], dim=1) |
| | return self._net(f) |
| |
|
| |
|
| | |
| | class Conv2d(nn.Conv2d): |
| | def __init__(self, in_channels, out_channels, kernel_size, stride=1, |
| | padding=0, dilation=1, groups=1, bias=True): |
| | super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride, |
| | padding, dilation, groups, bias) |
| |
|
| | def forward(self, x): |
| | weight = self.weight |
| | weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, |
| | keepdim=True).mean(dim=3, keepdim=True) |
| | weight = weight - weight_mean |
| | std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5 |
| | weight = weight / std.expand_as(weight) |
| | return F.conv2d(x, weight, self.bias, self.stride, |
| | self.padding, self.dilation, self.groups) |
| |
|
| |
|
| | |
| | def norm_normalize(norm_out): |
| | min_kappa = 0.01 |
| | norm_x, norm_y, norm_z, kappa = torch.split(norm_out, 1, dim=1) |
| | norm = torch.sqrt(norm_x ** 2.0 + norm_y ** 2.0 + norm_z ** 2.0) + 1e-10 |
| | kappa = F.elu(kappa) + 1.0 + min_kappa |
| | final_out = torch.cat([norm_x / norm, norm_y / norm, norm_z / norm, kappa], dim=1) |
| | return final_out |
| |
|
| |
|
| | |
| | @torch.no_grad() |
| | def sample_points(init_normal, gt_norm_mask, sampling_ratio, beta): |
| | device = init_normal.device |
| | B, _, H, W = init_normal.shape |
| | N = int(sampling_ratio * H * W) |
| | beta = beta |
| |
|
| | |
| | uncertainty_map = -1 * init_normal[:, 3, :, :] |
| |
|
| | |
| | if gt_norm_mask is not None: |
| | gt_invalid_mask = F.interpolate(gt_norm_mask.float(), size=[H, W], mode='nearest') |
| | gt_invalid_mask = gt_invalid_mask[:, 0, :, :] < 0.5 |
| | uncertainty_map[gt_invalid_mask] = -1e4 |
| |
|
| | |
| | _, idx = uncertainty_map.view(B, -1).sort(1, descending=True) |
| |
|
| | |
| | if int(beta * N) > 0: |
| | importance = idx[:, :int(beta * N)] |
| |
|
| | |
| | remaining = idx[:, int(beta * N):] |
| |
|
| | |
| | num_coverage = N - int(beta * N) |
| |
|
| | if num_coverage <= 0: |
| | samples = importance |
| | else: |
| | coverage_list = [] |
| | for i in range(B): |
| | idx_c = torch.randperm(remaining.size()[1]) |
| | coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) |
| | coverage = torch.cat(coverage_list, dim=0) |
| | samples = torch.cat((importance, coverage), dim=1) |
| |
|
| | else: |
| | |
| | remaining = idx[:, :] |
| |
|
| | |
| | num_coverage = N |
| |
|
| | coverage_list = [] |
| | for i in range(B): |
| | idx_c = torch.randperm(remaining.size()[1]) |
| | coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) |
| | coverage = torch.cat(coverage_list, dim=0) |
| | samples = coverage |
| |
|
| | |
| | rows_int = samples // W |
| | rows_float = rows_int / float(H-1) |
| | rows_float = (rows_float * 2.0) - 1.0 |
| |
|
| | cols_int = samples % W |
| | cols_float = cols_int / float(W-1) |
| | cols_float = (cols_float * 2.0) - 1.0 |
| |
|
| | point_coords = torch.zeros(B, 1, N, 2) |
| | point_coords[:, 0, :, 0] = cols_float |
| | point_coords[:, 0, :, 1] = rows_float |
| | point_coords = point_coords.to(device) |
| | return point_coords, rows_int, cols_int |