Skip to main content

Python PySpark: How to Split a PySpark DataFrame into Equal Number of Rows

When working with large PySpark DataFrames, there are scenarios where you need to split the data into smaller, equally-sized chunks - for example, to process each chunk in parallel, batch data for API calls, write data in manageable partitions, or distribute workloads across workers. If the operation on each chunk is independent of other rows, splitting and processing in parallel can significantly improve efficiency.

In this guide, you will learn how to split a PySpark DataFrame into equal parts using limit() and subtract(), a more reliable approach using monotonically_increasing_id(), and how to process each split independently before combining results.

Setting Up the Example DataFrame

from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("split_demo").getOrCreate()

columns = ["Brand", "Product"]

data = [
("HP", "Laptop"),
("Lenovo", "Mouse"),
("Dell", "Keyboard"),
("Samsung", "Monitor"),
("MSI", "Graphics Card"),
("Asus", "Motherboard"),
("Gigabyte", "Motherboard"),
("Zebronics", "Cabinet"),
("Adata", "RAM"),
("Transcend", "SSD"),
("Kingston", "HDD"),
("Toshiba", "DVD Writer"),
]

df = spark.createDataFrame(data=data, schema=columns)
df.show()

Output:

+---------+-------------+
| Brand| Product|
+---------+-------------+
| HP| Laptop|
| Lenovo| Mouse|
| Dell| Keyboard|
| Samsung| Monitor|
| MSI|Graphics Card|
| Asus| Motherboard|
| Gigabyte| Motherboard|
|Zebronics| Cabinet|
| Adata| RAM|
|Transcend| SSD|
| Kingston| HDD|
| Toshiba| DVD Writer|
+---------+-------------+

The DataFrame has 12 rows and 2 columns. We will split it into equal chunks.

Method 1: Using limit() and subtract()

The simplest approach uses limit() to take the first n rows and subtract() to remove those rows from the remaining data, repeating until all rows are distributed.

def split_dataframe_limit(df, n_splits):
"""Split a DataFrame into n equal parts using limit() and subtract()."""
each_len = df.count() // n_splits
remaining_df = df
splits = []

for i in range(n_splits):
if i == n_splits - 1:
# Last split gets all remaining rows (handles remainder)
splits.append(remaining_df)
else:
chunk = remaining_df.limit(each_len)
remaining_df = remaining_df.subtract(chunk)
splits.append(chunk)

return splits

# Split into 4 equal parts
chunks = split_dataframe_limit(df, n_splits=4)

for i, chunk in enumerate(chunks):
print(f"Chunk {i + 1} ({chunk.count()} rows):")
chunk.show(truncate=False)

Output:

Chunk 1 (3 rows):
+---------+-------------+
|Brand |Product |
+---------+-------------+
|HP |Laptop |
|Dell |Keyboard |
|Lenovo |Mouse |
+---------+-------------+

Chunk 2 (3 rows):
+-------+-------------+
|Brand |Product |
+-------+-------------+
|Samsung|Monitor |
|Asus |Motherboard |
|MSI |Graphics Card|
+-------+-------------+

Chunk 3 (3 rows):
+---------+-----------+
|Brand |Product |
+---------+-----------+
|Gigabyte |Motherboard|
|Zebronics|Cabinet |
|Adata |RAM |
+---------+-----------+

Chunk 4 (3 rows):
+---------+----------+
|Brand |Product |
+---------+----------+
|Transcend|SSD |
|Kingston |HDD |
|Toshiba |DVD Writer|
+---------+----------+
caution

The limit() + subtract() approach has significant drawbacks:

  • subtract() is expensive - it performs a full shuffle and set difference operation on every iteration.
  • Row order is not guaranteed - subtract() may reorder rows unpredictably.
  • Non-deterministic - if the DataFrame has duplicate rows, subtract() may remove more rows than intended because it performs set-based subtraction.

For production workloads, prefer the monotonically_increasing_id() approach described below.

A more reliable and efficient approach assigns a sequential row index to each row, then uses integer division to determine which split each row belongs to.

from pyspark.sql.functions import monotonically_increasing_id, spark_partition_id
from pyspark.sql.window import Window
from pyspark.sql.functions import row_number

def split_dataframe(df, n_splits):
"""Split a DataFrame into n roughly equal parts using row numbering."""
# Add a sequential row number
df_with_index = df.withColumn(
"row_idx",
row_number().over(Window.orderBy(monotonically_increasing_id())) - 1
)

total_rows = df_with_index.count()
each_len = total_rows // n_splits

splits = []
for i in range(n_splits):
start = i * each_len
if i == n_splits - 1:
# Last split gets all remaining rows
chunk = df_with_index.filter(
df_with_index.row_idx >= start
)
else:
end = start + each_len
chunk = df_with_index.filter(
(df_with_index.row_idx >= start) &
(df_with_index.row_idx < end)
)
# Drop the helper column
splits.append(chunk.drop("row_idx"))

return splits

# Split into 4 parts
chunks = split_dataframe(df, n_splits=4)

for i, chunk in enumerate(chunks):
print(f"Chunk {i + 1} ({chunk.count()} rows):")
chunk.show(truncate=False)

Output:

Chunk 1 (3 rows):
+------+--------+
|Brand |Product |
+------+--------+
|HP |Laptop |
|Lenovo|Mouse |
|Dell |Keyboard|
+------+--------+

Chunk 2 (3 rows):
+-------+-------------+
|Brand |Product |
+-------+-------------+
|Samsung|Monitor |
|MSI |Graphics Card|
|Asus |Motherboard |
+-------+-------------+

Chunk 3 (3 rows):
+---------+-----------+
|Brand |Product |
+---------+-----------+
|Gigabyte |Motherboard|
|Zebronics|Cabinet |
|Adata |RAM |
+---------+-----------+

Chunk 4 (3 rows):
+---------+----------+
|Brand |Product |
+---------+----------+
|Transcend|SSD |
|Kingston |HDD |
|Toshiba |DVD Writer|
+---------+----------+
tip

This method is deterministic and handles duplicates correctly. The row numbering ensures each row is assigned to exactly one split, regardless of duplicate values in the data.

Handling Uneven Splits

When the total row count is not evenly divisible by the number of splits, you need to decide what happens to the remainder rows. Both methods above assign remaining rows to the last chunk.

# 12 rows split into 5 parts: 2 + 2 + 2 + 2 + 4
chunks = split_dataframe(df, n_splits=5)

for i, chunk in enumerate(chunks):
print(f"Chunk {i + 1}: {chunk.count()} rows")

Output:

Chunk 1: 2 rows
Chunk 2: 2 rows
Chunk 3: 2 rows
Chunk 4: 2 rows
Chunk 5: 4 rows

For a more balanced distribution of remainder rows, you can modify the function to spread extra rows across the first few chunks:

def split_dataframe_balanced(df, n_splits):
"""Split DataFrame into n parts with balanced distribution of remainder rows."""
df_with_index = df.withColumn(
"row_idx",
row_number().over(Window.orderBy(monotonically_increasing_id())) - 1
)

total_rows = df_with_index.count()
base_size = total_rows // n_splits
remainder = total_rows % n_splits

splits = []
start = 0
for i in range(n_splits):
# Distribute one extra row to the first 'remainder' chunks
chunk_size = base_size + (1 if i < remainder else 0)
end = start + chunk_size
chunk = df_with_index.filter(
(df_with_index.row_idx >= start) &
(df_with_index.row_idx < end)
).drop("row_idx")
splits.append(chunk)
start = end

return splits

# 12 rows into 5 balanced splits: 3 + 3 + 2 + 2 + 2
chunks = split_dataframe_balanced(df, n_splits=5)

for i, chunk in enumerate(chunks):
print(f"Chunk {i + 1}: {chunk.count()} rows")

Output:

Chunk 1: 3 rows
Chunk 2: 3 rows
Chunk 3: 2 rows
Chunk 4: 2 rows
Chunk 5: 2 rows

Processing Each Split and Combining Results

A common pattern is to split a DataFrame, apply a transformation to each chunk independently, and then combine the results using union().

from pyspark.sql.functions import concat, col, lit
from pyspark.sql.types import StructType, StructField, StringType

# Split into 4 chunks
chunks = split_dataframe(df, n_splits=4)

# Define the transformation function
def transform_chunk(chunk_df):
"""Concatenate Brand and Product into a single column."""
return chunk_df.select(
concat(col("Brand"), lit(" - "), col("Product")).alias("Brand_Product")
)

# Process each chunk and combine results
result_df = None
for i, chunk in enumerate(chunks):
transformed = transform_chunk(chunk)
print(f"Processed Chunk {i + 1}:")
transformed.show(truncate=False)

if result_df is None:
result_df = transformed
else:
result_df = result_df.union(transformed)

print("Combined Result:")
result_df.show(truncate=False)

Output:

Processed Chunk 1:
+-------------+
|Brand_Product|
+-------------+
|HP - Laptop |
|Lenovo - Mouse|
|Dell - Keyboard|
+-------------+

...

Combined Result:
+------------------------+
|Brand_Product |
+------------------------+
|HP - Laptop |
|Lenovo - Mouse |
|Dell - Keyboard |
|Samsung - Monitor |
|MSI - Graphics Card |
|Asus - Motherboard |
|Gigabyte - Motherboard |
|Zebronics - Cabinet |
|Adata - RAM |
|Transcend - SSD |
|Kingston - HDD |
|Toshiba - DVD Writer |
+------------------------+
info

If the transformation function does not change the schema, you can use functools.reduce for cleaner code:

from functools import reduce
from pyspark.sql import DataFrame

result_df = reduce(
DataFrame.union,
[transform_chunk(chunk) for chunk in chunks]
)

Alternative: Using randomSplit()

PySpark provides a built-in randomSplit() method that splits a DataFrame by approximate proportions. While not guaranteed to produce exact split sizes, it is simpler and avoids the overhead of row numbering.

# Split into approximately 4 equal parts (25% each)
chunks = df.randomSplit([0.25, 0.25, 0.25, 0.25], seed=42)

for i, chunk in enumerate(chunks):
print(f"Chunk {i + 1}: {chunk.count()} rows")

Output (approximate):

Chunk 1: 2 rows
Chunk 2: 4 rows
Chunk 3: 3 rows
Chunk 4: 3 rows
caution

randomSplit() does not guarantee exact split sizes - the results are approximate and may vary between runs unless you set a seed. Use this method when approximate equality is acceptable. For exact splits, use the row-numbering approach.

Comparison of Methods

MethodExact SplitsDeterministicPerformanceHandles Duplicates
limit() + subtract()✅ Yes❌ No❌ Slow (shuffles)❌ Problematic
row_number() + filter()✅ Yes✅ Yes✅ Good✅ Yes
randomSplit()❌ Approximate✅ With seed✅ Good✅ Yes

Summary

Splitting a PySpark DataFrame into equal parts is useful for batch processing, parallel execution, and managing large datasets. Key takeaways:

  • Use row_number() with monotonically_increasing_id() for reliable, deterministic, and exact splits - this is the recommended approach.
  • Use randomSplit() when approximate split sizes are acceptable and simplicity is preferred.
  • Avoid limit() + subtract() in production due to performance overhead and non-deterministic behavior.
  • Handle remainder rows by either assigning them to the last chunk or distributing them across the first few chunks for balanced splits.
  • Use union() or functools.reduce to combine processed chunks back into a single DataFrame.