Skip to main content

How to Divide Datasets Into Mini-Batches in Python

In machine learning and deep learning, processing an entire dataset at once is often impossible due to memory constraints. Instead, we use mini-batches, i.e. small, manageable chunks of data processed sequentially.

This guide explains how to implement a memory-efficient data pipeline in Python using generators to slice a dataset into mini-batches.

Understanding Generators and Slicing

To divide a list into chunks, we rely on two Python features:

  1. List Slicing: list[start:end] extracts a specific portion of the data. Python handles out-of-bounds indices gracefully (e.g., if you ask for 5 items but only 3 remain, it simply returns the 3).
  2. Generators (yield): Instead of creating a list of lists containing all batches (which doubles memory usage), we use yield to return one batch at a time.
note

Using range(start, stop, step) allows us to jump through the list indices by the size of the batch.

Step 1: Implementing the Data Pipeline

We will create a function data_pipeline that accepts raw data and a batch size.

Memory Inefficient vs. Efficient Approach

# ⛔️ Inefficient: Creating a list of all batches at once
# This consumes 2x memory (Original Data + Batch Structure)
def bad_pipeline(data, batch_size):
batches = []
for i in range(0, len(data), batch_size):
batches.append(data[i : i + batch_size])
return batches

# ✅ Correct: Using a Generator to yield batches on demand
from typing import Generator, List

def data_pipeline(data: List[List[int]], batch_size: int) -> Generator[List[List[int]], None, None]:
"""
Yields mini-batches from the dataset sequentially.
"""
# Iterate from 0 to len(data) with steps of 'batch_size'
for i in range(0, len(data), batch_size):
# Slice the list.
# If i + batch_size > len(data), Python returns the remaining items.
batch = data[i : i + batch_size]
yield batch

Step 2: Testing the Mini-Batch Generator

To verify the logic, we will iterate through the generator. We expect the final batch to contain fewer items if the total count isn't perfectly divisible by the batch size.

if __name__ == "__main__":
# Sample dataset: 5 items
dataset = [
[1, 2],
[1, 3],
[3, 5],
[2, 1],
[3, 3]
]

BATCH_SIZE = 2

print(f"Total items: {len(dataset)}")
print(f"Batch size: {BATCH_SIZE}")
print("-" * 20)

# Initialize the generator
batch_generator = data_pipeline(dataset, BATCH_SIZE)

# Consuming the generator
for i, batch in enumerate(batch_generator):
print(f"Batch {i+1}: {batch}")

Output Validation:

Total items: 5
Batch size: 2
--------------------
Batch 1: [[1, 2], [1, 3]]
Batch 2: [[3, 5], [2, 1]]
Batch 3: [[3, 3]]
tip

Notice Batch 3. The generator correctly handled the "remainder." Even though we requested a batch size of 2, only 1 item was left, so it returned just that item without crashing or raising an IndexError.

Conclusion

Creating a mini-batch generator is a fundamental skill for data processing.

  1. Use range(0, len(data), batch_size) to calculate start indices.
  2. Use data[i : i + batch_size] to slice the data safely.
  3. Use yield to create a memory-efficient iterator.