How to Implement a Balanced Batch Generator in Python
In machine learning, training a model on an imbalanced dataset, where one class has significantly more samples than others, can lead to biased predictions. A common strategy to mitigate this is to create a data pipeline that dynamically generates batches with a balanced distribution of classes.
This guide explains how to implement a custom batch generator in Python that performs upsampling and downsampling on the fly to ensure equal class representation in every training step.
Understanding the Strategy
To balance a batch of size N across K unique classes:
- Group the dataset samples by their labels.
- Calculate quota: Determine how many samples (
Q) each class should contribute to the batch (Q = N // K). - Sample: Randomly select
Qsamples from each class.- If a class has more samples than
Q, we downsample (pick a subset). - If a class has fewer samples (and we allow replacement), we upsample (repeat samples).
- If a class has more samples than
- Fill Remainder: If
Nisn't perfectly divisible byK, randomly fill the remaining slots.
Step 1: Grouping Data by Label
First, we need to organize our raw data so that we can access all samples belonging to a specific class instantly. We use Python's collections.defaultdict.
from collections import defaultdict
# Raw data format: List of [feature_vector, one_hot_label]
# Example: [ [[1, 2], [1, 0]], [[3, 4], [0, 1]] ]
def group_data_by_label(data):
counter = defaultdict(list)
for x, y in data:
# ✅ Correct: Convert list label to tuple so it can be a dictionary key
counter[tuple(y)].append(x)
return counter
Lists are mutable and cannot be used as dictionary keys in Python. We must convert the one-hot label [1, 0] to a tuple (1, 0) before using it as a key.
Step 2: Calculating Batch Distribution
Given a batch size, we calculate how many samples each class must provide (pre_num) and how many slots remain (num_left) if the division isn't exact.
batch_size = 6
num_classes = 2
# Integer division for equal distribution
pre_num = batch_size // num_classes # 3
# Modulo for the remainder
num_left = batch_size % num_classes # 0
Step 3: Implementing the Sampling Logic
We iterate through our grouped data and select samples.
The implementation below uses random.sample. This function requires the population to be larger than the sample size (no replacement). For true upsampling where the class size might be smaller than the batch quota, use random.choices (with replacement) instead.
import random
def sample_balanced_batch(counter, pre_num, num_left):
batch_data = []
# 1. Fill the main quota for each class
for label_tuple, features_list in counter.items():
# Select 'pre_num' samples for this class
# Use random.sample for downsampling (unique samples)
# Use random.choices for upsampling (allows duplicates)
samples = random.sample(features_list, pre_num)
# Add to batch with the reconstructed label
batch_data.extend([[sample, list(label_tuple)] for sample in samples])
# 2. Fill the remainder (num_left) randomly
all_labels = list(counter.keys())
for _ in range(num_left):
# Pick a random class
y = random.choice(all_labels)
# Pick a random sample from that class
x = random.choice(counter[y])
batch_data.append([x, list(y)])
return batch_data
Complete Code Implementation
Here is the complete, runnable pipeline combining all steps.
import random
from collections import defaultdict
def unbalanced_data_pipeline(data, batch_size):
"""
Generates a batch of data with balanced class distribution.
"""
# 1. Group data by label
counter = defaultdict(list)
for x, y in data:
counter[tuple(y)].append(x)
# 2. Calculate distribution
num_classes = len(counter.keys())
pre_num = batch_size // num_classes
num_left = batch_size % num_classes
batch_data = []
# 3. Sample for each class
for y, x in counter.items():
# Note: Ensure len(x) >= pre_num or use random.choices
if len(x) >= pre_num:
samples = random.sample(x, pre_num)
else:
# Fallback to replacement if not enough samples (Upsampling)
samples = random.choices(x, k=pre_num)
batch_data.extend([[sample, list(y)] for sample in samples])
# 4. Handle remainder
keys_list = list(counter.keys())
for _ in range(num_left):
y = random.choice(keys_list)
x = random.choice(counter[y])
batch_data.append([x, list(y)])
return batch_data
if __name__ == "__main__":
# Test Data: Imbalanced (few [1, 0], many [0, 1])
data = [
[[1, 2, 5], [1, 0]], [[1, 6, 0], [1, 0]], [[4, 1, 8], [1, 0]], # 3 samples of Class A
[[7, 0, 4], [0, 1]], [[5, 9, 4], [0, 1]], [[2, 0, 1], [0, 1]], # 9 samples of Class B
[[1, 9, 3], [0, 1]], [[5, 5, 5], [0, 1]], [[8, 4, 0], [0, 1]],
[[9, 6, 3], [0, 1]], [[7, 7, 0], [0, 1]], [[0, 3, 4], [0, 1]],
]
print("Generating balanced batch of size 6...")
batch = unbalanced_data_pipeline(data, 6)
# Verify balance
class_a_count = sum(1 for item in batch if item[1] == [1, 0])
class_b_count = sum(1 for item in batch if item[1] == [0, 1])
print(f"Batch content: {batch}")
print(f"Class [1, 0] count: {class_a_count}")
print(f"Class [0, 1] count: {class_b_count}")
Output:
Generating balanced batch of size 6...
Batch content: [[[1, 2, 5], [1, 0]], [[4, 1, 8], [1, 0]], [[1, 6, 0], [1, 0]], [[5, 5, 5], [0, 1]], [[1, 9, 3], [0, 1]], [[8, 4, 0], [0, 1]]]
Class [1, 0] count: 3
Class [0, 1] count: 3
Conclusion
By manually controlling the batch generation process, you can ensure that your machine learning model sees a balanced representation of all classes during training.
- Map your data to separate buckets based on labels.
- Calculate the exact number of samples needed per class for your batch size.
- Sample randomly from these buckets to fill the batch.
This technique is a straightforward and effective way to handle imbalanced datasets without needing external libraries.