Custom consistent dropout modules
For BNNs, we are going to use MC dropout.
To be able to compute BatchBALD scores, we need consistent MC dropout, which uses the consistent masks for inference. That means, that we draw $K$ masks and then keep them fixed while drawing the $K$ inference samples for each input in the test set.
During training, masks are redrawn for every sample.
Bayesian Module
To make this work in an efficient way, we are going to define an abstract wrapper module that takes a batch input_B
and outputs results_B_K
.
Internally, it will blow up the input batch to $(B \cdot K) \times \cdots$ and then pass it to mc_forward_impl
, which should be overriden.
ConsistentMCDropout
layers will know to reshape the inputs to $B \times K \times \cdots$ and apply consistent masks.
class BayesianModule(Module):
"""A module that we can sample multiple times from given a single input batch.
To be efficient, the module allows for a part of the forward pass to be deterministic.
"""
k = None
def __init__(self):
super().__init__()
# Returns B x n x output
def forward(self, input_B: torch.Tensor, k: int):
BayesianModule.k = k
mc_input_BK = BayesianModule.mc_tensor(input_B, k)
mc_output_BK = self.mc_forward_impl(mc_input_BK)
mc_output_B_K = BayesianModule.unflatten_tensor(mc_output_BK, k)
return mc_output_B_K
def mc_forward_impl(self, mc_input_BK: torch.Tensor):
return mc_input_BK
@staticmethod
def unflatten_tensor(input: torch.Tensor, k: int):
input = input.view([-1, k] + list(input.shape[1:]))
return input
@staticmethod
def flatten_tensor(mc_input: torch.Tensor):
return mc_input.flatten(0, 1)
@staticmethod
def mc_tensor(input: torch.tensor, k: int):
mc_shape = [input.shape[0], k] + list(input.shape[1:])
return input.unsqueeze(1).expand(mc_shape).flatten(0, 1)
class _ConsistentMCDropout(Module):
def __init__(self, p=0.5):
super().__init__()
if p < 0 or p > 1:
raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p))
self.p = p
self.mask = None
def extra_repr(self):
return "p={}".format(self.p)
def reset_mask(self):
self.mask = None
def train(self, mode=True):
super().train(mode)
if not mode:
self.reset_mask()
def _get_sample_mask_shape(self, sample_shape):
return sample_shape
def _create_mask(self, input, k):
mask_shape = [1, k] + list(self._get_sample_mask_shape(input.shape[1:]))
mask = torch.empty(mask_shape, dtype=torch.bool, device=input.device).bernoulli_(self.p)
return mask
def forward(self, input: torch.Tensor):
if self.p == 0.0:
return input
k = BayesianModule.k
if self.training:
# Create a new mask on each call and for each batch element.
k = input.shape[0]
mask = self._create_mask(input, k)
else:
if self.mask is None:
# print('recreating mask', self)
# Recreate mask.
self.mask = self._create_mask(input, k)
mask = self.mask
mc_input = BayesianModule.unflatten_tensor(input, k)
mc_output = mc_input.masked_fill(mask, 0) / (1 - self.p)
# Flatten MCDI, batch into one dimension again.
return BayesianModule.flatten_tensor(mc_output)
import torch
from torch import nn as nn
from torch.nn import functional as F
class BayesianCNN(BayesianModule):
def __init__(self, num_classes=10):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
self.conv1_drop = ConsistentMCDropout2d()
self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
self.conv2_drop = ConsistentMCDropout2d()
self.fc1 = nn.Linear(1024, 128)
self.fc1_drop = ConsistentMCDropout()
self.fc2 = nn.Linear(128, num_classes)
def mc_forward_impl(self, input: torch.Tensor):
input = F.relu(F.max_pool2d(self.conv1_drop(self.conv1(input)), 2))
input = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(input)), 2))
input = input.view(-1, 1024)
input = F.relu(self.fc1_drop(self.fc1(input)))
input = self.fc2(input)
input = F.log_softmax(input, dim=1)
return input
BayesianCNN()