How to Divide Datasets Into Mini-Batches in Python
In machine learning and deep learning, processing an entire dataset at once is often impossible due to memory constraints. Instead, we use mini-batches, i.e. small, manageable chunks of data processed sequentially.
This guide explains how to implement a memory-efficient data pipeline in Python using generators to slice a dataset into mini-batches.
Understanding Generators and Slicing
To divide a list into chunks, we rely on two Python features:
- List Slicing:
list[start:end]extracts a specific portion of the data. Python handles out-of-bounds indices gracefully (e.g., if you ask for 5 items but only 3 remain, it simply returns the 3). - Generators (
yield): Instead of creating a list of lists containing all batches (which doubles memory usage), we useyieldto return one batch at a time.
Using range(start, stop, step) allows us to jump through the list indices by the size of the batch.
Step 1: Implementing the Data Pipeline
We will create a function data_pipeline that accepts raw data and a batch size.
Memory Inefficient vs. Efficient Approach
# ⛔️ Inefficient: Creating a list of all batches at once
# This consumes 2x memory (Original Data + Batch Structure)
def bad_pipeline(data, batch_size):
batches = []
for i in range(0, len(data), batch_size):
batches.append(data[i : i + batch_size])
return batches
# ✅ Correct: Using a Generator to yield batches on demand
from typing import Generator, List
def data_pipeline(data: List[List[int]], batch_size: int) -> Generator[List[List[int]], None, None]:
"""
Yields mini-batches from the dataset sequentially.
"""
# Iterate from 0 to len(data) with steps of 'batch_size'
for i in range(0, len(data), batch_size):
# Slice the list.
# If i + batch_size > len(data), Python returns the remaining items.
batch = data[i : i + batch_size]
yield batch
Step 2: Testing the Mini-Batch Generator
To verify the logic, we will iterate through the generator. We expect the final batch to contain fewer items if the total count isn't perfectly divisible by the batch size.
if __name__ == "__main__":
# Sample dataset: 5 items
dataset = [
[1, 2],
[1, 3],
[3, 5],
[2, 1],
[3, 3]
]
BATCH_SIZE = 2
print(f"Total items: {len(dataset)}")
print(f"Batch size: {BATCH_SIZE}")
print("-" * 20)
# Initialize the generator
batch_generator = data_pipeline(dataset, BATCH_SIZE)
# Consuming the generator
for i, batch in enumerate(batch_generator):
print(f"Batch {i+1}: {batch}")
Output Validation:
Total items: 5
Batch size: 2
--------------------
Batch 1: [[1, 2], [1, 3]]
Batch 2: [[3, 5], [2, 1]]
Batch 3: [[3, 3]]
Notice Batch 3. The generator correctly handled the "remainder." Even though we requested a batch size of 2, only 1 item was left, so it returned just that item without crashing or raising an IndexError.
Conclusion
Creating a mini-batch generator is a fundamental skill for data processing.
- Use
range(0, len(data), batch_size)to calculate start indices. - Use
data[i : i + batch_size]to slice the data safely. - Use
yieldto create a memory-efficient iterator.