First, we will implement two helper classes to compute conditional entropies $H[y_i|w]$ and entropies $H[y_i]$. Then, we will implement BatchBALD and BALD.
import math
from dataclasses import dataclass
from typing import List
import torch
from toma import toma
from tqdm.auto import tqdm
from batchbald_redux import joint_entropy
We are going to define a couple of sampled distributions to use for our testing our code.
$K=20$ means 20 inference samples.
K = 20
import numpy as np
def get_mixture_prob_dist(p1, p2, m):
return (1.0 - m) * np.asarray(p1) + m * np.asarray(p2)
p1 = [0.7, 0.1, 0.1, 0.1]
p2 = [0.3, 0.3, 0.2, 0.2]
y1_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]
p1 = [0.1, 0.7, 0.1, 0.1]
p2 = [0.2, 0.3, 0.3, 0.2]
y2_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]
p1 = [0.1, 0.1, 0.7, 0.1]
p2 = [0.2, 0.2, 0.3, 0.3]
y3_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]
p1 = [0.1, 0.1, 0.1, 0.7]
p2 = [0.3, 0.2, 0.2, 0.3]
y4_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]
def nested_to_tensor(l):
return torch.stack(list(map(torch.as_tensor, l)))
ys_ws = nested_to_tensor([y1_ws, y2_ws, y3_ws, y4_ws])
ys_ws.shape
def compute_conditional_entropy(probs_N_K_C: torch.Tensor) -> torch.Tensor:
N, K, C = probs_N_K_C.shape
entropies_N = torch.empty(N, dtype=torch.double)
pbar = tqdm(total=N, desc="Conditional Entropy", leave=False)
@toma.execute.chunked(probs_N_K_C, 1024)
def compute(probs_n_K_C, start: int, end: int):
nats_n_K_C = probs_n_K_C * torch.log(probs_n_K_C)
nats_n_K_C[probs_n_K_C == 0] = 0.0
entropies_N[start:end].copy_(-torch.sum(nats_n_K_C, dim=(1, 2)) / K)
pbar.update(end - start)
pbar.close()
return entropies_N
def compute_entropy(probs_N_K_C: torch.Tensor) -> torch.Tensor:
N, K, C = probs_N_K_C.shape
entropies_N = torch.empty(N, dtype=torch.double)
pbar = tqdm(total=N, desc="Entropy", leave=False)
@toma.execute.chunked(probs_N_K_C, 1024)
def compute(probs_n_K_C, start: int, end: int):
mean_probs_n_C = probs_n_K_C.mean(dim=1)
nats_n_C = mean_probs_n_C * torch.log(mean_probs_n_C)
nats_n_C[mean_probs_n_C == 0] = 0.0
entropies_N[start:end].copy_(-torch.sum(nats_n_C, dim=1))
pbar.update(end - start)
pbar.close()
return entropies_N
assert np.allclose(compute_conditional_entropy(yus_ws), [1.3863, 1.3863, 1.3863, 1.3863], atol=0.1)
assert np.allclose(compute_entropy(yus_ws), [1.3863, 1.3863, 1.3863, 1.3863], atol=0.1)
However, our neural networks usually use a log_softmax
as final layer. To avoid having to call .exp_()
, which is easy to miss and annoying to debug, we will instead use a version that uses log_probs
instead of probs
.
def compute_conditional_entropy(log_probs_N_K_C: torch.Tensor) -> torch.Tensor:
N, K, C = log_probs_N_K_C.shape
entropies_N = torch.empty(N, dtype=torch.double)
pbar = tqdm(total=N, desc="Conditional Entropy", leave=False)
@toma.execute.chunked(log_probs_N_K_C, 1024)
def compute(log_probs_n_K_C, start: int, end: int):
nats_n_K_C = log_probs_n_K_C * torch.exp(log_probs_n_K_C)
entropies_N[start:end].copy_(-torch.sum(nats_n_K_C, dim=(1, 2)) / K)
pbar.update(end - start)
pbar.close()
return entropies_N
def compute_entropy(log_probs_N_K_C: torch.Tensor) -> torch.Tensor:
N, K, C = log_probs_N_K_C.shape
entropies_N = torch.empty(N, dtype=torch.double)
pbar = tqdm(total=N, desc="Entropy", leave=False)
@toma.execute.chunked(log_probs_N_K_C, 1024)
def compute(log_probs_n_K_C, start: int, end: int):
mean_log_probs_n_C = torch.logsumexp(log_probs_n_K_C, dim=1) - math.log(K)
nats_n_C = mean_log_probs_n_C * torch.exp(mean_log_probs_n_C)
entropies_N[start:end].copy_(-torch.sum(nats_n_C, dim=1))
pbar.update(end - start)
pbar.close()
return entropies_N
conditional_entropies = compute_conditional_entropy(ys_ws.log())
print(conditional_entropies)
assert np.allclose(conditional_entropies, [1.2069, 1.2069, 1.2069, 1.2069], atol=0.01)
entropies = compute_entropy(ys_ws.log())
print(entropies)
assert np.allclose(entropies, [1.2376, 1.2376, 1.2376, 1.2376], atol=0.01)
BatchBALD
To compute BatchBALD exactly for a candidate batch, we'd have to compute $I[(y_b)_B;w] = H[(y_b)_B] - H[(y_b)_B|w]$.
As the $y_b$ are independent given $w$, we can simplify $H[(y_b)_B|w] = \sum_b H[y_b|w]$.
Furthermore, we use a greedy algorithm to build up the candidate batch, so $y_1,\dots,y_{B-1}$ will stay fixed as we determine $y_{B}$. We compute $H[(y_b)_{B-1}, y_i] - H[y_i|w]$ for each pool element $y_i$ and add the highest scorer as $y_{B}$.
We don't utilize the last optimization here in order to compute the actual scores.
@dataclass
class CandidateBatch:
scores: List[float]
indices: List[int]
def get_batchbald_batch(
log_probs_N_K_C: torch.Tensor, batch_size: int, num_samples: int, dtype=None, device=None
) -> CandidateBatch:
N, K, C = log_probs_N_K_C.shape
batch_size = min(batch_size, N)
candidate_indices = []
candidate_scores = []
if batch_size == 0:
return CandidateBatch(candidate_scores, candidate_indices)
conditional_entropies_N = compute_conditional_entropy(log_probs_N_K_C)
batch_joint_entropy = joint_entropy.DynamicJointEntropy(
num_samples, batch_size - 1, K, C, dtype=dtype, device=device
)
# We always keep these on the CPU.
scores_N = torch.empty(N, dtype=torch.double, pin_memory=torch.cuda.is_available())
for i in tqdm(range(batch_size), desc="BatchBALD", leave=False):
if i > 0:
latest_index = candidate_indices[-1]
batch_joint_entropy.add_variables(log_probs_N_K_C[latest_index : latest_index + 1])
shared_conditinal_entropies = conditional_entropies_N[candidate_indices].sum()
batch_joint_entropy.compute_batch(log_probs_N_K_C, output_entropies_B=scores_N)
scores_N -= conditional_entropies_N + shared_conditinal_entropies
scores_N[candidate_indices] = -float("inf")
candidate_score, candidate_index = scores_N.max(dim=0)
candidate_indices.append(candidate_index.item())
candidate_scores.append(candidate_score.item())
return CandidateBatch(candidate_scores, candidate_indices)
get_batchbald_batch(ys_ws.log().double(), 4, 1000, dtype=torch.double)
def get_bald_batch(log_probs_N_K_C: torch.Tensor, batch_size: int, dtype=None, device=None) -> CandidateBatch:
N, K, C = log_probs_N_K_C.shape
batch_size = min(batch_size, N)
candidate_indices = []
candidate_scores = []
scores_N = -compute_conditional_entropy(log_probs_N_K_C)
scores_N += compute_entropy(log_probs_N_K_C)
candiate_scores, candidate_indices = torch.topk(scores_N, batch_size)
return CandidateBatch(candiate_scores.tolist(), candidate_indices.tolist())
get_bald_batch(ys_ws.log().double(), 4)