Greedy algorithm and score computation

First, we will implement two helper classes to compute conditional entropies $H[y_i|w]$ and entropies $H[y_i]$. Then, we will implement BatchBALD and BALD.

import math
from dataclasses import dataclass
from typing import List

import torch
from toma import toma
from tqdm.auto import tqdm

from batchbald_redux import joint_entropy

We are going to define a couple of sampled distributions to use for our testing our code.

$K=20$ means 20 inference samples.

K = 20
import numpy as np


def get_mixture_prob_dist(p1, p2, m):
    return (1.0 - m) * np.asarray(p1) + m * np.asarray(p2)


p1 = [0.7, 0.1, 0.1, 0.1]
p2 = [0.3, 0.3, 0.2, 0.2]
y1_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]

p1 = [0.1, 0.7, 0.1, 0.1]
p2 = [0.2, 0.3, 0.3, 0.2]
y2_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]

p1 = [0.1, 0.1, 0.7, 0.1]
p2 = [0.2, 0.2, 0.3, 0.3]
y3_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]

p1 = [0.1, 0.1, 0.1, 0.7]
p2 = [0.3, 0.2, 0.2, 0.3]
y4_ws = [get_mixture_prob_dist(p1, p2, m) for m in np.linspace(0, 1, K)]


def nested_to_tensor(l):
    return torch.stack(list(map(torch.as_tensor, l)))


ys_ws = nested_to_tensor([y1_ws, y2_ws, y3_ws, y4_ws])
ys_ws.shape
torch.Size([4, 20, 4])

Conditional Entropies and Batched Entropies

To start with, we write two functions to compute the conditional entropy $H[y_i|w]$ and the entropy $H[y_i]$ for each input sample.

def compute_conditional_entropy(probs_N_K_C: torch.Tensor) -> torch.Tensor:
    N, K, C = probs_N_K_C.shape

    entropies_N = torch.empty(N, dtype=torch.double)

    pbar = tqdm(total=N, desc="Conditional Entropy", leave=False)

    @toma.execute.chunked(probs_N_K_C, 1024)
    def compute(probs_n_K_C, start: int, end: int):
        nats_n_K_C = probs_n_K_C * torch.log(probs_n_K_C)
        nats_n_K_C[probs_n_K_C == 0] = 0.0

        entropies_N[start:end].copy_(-torch.sum(nats_n_K_C, dim=(1, 2)) / K)
        pbar.update(end - start)

    pbar.close()

    return entropies_N


def compute_entropy(probs_N_K_C: torch.Tensor) -> torch.Tensor:
    N, K, C = probs_N_K_C.shape

    entropies_N = torch.empty(N, dtype=torch.double)

    pbar = tqdm(total=N, desc="Entropy", leave=False)

    @toma.execute.chunked(probs_N_K_C, 1024)
    def compute(probs_n_K_C, start: int, end: int):
        mean_probs_n_C = probs_n_K_C.mean(dim=1)
        nats_n_C = mean_probs_n_C * torch.log(mean_probs_n_C)
        nats_n_C[mean_probs_n_C == 0] = 0.0

        entropies_N[start:end].copy_(-torch.sum(nats_n_C, dim=1))
        pbar.update(end - start)

    pbar.close()

    return entropies_N
assert np.allclose(compute_conditional_entropy(yus_ws), [1.3863, 1.3863, 1.3863, 1.3863], atol=0.1)
assert np.allclose(compute_entropy(yus_ws), [1.3863, 1.3863, 1.3863, 1.3863], atol=0.1)

However, our neural networks usually use a log_softmax as final layer. To avoid having to call .exp_(), which is easy to miss and annoying to debug, we will instead use a version that uses log_probs instead of probs.

compute_conditional_entropy[source]

compute_conditional_entropy(log_probs_N_K_C:Tensor)

compute_entropy[source]

compute_entropy(log_probs_N_K_C:Tensor)

def compute_conditional_entropy(log_probs_N_K_C: torch.Tensor) -> torch.Tensor:
    N, K, C = log_probs_N_K_C.shape

    entropies_N = torch.empty(N, dtype=torch.double)

    pbar = tqdm(total=N, desc="Conditional Entropy", leave=False)

    @toma.execute.chunked(log_probs_N_K_C, 1024)
    def compute(log_probs_n_K_C, start: int, end: int):
        nats_n_K_C = log_probs_n_K_C * torch.exp(log_probs_n_K_C)

        entropies_N[start:end].copy_(-torch.sum(nats_n_K_C, dim=(1, 2)) / K)
        pbar.update(end - start)

    pbar.close()

    return entropies_N


def compute_entropy(log_probs_N_K_C: torch.Tensor) -> torch.Tensor:
    N, K, C = log_probs_N_K_C.shape

    entropies_N = torch.empty(N, dtype=torch.double)

    pbar = tqdm(total=N, desc="Entropy", leave=False)

    @toma.execute.chunked(log_probs_N_K_C, 1024)
    def compute(log_probs_n_K_C, start: int, end: int):
        mean_log_probs_n_C = torch.logsumexp(log_probs_n_K_C, dim=1) - math.log(K)
        nats_n_C = mean_log_probs_n_C * torch.exp(mean_log_probs_n_C)

        entropies_N[start:end].copy_(-torch.sum(nats_n_C, dim=1))
        pbar.update(end - start)

    pbar.close()

    return entropies_N

Examples

conditional_entropies = compute_conditional_entropy(ys_ws.log())

print(conditional_entropies)

assert np.allclose(conditional_entropies, [1.2069, 1.2069, 1.2069, 1.2069], atol=0.01)
                                                          tensor([1.2069, 1.2069, 1.2069, 1.2069], dtype=torch.float64)
entropies = compute_entropy(ys_ws.log())

print(entropies)

assert np.allclose(entropies, [1.2376, 1.2376, 1.2376, 1.2376], atol=0.01)
                                              tensor([1.2376, 1.2376, 1.2376, 1.2376], dtype=torch.float64)

BatchBALD

To compute BatchBALD exactly for a candidate batch, we'd have to compute $I[(y_b)_B;w] = H[(y_b)_B] - H[(y_b)_B|w]$.

As the $y_b$ are independent given $w$, we can simplify $H[(y_b)_B|w] = \sum_b H[y_b|w]$.

Furthermore, we use a greedy algorithm to build up the candidate batch, so $y_1,\dots,y_{B-1}$ will stay fixed as we determine $y_{B}$. We compute $H[(y_b)_{B-1}, y_i] - H[y_i|w]$ for each pool element $y_i$ and add the highest scorer as $y_{B}$.

We don't utilize the last optimization here in order to compute the actual scores.

In the Paper

BatchBALD algorithm in the paper

Implementation

class CandidateBatch[source]

CandidateBatch(scores:List[float], indices:List[int])

CandidateBatch(scores: List[float], indices: List[int])

get_batchbald_batch[source]

get_batchbald_batch(log_probs_N_K_C:Tensor, batch_size:int, num_samples:int, dtype=None, device=None)

@dataclass
class CandidateBatch:
    scores: List[float]
    indices: List[int]


def get_batchbald_batch(
    log_probs_N_K_C: torch.Tensor, batch_size: int, num_samples: int, dtype=None, device=None
) -> CandidateBatch:
    N, K, C = log_probs_N_K_C.shape

    batch_size = min(batch_size, N)

    candidate_indices = []
    candidate_scores = []

    if batch_size == 0:
        return CandidateBatch(candidate_scores, candidate_indices)

    conditional_entropies_N = compute_conditional_entropy(log_probs_N_K_C)

    batch_joint_entropy = joint_entropy.DynamicJointEntropy(
        num_samples, batch_size - 1, K, C, dtype=dtype, device=device
    )

    # We always keep these on the CPU.
    scores_N = torch.empty(N, dtype=torch.double, pin_memory=torch.cuda.is_available())

    for i in tqdm(range(batch_size), desc="BatchBALD", leave=False):
        if i > 0:
            latest_index = candidate_indices[-1]
            batch_joint_entropy.add_variables(log_probs_N_K_C[latest_index : latest_index + 1])

        shared_conditinal_entropies = conditional_entropies_N[candidate_indices].sum()

        batch_joint_entropy.compute_batch(log_probs_N_K_C, output_entropies_B=scores_N)

        scores_N -= conditional_entropies_N + shared_conditinal_entropies
        scores_N[candidate_indices] = -float("inf")

        candidate_score, candidate_index = scores_N.max(dim=0)

        candidate_indices.append(candidate_index.item())
        candidate_scores.append(candidate_score.item())

    return CandidateBatch(candidate_scores, candidate_indices)

Example

get_batchbald_batch(ys_ws.log().double(), 4, 1000, dtype=torch.double)
BatchBALD:   0%|          | 0/4 [00:00<?, ?it/s]
ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]
                                                                      
ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]
BatchBALD:  50%|█████     | 2/4 [00:00<00:00, 13.04it/s]
ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]
                                                                      
ExactJointEntropy.compute_batch:   0%|          | 0/4 [00:00<?, ?it/s]
CandidateBatch(scores=[0.030715639666234917, 0.05961958627158248, 0.0869107051474467, 0.11275304532467878], indices=[1, 0, 2, 3])

BALD

BALD is the same as BatchBALD, except that we evaluate points individually, by computing $I[y_i;w]$ for each, and then take the top $B$ scorers.

get_bald_batch[source]

get_bald_batch(log_probs_N_K_C:Tensor, batch_size:int, dtype=None, device=None)

def get_bald_batch(log_probs_N_K_C: torch.Tensor, batch_size: int, dtype=None, device=None) -> CandidateBatch:
    N, K, C = log_probs_N_K_C.shape

    batch_size = min(batch_size, N)

    candidate_indices = []
    candidate_scores = []

    scores_N = -compute_conditional_entropy(log_probs_N_K_C)
    scores_N += compute_entropy(log_probs_N_K_C)

    candiate_scores, candidate_indices = torch.topk(scores_N, batch_size)

    return CandidateBatch(candiate_scores.tolist(), candidate_indices.tolist())

Example

get_bald_batch(ys_ws.log().double(), 4)
CandidateBatch(scores=[0.030715639666234917, 0.030715639666234917, 0.030715639666234695, 0.030715639666234695], indices=[1, 3, 2, 0])