Skip to main content

Python PySpark: How to Slice a PySpark DataFrame into Two Row-Wise DataFrames

Splitting a PySpark DataFrame into two smaller DataFrames by rows is a common operation in data processing - whether you need to create training and test sets, separate data for parallel processing, or paginate large datasets. Unlike Pandas, PySpark DataFrames don't support direct row indexing, so slicing requires different approaches.

This guide covers four methods to split a PySpark DataFrame row-wise, each suited to different use cases.

Sample DataFrame

All examples use the following PySpark DataFrame:

from pyspark.sql import SparkSession

spark = SparkSession.builder.appName('SlicingDemo').getOrCreate()

rows = [
['James Anderson', 69, 'New York'],
['Emily Chen', 66, 'Chicago'],
['Michael Davis', 9, 'Houston'],
['Sarah Martinez', 15, 'Phoenix']
]

columns = ['Player', 'Titles', 'City']
df = spark.createDataFrame(rows, columns)
df.show()

Output:

+--------------+------+--------+
| Player|Titles| City|
+--------------+------+--------+
|James Anderson| 69|New York|
| Emily Chen| 66| Chicago|
| Michael Davis| 9| Houston|
|Sarah Martinez| 15| Phoenix|
+--------------+------+--------+

Method 1: Using limit() and subtract()

This approach takes the first N rows with limit() and gets the remaining rows by subtracting the first slice from the original DataFrame.

from pyspark.sql import SparkSession

spark = SparkSession.builder.appName('SlicingDemo').getOrCreate()

rows = [
['James Anderson', 69, 'New York'],
['Emily Chen', 66, 'Chicago'],
['Michael Davis', 9, 'Houston'],
['Sarah Martinez', 15, 'Phoenix']
]

columns = ['Player', 'Titles', 'City']
df = spark.createDataFrame(rows, columns)

# First slice: top 3 rows
df1 = df.limit(3)

# Second slice: everything not in the first slice
df2 = df.subtract(df1)

print("First slice:")
df1.show()

print("Second slice:")
df2.show()

Output:

First slice:
+--------------+------+--------+
| Player|Titles| City|
+--------------+------+--------+
|James Anderson| 69|New York|
| Emily Chen| 66| Chicago|
| Michael Davis| 9| Houston|
+--------------+------+--------+

Second slice:
+--------------+------+-------+
| Player|Titles| City|
+--------------+------+-------+
|Sarah Martinez| 15|Phoenix|
+--------------+------+-------+
warning

subtract() performs a set difference operation, which removes duplicate rows. If your DataFrame contains duplicate rows, some may be unintentionally removed from the second slice. Additionally, subtract() does not preserve row order. Use this method only when rows are unique and order doesn't matter.

Method 2: Using randomSplit()

The randomSplit() method splits a DataFrame into multiple parts based on specified weight ratios. Rows are distributed randomly, making this ideal for creating training and test datasets.

from pyspark.sql import SparkSession

spark = SparkSession.builder.appName('SlicingDemo').getOrCreate()

rows = [
['James Anderson', 69, 'New York'],
['Emily Chen', 66, 'Chicago'],
['Michael Davis', 9, 'Houston'],
['Sarah Martinez', 15, 'Phoenix']
]

columns = ['Player', 'Titles', 'City']
df = spark.createDataFrame(rows, columns)

# Split: approximately 70% and 30%
df1, df2 = df.randomSplit([0.7, 0.3], seed=42)

print(f"First slice ({df1.count()} rows):")
df1.show()

print(f"Second slice ({df2.count()} rows):")
df2.show()

Output (varies by run unless seed is set):

First slice (3 rows):
+--------------+------+-------+
| Player|Titles| City|
+--------------+------+-------+
| Emily Chen| 66|Chicago|
| Michael Davis| 9|Houston|
|Sarah Martinez| 15|Phoenix|
+--------------+------+-------+

Second slice (1 rows):
+--------------+------+--------+
| Player|Titles| City|
+--------------+------+--------+
|James Anderson| 69|New York|
+--------------+------+--------+
tip

The seed parameter ensures reproducible splits. Always set a seed when you need consistent results across runs, such as in machine learning experiments:

df1, df2 = df.randomSplit([0.8, 0.2], seed=42)
info

The weights are approximate - the actual number of rows in each split may not exactly match the specified ratio, especially with small datasets. The ratios become more accurate with larger DataFrames.

Method 3: Using collect() and List Slicing

This method converts the DataFrame to a list of Row objects, slices the list using Python's standard slicing, and converts the slices back to PySpark DataFrames.

from pyspark.sql import SparkSession

spark = SparkSession.builder.appName('SlicingDemo').getOrCreate()

rows = [
['James Anderson', 69, 'New York'],
['Emily Chen', 66, 'Chicago'],
['Michael Davis', 9, 'Houston'],
['Sarah Martinez', 15, 'Phoenix']
]

columns = ['Player', 'Titles', 'City']
df = spark.createDataFrame(rows, columns)

# Collect all rows into a Python list
row_list = df.collect()

# Slice the list
part1 = row_list[:1] # First row
part2 = row_list[1:] # Remaining rows

# Convert slices back to PySpark DataFrames
slice1 = spark.createDataFrame(part1, columns)
slice2 = spark.createDataFrame(part2, columns)

print("First slice:")
slice1.show()

print("Second slice:")
slice2.show()

Output:

First slice:
+--------------+------+--------+
| Player|Titles| City|
+--------------+------+--------+
|James Anderson| 69|New York|
+--------------+------+--------+

Second slice:
+--------------+------+-------+
| Player|Titles| City|
+--------------+------+-------+
| Emily Chen| 66|Chicago|
| Michael Davis| 9|Houston|
|Sarah Martinez| 15|Phoenix|
+--------------+------+-------+

This approach gives you exact control over which rows go into each slice based on their position.

danger

collect() loads the entire DataFrame into the driver's memory. For large datasets (millions of rows), this can cause OutOfMemoryError and crash your application. Use this method only for small DataFrames.

Method 4: Converting to Pandas and Using iloc[]

This method converts the PySpark DataFrame to a Pandas DataFrame, uses Pandas' iloc[] for precise row slicing, and converts the results back to PySpark.

from pyspark.sql import SparkSession

spark = SparkSession.builder.appName('SlicingDemo').getOrCreate()

rows = [
['James Anderson', 69, 'New York'],
['Emily Chen', 66, 'Chicago'],
['Michael Davis', 9, 'Houston'],
['Sarah Martinez', 15, 'Phoenix']
]

columns = ['Player', 'Titles', 'City']
df = spark.createDataFrame(rows, columns)

# Convert to Pandas
pandas_df = df.toPandas()

# Slice using iloc
pdf1 = pandas_df.iloc[:2] # First 2 rows
pdf2 = pandas_df.iloc[2:] # Last 2 rows

# Convert back to PySpark
df1 = spark.createDataFrame(pdf1)
df2 = spark.createDataFrame(pdf2)

print("First slice:")
df1.show()

print("Second slice:")
df2.show()

Output:

First slice:
+--------------+------+--------+
| Player|Titles| City|
+--------------+------+--------+
|James Anderson| 69|New York|
| Emily Chen| 66| Chicago|
+--------------+------+--------+

Second slice:
+--------------+------+-------+
| Player|Titles| City|
+--------------+------+-------+
| Michael Davis| 9|Houston|
|Sarah Martinez| 15|Phoenix|
+--------------+------+-------+

This gives you the full power of Pandas slicing, including step values (iloc[::2]), negative indexing, and more.

warning

Like collect(), toPandas() loads the entire DataFrame into the driver's memory. This method is only suitable for DataFrames that fit comfortably in memory. For large-scale data, use limit() + subtract() or randomSplit().

Method Comparison

MethodPreserves OrderExact SplitHandles DuplicatesLarge Data Safe
limit() + subtract()NoYes (by count)No (removes duplicates)Yes
randomSplit()YesApproximateYesYes
collect() + list slicingYesYes (by position)YesNo (memory risk)
toPandas() + iloc[]YesYes (by position)YesNo (memory risk)

When to Use Each Method

  • randomSplit(): Best for machine learning train/test splits where random distribution matters and exact counts are not critical.
  • limit() + subtract(): Best for distributed environments where you need the first N rows without collecting data to the driver. Be aware of duplicate handling.
  • collect() or toPandas(): Best for small DataFrames where you need precise positional slicing and full control over the split point.

Choose the method that fits your data size and whether you need deterministic, position-based slicing or approximate, distributed splitting.