You can find the paper here: https://arxiv.org/abs/2102.11582.
Please cite us using:
@article{mukhoti2021deterministic,
title={Deterministic Neural Networks with Appropriate Inductive Biases Capture Epistemic and Aleatoric Uncertainty},
author={Mukhoti, Jishnu and Kirsch, Andreas and van Amersfoort, Joost and Torr, Philip HS and Gal, Yarin},
journal={arXiv preprint arXiv:2102.11582},
year={2021}
}
Overview
We train a VAE on the original MNIST dataset to be able to synthesize new MNIST digits. Then, we train an ensemble of LeNet models on MNIST that we can use to select ambiguous samples (low epistemic uncertainty, high aleatoric uncertainty) from the VAE.
To calibrate the VAE's output to match MNIST's unwhitened outputs as much as possible, we synthesize unambiguous digits using the VAE first and then adjust the outputs slightly.
To cover a wider range of entropies, we employ stratified sampling of ambiguous samples while rejecting samples with high epistemic uncertainty (info gain).
We try to different means of sampling from the VAE. The first one samples randomly from the latent space (unit Gaussian). The other method encodes existing MNIST samples into the latent space and then interpolates between them, using Barycentric coordinates. I.e. by randomly picking a convex interpolation.
Software Engineering
To save progress and resources, the script caches most results locally using a helper function restore_or_create
. We did not upload the cache here—you can generate it yourself if you wish so. We only upload to this website the final AMNIST samples.
import os
from functools import wraps
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tqdm.auto import tqdm
torch.manual_seed(1)
np.random.seed(1)
# Setup matplotlib to make it easier to copy outputs into Slack (seriously).
plt.rcParams["figure.facecolor"] = "white"
batch_size = 128
Load MNIST:
mnist_mean, mnist_std = 0.1307, 0.3081
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((mnist_mean,), (mnist_std,))])
mnist_train_dataset = datasets.MNIST(root="./mnist_data/", train=True, transform=transform, download=True)
mnist_test_dataset = datasets.MNIST(root="./mnist_data/", train=False, transform=transform, download=True)
mnist_train_loader = torch.utils.data.DataLoader(dataset=mnist_train_dataset, batch_size=batch_size, shuffle=True)
mnist_test_loader = torch.utils.data.DataLoader(dataset=mnist_test_dataset, batch_size=batch_size, shuffle=False)
Two helper functions to normalize and unnormalize MNIST samples. MNIST samples are stored as grayscale bytes in a range 0..255, which is first converted to 0..1 floats by the MNIST dataset code in torchvision
. We then normalize the values further to have zero mean and standard deviation 1, which helps with training.
def unnormalize_mnist(x):
return x * mnist_std + mnist_mean
def normalize_mnist(x):
return (x - mnist_mean) / mnist_std
Finally, we need a helper function to make it easy to backup and restore resources:
def restore_or_create(file_or_path, recreate: bool = False, *, pickle_module=None, pickle_load_args={}):
pickle_module_arg = dict(pickle_module=pickle_module) if pickle_module else {}
def delegate(func):
@wraps(func)
def wrapper():
if not recreate:
try:
return torch.load(file_or_path, {**pickle_load_args, **pickle_module_arg})
except FileNotFoundError:
pass
result = func()
torch.save(result, file_or_path, pickle_protocol=-1, **pickle_module_arg)
return result
return wrapper
return delegate
This code is from https://github.com/pytorch/examples/blob/master/vae/main.py.
Note, we train it on whitened MNIST, and have adapted the loss function for that.
class VAE(nn.Module):
def __init__(self, x_dim, h_dim1, h_dim2, z_dim):
super(VAE, self).__init__()
# encoder part
self.fc1 = nn.Linear(x_dim, h_dim1)
self.fc2 = nn.Linear(h_dim1, h_dim2)
self.fc31 = nn.Linear(h_dim2, z_dim)
self.fc32 = nn.Linear(h_dim2, z_dim)
# decoder part
self.fc4 = nn.Linear(z_dim, h_dim2)
self.fc5 = nn.Linear(h_dim2, h_dim1)
self.fc6 = nn.Linear(h_dim1, x_dim)
def encoder(self, x):
h = F.relu(self.fc1(x))
h = F.relu(self.fc2(h))
return self.fc31(h), self.fc32(h) # mu, log_var
def sampling(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return eps.mul(std).add_(mu) # return z sample
def decoder(self, z):
h = F.relu(self.fc4(z))
h = F.relu(self.fc5(h))
return (F.sigmoid(self.fc6(h)) - mnist_mean) / mnist_std
def forward(self, x):
mu, log_var = self.encoder(x.view(-1, 784))
z = self.sampling(mu, log_var)
return self.decoder(z), mu, log_var
def loss_function(recon_x, x, mu, log_var):
BCE = F.binary_cross_entropy(
unnormalize_mnist(recon_x),
unnormalize_mnist(x).view(-1, 784),
reduction="sum",
)
KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
return BCE + KLD
def train_vae_epoch(vae, optimizer, train_loader, epoch):
vae.train()
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.cuda()
optimizer.zero_grad()
recon_batch, mu, log_var = vae(data)
loss = loss_function(recon_batch, data, mu, log_var)
loss.backward()
train_loss += loss.item()
optimizer.step()
print("====> Epoch: {} Average loss: {:.4f}".format(epoch, train_loss / len(train_loader.dataset)))
def test_vae(vae, test_loader):
vae.eval()
test_loss = 0
with torch.no_grad():
for data, _ in test_loader:
data = data.cuda()
recon, mu, log_var = vae(data)
# sum up batch loss
test_loss += loss_function(recon, data, mu, log_var).item()
test_loss /= len(test_loader.dataset)
print("====> Test set loss: {:.4f}".format(test_loss))
@restore_or_create("vae_ambiguous_mnist.model", recreate=False)
def train_vae():
z_dim = 32
# build model
vae = VAE(x_dim=784, h_dim1=512, h_dim2=256, z_dim=z_dim)
if torch.cuda.is_available():
vae.cuda()
optimizer = optim.Adam(vae.parameters())
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[25, 50, 65], gamma=0.2)
os.makedirs("./samples", exist_ok=True)
for epoch in tqdm(range(1, 75)):
train_vae_epoch(vae, optimizer, mnist_train_loader, epoch)
test_vae(vae, mnist_test_loader)
scheduler.step()
with torch.no_grad():
z = torch.randn(64, z_dim).cuda()
# print (z.shape)
sample = vae.decoder(z).cuda()
save_image(
unnormalize_mnist(sample.view(64, 1, 28, 28)),
"./samples/sample_" + str(epoch) + ".png",
)
return vae
vae = train_vae()
vae
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import PercentFormatter
from torchvision.utils import make_grid
def show(img):
npimg = img.cpu().numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)), interpolation="nearest")
Before we calibrate the VAE properly, we use the following method to post process the samples which will help them match the histogram of MNIST better.
The original LeNet-5 paper states (using inverted colors):
The values of the input pixels are normalized so that the background level (white) corresponds to a value of -0.1 and the foreground (black) corresponds to 1.175. This makes the mean input roughly zero and the variance roughly one, which accelerates learning.
def post_process_vae_output(sample):
unnormalized = unnormalize_mnist(sample)
unnormalized = torch.clamp(1.275 * unnormalized - 0.1, 0, 1)
return unnormalized
Sample images from the latent space and visualize their histogram:
def pixel_histogram(samples, num_bins):
train_histogram = np.histogram(samples.cpu().flatten().numpy(), bins=num_bins)
print(train_histogram[0] / samples.numel(), train_histogram[1])
plt.bar(train_histogram[1][:-1], train_histogram[0], align="edge", width=1 / num_bins)
plt.gca().yaxis.set_major_formatter(PercentFormatter(xmax=samples.numel()))
plt.show()
plt.figure(figsize=(8, 16))
with torch.no_grad():
z = torch.randn(128, 32).cuda()
# print (z.shape)
sample = vae.decoder(z).cuda()
unnormalized = post_process_vae_output(sample.view(128, 1, 28, 28))
show(make_grid(unnormalized, normalize=False, scale_each=False))
plt.show()
pixel_histogram(unnormalized, num_bins=20)
Compare to MNIST test samples:
plt.figure(figsize=(8, 16))
for data, _ in mnist_test_loader:
unnormalized = unnormalize_mnist(data.view(batch_size, 1, 28, 28))
print(unnormalized.min(), unnormalized.max())
show(make_grid(unnormalized, normalize=False, scale_each=False))
plt.show()
plt.show()
pixel_histogram(unnormalized, num_bins=20)
break
The histograms are already quite similar. The sample visualizations seem to validate this further.
Calibrating the VAE
Reconstructing MNIST's training set
Let's calibrate the output of the VAE to match real MNIST as much as possible. We want to improve over post_process_vae_output
as much as possible. The simplest solution to this is to pass all MNIST training samples through the VAE and look at the histogram after they have been decoded.
class_indices = [np.nonzero((mnist_train_dataset.targets == c).cpu().numpy())[0] for c in range(10)]
def pred_entropy(
pre_softmaxs,
): # shape: [MC samples from pred / ensemble, num data, classes]
softmaxs = F.softmax(pre_softmaxs, dim=-1)
softmax = torch.mean(softmaxs, dim=0)
logits = torch.log(softmax)
nats = -logits * softmax
nats[torch.isnan(nats)] = 0.0
return nats.sum(dim=-1)
def avg_entropy(
pre_softmaxs,
): # shape: [MC samples from pred / ensemble, num data, classes]
# NOTE: looking at average ent, not pred ent
softmax = F.softmax(pre_softmaxs, dim=-1)
logits = F.log_softmax(pre_softmaxs, dim=-1)
nats = -logits * softmax
nats[torch.isnan(nats)] = 0.0
nats = nats.sum(dim=-1) # sum over classes
return nats.mean(dim=0) # average over MC samples!
@restore_or_create("decoded_mnist_train.pt", recreate=False)
@torch.no_grad()
def decode_mnist_train_samples():
decoded_images = []
decoded_labels = []
correct_labels = 0
total_labels = 0
for data, labels in tqdm(
zip(
torch.split(normalize_mnist(mnist_train_dataset.data / 255.0), 1024),
torch.split(torch.as_tensor(mnist_train_dataset.targets), 1024),
)
):
batch_mus = vae.encoder(data.cuda().view(-1, 784))[0]
batch_images = vae.decoder(batch_mus).view(-1, 1, 28, 28)
normalized_samples = normalize_mnist(post_process_vae_output(batch_images))
lenet_pre_softmaxs = [lenet(normalized_samples.cuda()) for lenet in lenets]
lenet_pre_softmaxs = torch.stack(lenet_pre_softmaxs)
batch_labels = F.softmax(lenet_pre_softmaxs, dim=-1).mean(dim=0).argmax(dim=-1)
correct_labels += (batch_labels.cpu() == labels).sum().item()
total_labels += len(labels)
decoded_images.append(batch_images.cpu())
decoded_labels.append(batch_labels.cpu())
decoded_images = torch.cat(decoded_images).cpu()
decoded_labels = torch.cat(decoded_labels).cpu()
print(f"Acc: {100*correct_labels/total_labels}")
return decoded_images, decoded_labels
(
decoded_mnist_train_images,
decoded_mnist_train_labels,
) = decode_mnist_train_samples()
num_histogram_bins = 20
pixel_histogram(mnist_train_dataset.data / 255.0, num_histogram_bins)
pixel_histogram(post_process_vae_output(decoded_mnist_train_images), num_bins=num_histogram_bins)
We see two peaks: 82% of the pixels are close to 0 and 7.9% are close to 1.
The histograms already look like quite similar. Can we do better though?
We will compute and adjust the unnormalized samples to match the unnormalized MNIST samples as much as possible, which yields adjusted_post_process_vae_output
. The adjusted histogram mostly matches MNIST better now. This "optimization" was done using graduate student descent while treating the scale and shift in adjusted_post_process_vae_output
as hyperparameters:
vae_mnist_mean = torch.mean(decoded_mnist_train_images)
vae_mnist_std = torch.std(decoded_mnist_train_images)
print(vae_mnist_mean, vae_mnist_std)
vae_mnist_min, vae_mnist_max = (
decoded_mnist_train_images.min(),
decoded_mnist_train_images.max(),
)
print(vae_mnist_min, vae_mnist_max)
For comparison, let's verify that our MNIST normalization constants are sensible:
torch.mean(mnist_train_dataset.data / 255.0), torch.std(mnist_train_dataset.data / 255.0)
mnist_mean, mnist_std
Anyhow, here is the adjusted post-processing function:
def adjusted_post_process_vae_output(sample):
# Instead of using unnormalize_mnist directly, we first normalize the samples using the real mean and std dev.
normalized_amnist = (sample - vae_mnist_mean) / vae_mnist_std
unnormalized_mnist = unnormalize_mnist(normalized_amnist)
unnormalized_mnist = torch.clamp(1.20 * unnormalized_mnist - 0.15, 0, 1)
return unnormalized_mnist
num_histogram_bins = 20 * 2
pixel_histogram(mnist_train_dataset.data / 255.0, num_histogram_bins)
pixel_histogram(adjusted_post_process_vae_output(decoded_mnist_train_images), num_bins=num_histogram_bins)
pixel_histogram(post_process_vae_output(decoded_mnist_train_images), num_bins=num_histogram_bins)
from scipy.stats import wasserstein_distance
def compare_distributions(pixels_a, pixels_b):
return wasserstein_distance(pixels_a.flatten().numpy(), pixels_b.flatten().numpy())
print(
"MNIST vs reconstruction:",
compare_distributions(mnist_train_dataset.data / 255.0, post_process_vae_output(decoded_mnist_train_images)),
)
print(
"MNIST vs adjusted reconstruction:",
compare_distributions(
mnist_train_dataset.data / 255.0, adjusted_post_process_vae_output(decoded_mnist_train_images)
),
)
adjusted_post_process_vae_output
improves the histgram more than post_process_vae_output
, which is what we wanted.
plt.figure(figsize=(24, 16))
plt.subplot(1, 3, 1)
sample = decoded_mnist_train_images[5 :: len(decoded_mnist_train_images) // 128][:128]
unnormalized = post_process_vae_output(sample.view(128, 1, 28, 28))
show(make_grid(unnormalized, normalize=False, scale_each=False))
plt.subplot(1, 3, 2)
sample = decoded_mnist_train_images[5 :: len(decoded_mnist_train_images) // 128][:128]
unnormalized = adjusted_post_process_vae_output(sample.view(128, 1, 28, 28))
show(make_grid(unnormalized, normalize=False, scale_each=False))
plt.subplot(1, 3, 3)
sample = mnist_train_dataset.data[5 :: len(mnist_train_dataset.data) // 128][:128] / 255.0
unnormalized = sample.view(128, 1, 28, 28)
show(make_grid(unnormalized, normalize=False, scale_each=False))
plt.show()
We want samples that are ambiguous but not out-of-distribution. Epistemic uncertainty should be low, but aleatoric uncertainty should not. Someone looking at the sample should recognize that is a digit, but should not be able to tell which one with certainty.
We can measure epistemic uncertainty using the mutual information of an ensemble. When the mutual information is low, that is epistemic uncertainty is low, the predictive entropy will tell us the aleatoric uncertainty.
A slightly adapted version of LeNet-5 (https://ieeexplore.ieee.org/document/726791). We don't use padding, but use max pooling and ReLUs.
class LeNet(nn.Module):
def __init__(self, num_classes):
super(LeNet, self).__init__()
self.num_classes = num_classes
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(256, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, num_classes)
def forward(self, x):
out = F.relu(self.conv1(x))
out = F.max_pool2d(out, 2)
out = F.relu(self.conv2(out))
out = F.max_pool2d(out, 2)
out = out.view(out.size(0), -1)
out = F.relu(self.fc1(out))
out = F.relu(self.fc2(out))
out = self.fc3(out)
return out
def lenet_train(epoch, train_loader, optimizer, lenet):
lenet.train()
train_loss = 0
for batch_idx, (data, label) in enumerate(train_loader):
data = data.cuda()
label = label.cuda()
optimizer.zero_grad()
out = lenet(data)
loss = F.cross_entropy(out, label)
loss.backward()
train_loss += loss.item()
optimizer.step()
if batch_idx % 100 == 0:
print(
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
epoch,
batch_idx * len(data),
len(train_loader.dataset),
100.0 * batch_idx / len(train_loader),
loss.item() / len(data),
)
)
print("====> Epoch: {} Average loss: {:.4f}".format(epoch, train_loss / len(train_loader.dataset)))
def lenet_test(test_loader, lenet):
lenet.eval()
test_loss = 0
with torch.no_grad():
for data, label in test_loader:
data = data.cuda()
label = label.cuda()
out = lenet(data)
# sum up batch loss
test_loss += F.cross_entropy(out, label).item()
test_loss /= len(test_loader.dataset)
print("====> Test set loss: {:.4f}".format(test_loss))
return test_loss
lenets = []
num_ensemble_components = 5
for i in tqdm(range(num_ensemble_components)):
@restore_or_create(f"lenet_model_{i}.model")
def train_lenet_model():
lenet = LeNet(num_classes=10).cuda()
optimizer = optim.Adam(lenet.parameters())
best_loss = None
best_epoch = None
for epoch in tqdm(
range(1, 101 // num_ensemble_components)
): # lenet seems to converge after about 10 epochs anyways..
lenet_train(epoch, mnist_train_loader, optimizer, lenet)
test_loss = lenet_test(mnist_test_loader, lenet)
if not best_loss or best_loss > test_loss:
torch.save(lenet.state_dict(), "./tmp_lenet_best.model")
best_loss = test_loss
best_epoch = epoch
print("New best model", best_epoch, " with ", best_loss)
print("Best epoch", best_epoch)
lenet.load_state_dict(torch.load("./tmp_lenet_best.model"))
return lenet
lenets.append(train_lenet_model())
@restore_or_create("stratified_ambiguous_samples.pt", recreate=False)
@torch.no_grad()
def create_stratified_ambiguous_samples():
mi_threshold = 0.05
bin_edges = [0, 0.05, 0.5, 1.0, 10.0]
target_counts = [60000, 30000, 20000, 10000]
num_bins = len(bin_edges) - 1
bucket_counts = [0] * num_bins
buckets = [[] for _ in range(num_bins)]
assert len(target_counts) == len(bucket_counts) == len(buckets) == num_bins
batch_size = 16384
total_target_count = sum(target_counts)
sample_progress = tqdm(range(total_target_count))
while any(count < target for count, target in zip(bucket_counts, target_counts)):
z = torch.randn(batch_size, 32).cuda()
decoded_lerps = vae.decoder(z).view(batch_size, 1, 28, 28)
normalized_samples = normalize_mnist(adjusted_post_process_vae_output(decoded_lerps))
lenet_pre_softmaxs = [lenet(normalized_samples.cuda()) for lenet in lenets]
lenet_pre_softmaxs = torch.stack(lenet_pre_softmaxs)
batch_pred_entropies = pred_entropy(lenet_pre_softmaxs)
batch_avg_entropies = avg_entropy(lenet_pre_softmaxs)
batch_mi = batch_pred_entropies - batch_avg_entropies
id_samples = batch_mi <= mi_threshold
for bucket_index, (count, target, bucket, lower, upper) in enumerate(
zip(
bucket_counts,
target_counts,
buckets,
bin_edges[:-1],
bin_edges[1:],
)
):
bucket_samples = id_samples & (lower <= batch_pred_entropies) & (batch_pred_entropies < upper)
num_new_samples = bucket_samples.sum().item()
if num_new_samples == 0 or count >= target:
continue
bucket.append(decoded_lerps[bucket_samples])
bucket_counts[bucket_index] += num_new_samples
sample_progress.update(num_new_samples)
sample_progress.close()
buckets = [torch.cat(bucket).cpu() for bucket in buckets if bucket]
buckets = torch.cat(buckets)
return buckets
stratified_ambiguous_samples_random_latent = create_stratified_ambiguous_samples().cpu()
@restore_or_create("stratified_ambiguous_samples_convex_mixture.pt", recreate=False)
@torch.no_grad()
def create_stratified_ambiguous_mixture_samples():
mi_threshold = 0.05
bin_edges = [0, 0.05, 0.5, 1.0, 1.6]
target_counts = [60000, 60000, 40000, 20000]
num_bins = len(bin_edges) - 1
bucket_counts = [0] * num_bins
buckets = [[] for _ in range(num_bins)]
assert len(target_counts) == len(bucket_counts) == len(buckets) == num_bins
total_target_count = sum(target_counts)
num_samples_per_clique = 64
num_label_range = list(range(2, 5))
batch_size = 65536 // num_samples_per_clique
sample_progress = tqdm(total=total_target_count)
while any(count < target for count, target in zip(bucket_counts, target_counts)):
sample_progress.total = sum(max(count, target) for count, target in zip(bucket_counts, target_counts))
for num_labels in num_label_range:
num_new_samples = 0
batch_labels = []
batch_data = []
for j in range(batch_size):
label = np.random.choice(10, num_labels, replace=False)
batch_indices = [np.random.choice(len(class_indices[c]), 1)[0] for c in label]
data = torch.stack([mnist_train_dataset[i][0] for i in batch_indices]).cuda()
label = torch.tensor(label).cuda()
batch_labels.append(label)
batch_data.append(data)
batch_data = torch.stack(batch_data)
batch_labels = torch.stack(batch_labels)
mu, logvar = vae.encoder(batch_data.view(-1, 784))
mix_weights = torch.rand(batch_size * num_samples_per_clique, num_labels + 1).cuda()
mix_weights[:, 0] = 0
mix_weights[:, -1] = 1
mix_weights = torch.sort(mix_weights, dim=-1)[0]
mix_weights = mix_weights[:, 1:] - mix_weights[:, :-1]
mixed_encodings = mix_weights.view(batch_size, num_samples_per_clique, num_labels) @ mu.view(
batch_size, num_labels, -1
)
decoded_lerps = vae.decoder(mixed_encodings).view(batch_size * num_samples_per_clique, 1, 28, 28)
normalized_samples = normalize_mnist(adjusted_post_process_vae_output(decoded_lerps))
lenet_pre_softmaxs = [lenet(normalized_samples.cuda()) for lenet in lenets]
lenet_pre_softmaxs = torch.stack(lenet_pre_softmaxs)
batch_pred_entropies = pred_entropy(lenet_pre_softmaxs)
batch_avg_entropies = avg_entropy(lenet_pre_softmaxs)
batch_mi = batch_pred_entropies - batch_avg_entropies
id_samples = batch_mi <= mi_threshold
for bucket_index, (count, target, bucket, lower, upper,) in enumerate(
zip(
bucket_counts,
target_counts,
buckets,
bin_edges[:-1],
bin_edges[1:],
)
):
bucket_samples = id_samples & (lower <= batch_pred_entropies) & (batch_pred_entropies < upper)
num_new_samples = bucket_samples.sum().item()
if num_new_samples == 0 or count >= target:
continue
bucket.append(decoded_lerps[bucket_samples].cpu())
bucket_counts[bucket_index] += num_new_samples
sample_progress.update(num_new_samples)
sample_progress.close()
buckets = [torch.cat(bucket).cpu() for bucket in buckets if bucket]
buckets = torch.cat(buckets).cpu()
return buckets
stratified_ambiguous_samples_convex_mixture = create_stratified_ambiguous_mixture_samples().cpu()
@restore_or_create("ambiguous_samples_convex_mixture.pt", recreate=False)
@torch.no_grad()
def create_ambiguous_mixture_samples():
mi_threshold = 0.05
bin_edges = [0.05, 3]
target_counts = [2 ** 20 * 4]
num_bins = len(bin_edges) - 1
bucket_counts = [0] * num_bins
buckets = [[] for _ in range(num_bins)]
assert len(target_counts) == len(bucket_counts) == len(buckets) == num_bins
total_target_count = sum(target_counts)
num_samples_per_clique = 64
num_label_range = list(range(2, 5))
batch_size = 65536 // num_samples_per_clique
label_batch_size = 32
sample_progress = tqdm(total=total_target_count)
while any(count < target for count, target in zip(bucket_counts, target_counts)):
sample_progress.total = sum(max(count, target) for count, target in zip(bucket_counts, target_counts))
for num_labels in num_label_range:
num_new_samples = 0
batch_labels = []
batch_data = []
for j in range(batch_size // label_batch_size):
label = np.random.choice(10, num_labels, replace=False)
batch_indices = [np.random.choice(len(class_indices[c]), label_batch_size) for c in label]
data = torch.stack(
[normalize_mnist(mnist_train_dataset.data[i] / 255.0) for i in batch_indices], dim=1
).cuda()
label = torch.tensor(np.repeat(label, label_batch_size)).cuda()
batch_labels.append(label)
batch_data.append(data)
batch_data = torch.stack(batch_data)
batch_labels = torch.stack(batch_labels)
mu, logvar = vae.encoder(batch_data.view(-1, 784))
mix_weights = torch.rand(batch_size * num_samples_per_clique, num_labels + 1).cuda()
mix_weights[:, 0] = 0
mix_weights[:, -1] = 1
mix_weights = torch.sort(mix_weights, dim=-1)[0]
mix_weights = mix_weights[:, 1:] - mix_weights[:, :-1]
mixed_encodings = mix_weights.view(batch_size, num_samples_per_clique, num_labels) @ mu.view(
batch_size, num_labels, -1
)
decoded_lerps = vae.decoder(mixed_encodings).view(batch_size * num_samples_per_clique, 1, 28, 28)
normalized_samples = normalize_mnist(adjusted_post_process_vae_output(decoded_lerps))
lenet_pre_softmaxs = [lenet(normalized_samples.cuda()) for lenet in lenets]
lenet_pre_softmaxs = torch.stack(lenet_pre_softmaxs)
batch_pred_entropies = pred_entropy(lenet_pre_softmaxs)
batch_avg_entropies = avg_entropy(lenet_pre_softmaxs)
batch_mi = batch_pred_entropies - batch_avg_entropies
id_samples = batch_mi <= mi_threshold
for bucket_index, (count, target, bucket, lower, upper,) in enumerate(
zip(
bucket_counts,
target_counts,
buckets,
bin_edges[:-1],
bin_edges[1:],
)
):
bucket_samples = id_samples & (lower <= batch_pred_entropies) & (batch_pred_entropies < upper)
num_new_samples = bucket_samples.sum().item()
if num_new_samples == 0 or count >= target:
continue
bucket.append(decoded_lerps[bucket_samples].cpu())
bucket_counts[bucket_index] += num_new_samples
sample_progress.update(num_new_samples)
sample_progress.close()
buckets = [torch.cat(bucket).cpu() for bucket in buckets if bucket]
buckets = torch.cat(buckets).cpu()
return buckets
ambiguous_samples_convex_mixture = create_ambiguous_mixture_samples().cpu()
We compute various statistics to examine the samples we have generated further.
from dataclasses import dataclass
@dataclass
class SampleEvaluation:
pred_entropies: torch.Tensor
mutual_infos: torch.Tensor
softmax_logits: torch.Tensor
softmax_predictions: torch.Tensor
predictions: torch.Tensor
single_labels: torch.Tensor
@torch.no_grad()
def evaluate_samples(samples):
pred_entropies = []
avg_entropies = []
softmax_logits = []
softmax_predictions = []
predictions = []
predicted_labels = []
for decoded_image_batch in tqdm(torch.split(samples, 32768)):
normalized_sample = normalize_mnist(adjusted_post_process_vae_output(decoded_image_batch.cuda()))
lenet_pre_softmaxs = [lenet(normalized_sample) for lenet in lenets]
lenet_pre_softmaxs = torch.stack(lenet_pre_softmaxs)
batch_pred_entropies = pred_entropy(lenet_pre_softmaxs)
batch_avg_entropies = avg_entropy(lenet_pre_softmaxs)
pred_entropies.append(batch_pred_entropies.cpu())
avg_entropies.append(batch_avg_entropies.cpu())
lenet_pre_softmaxs = torch.transpose(lenet_pre_softmaxs, 0, 1) # swap MC samples with batch dims
softmax_logits.append(lenet_pre_softmaxs.cpu())
batch_sm_pred = F.softmax(lenet_pre_softmaxs, dim=-1)
softmax_predictions.append(batch_sm_pred.cpu())
batch_pred = batch_sm_pred.mean(dim=1)
predictions.append(batch_pred.cpu())
batch_pred_labels = batch_pred.argmax(-1)
predicted_labels.append(batch_pred_labels.cpu())
pred_entropies = torch.cat(pred_entropies)
avg_entropies = torch.cat(avg_entropies)
softmax_logits = torch.cat(softmax_logits)
softmax_predictions = torch.cat(softmax_predictions)
predictions = torch.cat(predictions)
predicted_labels = torch.cat(predicted_labels)
return SampleEvaluation(
pred_entropies,
pred_entropies - avg_entropies,
softmax_logits,
softmax_predictions,
predictions,
predicted_labels,
)
stratified_ambiguous_samples_random_latent_evaluation = restore_or_create(
"stratified_ambiguous_samples_random_latent_evaluation.pt", recreate=False
)(lambda: evaluate_samples(stratified_ambiguous_samples_random_latent))()
stratified_ambiguous_samples_convex_mixture_evaluation = restore_or_create(
"stratified_ambiguous_samples_convex_mixture_evaluation.pt", recreate=False
)(lambda: evaluate_samples(stratified_ambiguous_samples_convex_mixture))()
ambiguous_samples_convex_mixture_evaluation = restore_or_create(
"ambiguous_samples_convex_mixture_evaluation.pt", recreate=False
)(lambda: evaluate_samples(ambiguous_samples_convex_mixture))()
plt.yscale("log")
plt.hist(stratified_ambiguous_samples_random_latent_evaluation.pred_entropies.numpy(), bins=20)
plt.title("Predictive Entropy of `stratified_ambiguous_samples_random_latent_evaluation`")
plt.show()
plt.yscale("log")
plt.hist(stratified_ambiguous_samples_convex_mixture_evaluation.pred_entropies.numpy(), bins=20)
plt.title("Predictive Entropy of `stratified_ambiguous_samples_convex_mixture_evaluation`")
plt.show()
plt.yscale("log")
plt.hist(ambiguous_samples_convex_mixture_evaluation.pred_entropies.numpy(), bins=20)
plt.title("Predictive Entropy of `ambiguous_samples_convex_mixture_evaluation`")
plt.show()
def show_stratified_samples(
samples, pred_entropies, num_pe_bins, num_bin_samples, bin_range=None, predictions=None, one_per_class=False
):
stratified_samples = []
edges = np.histogram_bin_edges(pred_entropies, bins=num_pe_bins, range=bin_range)
for lower, upper in zip(edges[:-1], edges[1:]):
bucket_mask = (pred_entropies >= lower) & (pred_entropies < upper)
bucket_samples = samples[bucket_mask]
if one_per_class:
assert num_bin_samples == 10
assert predictions is not None
bucket_predictions = predictions[bucket_mask]
bucket_labels = predictions[bucket_mask].argmax(dim=-1)
sample_indices = np.concatenate(
[np.random.choice(torch.nonzero(bucket_labels == c, as_tuple=False)[:, 0], size=1) for c in range(10)]
)
else:
sample_indices = np.random.choice(len(bucket_samples), size=num_bin_samples, replace=False)
row_samples = bucket_samples[sample_indices]
stratified_samples.append(row_samples)
if predictions is not None:
sorted_probs, original_indices = torch.sort(
predictions[bucket_mask][sample_indices], dim=-1, descending=True
)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
row_labels = []
for sample_classes, cumulative_prob in zip(original_indices, cumulative_probs):
majority_labels = []
for c, cs in zip(sample_classes, cumulative_prob):
majority_labels += [c.item()]
if cs > 0.8:
break
row_labels += [majority_labels]
print(row_labels)
stratified_samples = torch.cat(stratified_samples)
plt.figure(figsize=(num_pe_bins, num_bin_samples), dpi=150)
show(
make_grid(
adjusted_post_process_vae_output(stratified_samples.view(-1, 1, 28, 28)),
normalize=False,
nrow=num_bin_samples,
)
)
plt.show()
show_stratified_samples(
stratified_ambiguous_samples_random_latent,
stratified_ambiguous_samples_random_latent_evaluation.pred_entropies,
15,
4,
bin_range=[0.05, 1.1],
predictions=stratified_ambiguous_samples_random_latent_evaluation.predictions,
)
show_stratified_samples(
stratified_ambiguous_samples_convex_mixture,
stratified_ambiguous_samples_convex_mixture_evaluation.pred_entropies,
15,
4,
bin_range=[0.05, 1.1],
predictions=stratified_ambiguous_samples_convex_mixture_evaluation.predictions,
)
show_stratified_samples(
ambiguous_samples_convex_mixture,
ambiguous_samples_convex_mixture_evaluation.pred_entropies,
15,
4,
bin_range=[0.05, 1.1],
predictions=ambiguous_samples_convex_mixture_evaluation.predictions,
)
show_stratified_samples(
ambiguous_samples_convex_mixture,
ambiguous_samples_convex_mixture_evaluation.pred_entropies,
8,
10,
bin_range=[0.05, 1.2],
predictions=ambiguous_samples_convex_mixture_evaluation.predictions,
one_per_class=True,
)
The convex mixture samples look much better overall. However, we have to limit the entropy range to [0.1, 1.4]
as there are not sufficiently many samples beyond that:
plt.hist(
ambiguous_samples_convex_mixture_evaluation.pred_entropies.numpy(),
bins=10,
range=[1.4, 1.6],
)
def show_label_buckets(entropies, labels, num_bins, range=None):
edges = np.histogram_bin_edges(entropies.numpy(), bins=num_bins, range=range)
plt.figure(figsize=(num_bins * 4, 3))
for i, (lower, upper) in enumerate(zip(edges[:-1], edges[1:])):
plt.subplot(1, num_bins, i + 1)
bucket_labels = labels[(lower <= entropies) & (entropies < upper)]
# plt.yscale("log")
plt.title(f"{lower:.2}-{upper:.2}")
plt.hist(bucket_labels.numpy(), bins=10, range=[-0.5, 9.5])
plt.show()
show_label_buckets(
stratified_ambiguous_samples_random_latent_evaluation.pred_entropies,
stratified_ambiguous_samples_random_latent_evaluation.single_labels,
10,
range=[0.05, 1.4],
)
show_label_buckets(
stratified_ambiguous_samples_convex_mixture_evaluation.pred_entropies,
stratified_ambiguous_samples_convex_mixture_evaluation.single_labels,
10,
range=[0.05, 1.4],
)
show_label_buckets(
ambiguous_samples_convex_mixture_evaluation.pred_entropies,
ambiguous_samples_convex_mixture_evaluation.single_labels,
10,
range=[0.05, 1.4],
)
import matplotlib as mpl
def show_probs_vs_entropy(evaluation: SampleEvaluation, num_bins, entropy_range=None):
entropies = evaluation.pred_entropies
predictions = evaluation.predictions
plt.figure(figsize=(num_bins * 4, 3.8))
for i in range(10):
plt.subplot(1, 10, i + 1)
h = plt.hist2d(
predictions[:, i].numpy(),
entropies.numpy(),
bins=num_bins,
norm=mpl.colors.LogNorm(vmax=2e4),
range=[[0.1, 1], entropy_range],
)
plt.xlabel("Class Prediction Probability")
plt.ylabel("Predictive Entropy")
plt.title(f"Class {i}")
ax = plt.gca()
plt.colorbar(h[-1], use_gridspec=True, ax=ax)
show_probs_vs_entropy(stratified_ambiguous_samples_random_latent_evaluation, 10, entropy_range=[0.1, 1.6])
plt.suptitle("`stratified_ambiguous_samples_random_latent_evaluation`")
show_probs_vs_entropy(stratified_ambiguous_samples_convex_mixture_evaluation, 10, entropy_range=[0.1, 1.6])
plt.suptitle("`stratified_ambiguous_samples_convex_mixture_evaluation`")
show_probs_vs_entropy(ambiguous_samples_convex_mixture_evaluation, 10, entropy_range=[0.1, 1.6])
plt.suptitle("`ambiguous_samples_convex_mixture_evaluation`")
plt.show()
# ambiguous_samples_evaluation = stratified_ambiguous_samples_convex_mixture_evaluation
ambiguous_samples = ambiguous_samples_convex_mixture
ambiguous_samples_evaluation = ambiguous_samples_convex_mixture_evaluation
def entropy(p):
nats = -p * torch.log(p)
nats[torch.isnan(nats)] = 0.0
entropy = torch.sum(nats, dim=-1)
return entropy
def kl(p, q):
items = -p * (torch.log(q) - torch.log(p))
items[torch.isnan(items)] = 0.0
kl = torch.sum(items, dim=-1)
return kl
def show_prob_buckets(entropies, predictions, num_bins, entropy_range=None):
edges = np.histogram_bin_edges(entropies.numpy(), bins=num_bins, range=entropy_range)
plt.figure(figsize=(num_bins * 4, 3))
for i, (lower, upper) in enumerate(zip(edges[:-1], edges[1:])):
plt.subplot(1, num_bins, i + 1)
bucket_prob_distribution = predictions[(lower <= entropies) & (entropies < upper)].mean(dim=0)
print(entropy(bucket_prob_distribution))
# plt.yscale("log")
plt.title(f"{lower:.2}-{upper:.2}")
plt.ylim(0, 0.3)
plt.bar(range(10), bucket_prob_distribution.numpy(), width=1)
plt.show()
import math
def subsample_preserve_histogram(train_num_samples, test_num_samples, entropies):
num_bins = 50
hist, bin_edges = np.histogram(entropies, bins=50, range=[0.05, 1.6])
subsampled_weights = hist ** 0.2 / len(entropies) ** 0.2
subsampled_total = np.sum(subsampled_weights)
subsampled_counts = subsampled_weights / subsampled_total * (train_num_samples + test_num_samples + num_bins)
print(subsampled_counts)
train_stratified_indices = []
test_stratified_indices = []
train_ratio = train_num_samples / (train_num_samples + test_num_samples)
for lower_edge, upper_edge, target_count in zip(bin_edges[:-1], bin_edges[1:], subsampled_counts):
mask = (lower_edge <= entropies) & (entropies < upper_edge)
bucket_indices = torch.nonzero(mask, as_tuple=False)[:, 0]
bucket_entropies = entropies[bucket_indices]
sorted_entropies, original_indices = torch.sort(bucket_entropies)
uniform_draws = torch.rand(math.ceil(target_count)) * (upper_edge - lower_edge) + lower_edge
drawn_indices = torch.searchsorted(sorted_entropies, uniform_draws)
stratified_bucket_indices = bucket_indices[original_indices[drawn_indices - 1]]
train_bucket_size = math.ceil(len(stratified_bucket_indices) * train_ratio)
train_stratified_indices.extend(stratified_bucket_indices[:train_bucket_size].tolist())
test_stratified_indices.extend(stratified_bucket_indices[train_bucket_size:].tolist())
return torch.as_tensor(train_stratified_indices), torch.as_tensor(test_stratified_indices)
@restore_or_create("subsample_indices_train_test.pt", recreate=False)
def get_subsample_train_test_indices():
train_num_samples = int(60e3) // 10
test_num_samples = int(60e3) // 10
entropies = ambiguous_samples_evaluation.pred_entropies
return subsample_preserve_histogram(train_num_samples, test_num_samples, entropies)
subsample_indices_train, subsample_indices_test = get_subsample_train_test_indices()
# indices_test = entropy_balanced_indices[6000:8000]
indices_train = subsample_indices_train
indices_test = subsample_indices_test
len(indices_train), len(indices_test)
show_stratified_samples(
ambiguous_samples[indices_train],
ambiguous_samples_evaluation.pred_entropies[indices_train],
8,
10,
bin_range=[0.05, 0.8],
predictions=ambiguous_samples_evaluation.predictions[indices_train],
one_per_class=True,
)
plt.yscale("log")
plt.hist(
ambiguous_samples_evaluation.pred_entropies[indices_train].numpy(),
bins=16,
range=[0.05, 1.6],
)
plt.show()
plt.yscale("log")
plt.hist(
ambiguous_samples_evaluation.pred_entropies[indices_test].numpy(),
bins=16,
range=[0.05, 1.6],
)
len(indices_train), len(indices_test)
plt.hist(
ambiguous_samples_evaluation.pred_entropies.numpy(),
bins=50,
range=[0.05, 1.45],
)
show_prob_buckets(
ambiguous_samples_evaluation.pred_entropies[indices_train],
ambiguous_samples_evaluation.predictions[indices_train],
5,
entropy_range=[0.05, 1.40],
)
show_prob_buckets(
ambiguous_samples_evaluation.pred_entropies,
ambiguous_samples_evaluation.predictions,
5,
entropy_range=[0.05, 1.40],
)
show_stratified_samples(
ambiguous_samples[indices_train],
ambiguous_samples_evaluation.pred_entropies[indices_train],
10,
4,
bin_range=[0.05, 1.0],
predictions=ambiguous_samples_evaluation.predictions[indices_train],
)
plt.show()
shuffled_balanced_amnist_indices_train = indices_train[torch.randperm(len(indices_train))][:6000]
shuffled_balanced_amnist_indices_test = indices_test[torch.randperm(len(indices_test))][:6000]
shuffled_balanced_amnist_indices = torch.cat(
[shuffled_balanced_amnist_indices_train, shuffled_balanced_amnist_indices_test]
)
len(shuffled_balanced_amnist_indices)
amnist_predictions = ambiguous_samples_evaluation.predictions[shuffled_balanced_amnist_indices]
len(amnist_predictions)
def draw_labels(predictions, num_labels):
drawn_labels = torch.multinomial(predictions, num_samples=num_labels, replacement=True)
return drawn_labels
@restore_or_create("internal_amnist_labels.pt", recreate=True)
def create_amnist_labels():
amnist_labels = draw_labels(amnist_predictions, 10)
return amnist_labels
amnist_labels = create_amnist_labels()
amnist_samples = ambiguous_samples[shuffled_balanced_amnist_indices]
amnist_unnormalized_samples = adjusted_post_process_vae_output(amnist_samples)
torch.save(amnist_predictions, "amnist_predictions.pt")
torch.save(amnist_labels, "amnist_labels.pt")
torch.save(amnist_samples, "amnist_raw_samples.pt")
torch.save(amnist_unnormalized_samples, "amnist_samples.pt")
amnist_samples = torch.load("amnist_samples.pt").expand(-1, 10, 28, 28).reshape(-1, 1, 28, 28)
amnist_labels = torch.load("amnist_labels.pt").reshape(-1)
amnist_samples.shape, amnist_labels.shape
amnist_normalized_samples = normalize_mnist(amnist_samples)
amnist_train_dataset = torch.utils.data.TensorDataset(amnist_normalized_samples[:-60000], amnist_labels[:-60000])
amnist_test_dataset = torch.utils.data.TensorDataset(amnist_normalized_samples[-60000:], amnist_labels[-60000:])
amnist_train_dataset.tensors = (
amnist_train_dataset.tensors[0],
amnist_train_dataset.tensors[1].numpy(),
)
amnist_test_dataset.tensors = (
amnist_test_dataset.tensors[0],
amnist_test_dataset.tensors[1].numpy(),
)
amnist_train_loader = torch.utils.data.DataLoader(
dataset=amnist_train_dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=torch.cuda.is_available(),
)
amnist_test_loader = torch.utils.data.DataLoader(
dataset=amnist_test_dataset,
batch_size=batch_size,
shuffle=False,
pin_memory=torch.cuda.is_available(),
)
dmnist_train_dataset = torch.utils.data.ConcatDataset([mnist_train_dataset, amnist_train_dataset])
dmnist_test_dataset = torch.utils.data.ConcatDataset([mnist_test_dataset, amnist_test_dataset])
dmnist_train_loader = torch.utils.data.DataLoader(
dataset=dmnist_train_dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=torch.cuda.is_available(),
)
dmnist_test_loader = torch.utils.data.DataLoader(
dataset=dmnist_test_dataset,
batch_size=batch_size,
shuffle=False,
pin_memory=torch.cuda.is_available(),
)
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
(0.1307,),
(0.3081,),
),
]
)
fmnist_train_dataset = datasets.FashionMNIST(
root="./fmnist_data/",
train=True,
download=True,
transform=transform,
)
fmnist_test_dataset = datasets.FashionMNIST(
root="./fmnist_data/",
train=False,
download=True,
transform=transform,
)
fmnist_train_loader = torch.utils.data.DataLoader(
fmnist_train_dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=torch.cuda.is_available(),
)
fmnist_test_loader = torch.utils.data.DataLoader(
fmnist_test_dataset,
batch_size=batch_size,
shuffle=False,
pin_memory=torch.cuda.is_available(),
)
dmnist_lenets = []
num_ensemble_components = 5
for i in tqdm(range(num_ensemble_components)):
@restore_or_create(f"dmnist_lenet_model_{i}.model", recreate=True)
def train_lenet_model():
lenet = LeNet(num_classes=10).cuda()
optimizer = optim.Adam(lenet.parameters())
best_loss = None
best_epoch = None
for epoch in tqdm(range(5)): # lenet seems to converge after about 10 epochs anyways..
lenet_train(epoch, dmnist_train_loader, optimizer, lenet)
test_loss = lenet_test(mnist_test_loader, lenet)
if not best_loss or best_loss > test_loss:
torch.save(lenet.state_dict(), "./tmp_lenet_best.model")
best_loss = test_loss
best_epoch = epoch
print("New best model", best_epoch, " with ", best_loss)
print("Best epoch", best_epoch)
lenet.load_state_dict(torch.load("./tmp_lenet_best.model"))
return lenet
dmnist_lenets.append(train_lenet_model())
@dataclass
class Evaluation:
predictions: torch.Tensor
entropies: torch.Tensor
labels: torch.Tensor
@torch.no_grad()
def evaluate(test_loader, lenets):
for lenet in lenets:
lenet.eval()
labels = []
predictions = []
with torch.no_grad():
for data, label in tqdm(test_loader):
data = data.cuda()
label = label.cuda()
batch_outputs = [lenet(data.cuda()) for lenet in lenets]
batch_outputs = torch.stack(batch_outputs)
batch_predictions = F.softmax(batch_outputs.mean(dim=0), dim=-1)
predictions.append(batch_predictions)
labels.append(label)
labels = torch.cat(labels).cpu()
predictions = torch.cat(predictions).cpu()
entropies = entropy(predictions).cpu()
return Evaluation(predictions, entropies, labels)
dmnist_test_evaluation = evaluate(dmnist_test_loader, dmnist_lenets)
amnist_test_evaluation = evaluate(amnist_test_loader, dmnist_lenets)
mnist_test_evaluation = evaluate(mnist_test_loader, dmnist_lenets)
fmnist_test_evaluation = evaluate(fmnist_test_loader, dmnist_lenets)
print(
"MNIST accuracy",
(mnist_test_evaluation.predictions.argmax(dim=-1) == mnist_test_evaluation.labels).sum().item()
/ len(mnist_test_evaluation.labels)
* 100,
)
import seaborn as sns
def plot_entropies(evaluation: Evaluation, **kwargs):
# plt.ylim(0, 0.15)
sns.histplot(evaluation.entropies.numpy(), stat="probability", binrange=[0.0, np.log(10)], bins=15, **kwargs)
plt.figure(figsize=(5 * 4, 4 / 1.6))
plt.subplot(1, 5, 1)
plot_entropies(
dmnist_test_evaluation,
alpha=0.4,
color=sns.color_palette()[0],
label="DMNIST",
)
plot_entropies(
fmnist_test_evaluation,
alpha=0.4,
color=sns.color_palette()[2],
label="FMNIST",
)
plt.legend()
plt.title("DMNIST vs FMNIST")
plt.subplot(1, 5, 2)
plot_entropies(
mnist_test_evaluation,
alpha=0.8,
color=sns.color_palette()[0],
label="MNIST",
)
plot_entropies(
amnist_test_evaluation,
alpha=0.8,
color=sns.color_palette()[1],
label="AMNIST",
)
plt.title("MNIST vs AMNIST")
plt.legend()
plt.subplot(1, 5, 3)
plot_entropies(
amnist_test_evaluation,
alpha=0.8,
color=sns.color_palette()[0],
label="AMNIST",
)
plt.title("AMNIST")
plt.subplot(1, 5, 4)
plot_entropies(
dmnist_test_evaluation,
alpha=0.8,
color=sns.color_palette()[0],
label="DMNIST",
)
plt.title("DMNIST")
plt.subplot(1, 5, 5)
plot_entropies(
fmnist_test_evaluation,
alpha=0.8,
color=sns.color_palette()[0],
label="FMNIST",
)
plt.title("FMNIST")
plt.show()
This shows that training on DirtyMNIST (MNIST + Ambiguous-MNIST) will lead to predictions that correctly span a wide range of aleatoric uncertainty, causing overlaps with OOD data which is also assigned high entropy. We note that there is high overlap, even though we use a LeNet ensemble.