Skip to main content

Python PySpark: How to Split a DataFrame by Column Value in PySpark

When working with large PySpark DataFrames, you often need to split the data into separate DataFrames based on the values in a specific column - for example, separating customers by region, filtering orders by status, or partitioning records by age group. PySpark provides two primary functions for this: filter() and where(), both of which let you apply conditions to extract subsets of your data.

In this guide, you will learn how to split a PySpark DataFrame by column value using both methods, along with advanced techniques for handling multiple splits, complex conditions, and practical patterns for real-world use cases.

Setting Up the Example DataFrame

from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("split_demo").getOrCreate()

data = [
("James", 18, "Physics", 3, 10000),
("Emily", 21, "Chemistry", 2, 14000),
("Michael", 18, "Biology", 3, 12000),
("David", 23, "Maths", 4, 11000),
("Sarah", 21, "English", 2, 13000),
]

columns = ["Name", "Age", "Subject", "Class", "Fees"]
df = spark.createDataFrame(data, columns)
df.show()

Output:

+-------+---+---------+-----+-----+
| Name|Age| Subject|Class| Fees|
+-------+---+---------+-----+-----+
| James| 18| Physics| 3|10000|
| Emily| 21|Chemistry| 2|14000|
|Michael| 18| Biology| 3|12000|
| David| 23| Maths| 4|11000|
| Sarah| 21| English| 2|13000|
+-------+---+---------+-----+-----+

Method 1: Using filter()

The filter() function returns a new DataFrame containing only the rows that satisfy the given condition. To split a DataFrame, apply the condition once to get matching rows and then negate it to get the remaining rows.

Basic Split by Column Value

# Rows where Age equals 18
df_age_18 = df.filter(df.Age == 18)

# Rows where Age does not equal 18
df_age_not_18 = df.filter(df.Age != 18)

print("Age == 18:")
df_age_18.show()

print("Age != 18:")
df_age_not_18.show()

Output:

Age == 18:
+-------+---+-------+-----+-----+
| Name|Age|Subject|Class| Fees|
+-------+---+-------+-----+-----+
| James| 18|Physics| 3|10000|
|Michael| 18|Biology| 3|12000|
+-------+---+-------+-----+-----+

Age != 18:
+-----+---+---------+-----+-----+
| Name|Age| Subject|Class| Fees|
+-----+---+---------+-----+-----+
|Emily| 21|Chemistry| 2|14000|
|David| 23| Maths| 4|11000|
|Sarah| 21| English| 2|13000|
+-----+---+---------+-----+-----+

Using SQL Expression Syntax

You can also pass conditions as SQL expression strings:

df_age_18 = df.filter("Age = 18")
df_age_not_18 = df.filter("Age != 18")

df_age_18.show()

Output:

+-------+---+-------+-----+-----+
| Name|Age|Subject|Class| Fees|
+-------+---+-------+-----+-----+
| James| 18|Physics| 3|10000|
|Michael| 18|Biology| 3|12000|
+-------+---+-------+-----+-----+

Method 2: Using where()

The where() function is functionally identical to filter() - they are aliases of each other. You can use whichever reads more naturally in your code.

# Rows where Age equals 18
df_age_18 = df.where(df.Age == 18)

# Rows where Age does not equal 18
df_age_not_18 = df.where(df.Age != 18)

print("Age == 18:")
df_age_18.show()

print("Age != 18:")
df_age_not_18.show()

Output:

Age == 18:
+-------+---+-------+-----+-----+
| Name|Age|Subject|Class| Fees|
+-------+---+-------+-----+-----+
| James| 18|Physics| 3|10000|
|Michael| 18|Biology| 3|12000|
+-------+---+-------+-----+-----+

Age != 18:
+-----+---+---------+-----+-----+
| Name|Age| Subject|Class| Fees|
+-----+---+---------+-----+-----+
|Emily| 21|Chemistry| 2|14000|
|David| 23| Maths| 4|11000|
|Sarah| 21| English| 2|13000|
+-----+---+---------+-----+-----+
info

filter() and where() produce identical results. The choice is purely stylistic - where() may feel more natural if you come from a SQL background, while filter() aligns with functional programming conventions.

Splitting by Multiple Conditions

You can combine multiple conditions using & (AND), | (OR), and ~ (NOT). Each condition must be wrapped in parentheses.

AND Condition

# Students who are 18 AND in Class 3
df_filtered = df.filter((df.Age == 18) & (df.Class == 3))
df_filtered.show()

Output:

+-------+---+-------+-----+-----+
| Name|Age|Subject|Class| Fees|
+-------+---+-------+-----+-----+
| James| 18|Physics| 3|10000|
|Michael| 18|Biology| 3|12000|
+-------+---+-------+-----+-----+

OR Condition

# Students who are 18 OR have Fees greater than 12000
df_filtered = df.filter((df.Age == 18) | (df.Fees > 12000))
df_filtered.show()

Output:

+-------+---+---------+-----+-----+
| Name|Age| Subject|Class| Fees|
+-------+---+---------+-----+-----+
| James| 18| Physics| 3|10000|
| Emily| 21|Chemistry| 2|14000|
|Michael| 18| Biology| 3|12000|
| Sarah| 21| English| 2|13000|
+-------+---+---------+-----+-----+

NOT Condition

# Students who are NOT in Class 2
df_filtered = df.filter(~(df.Class == 2))
df_filtered.show()

Output:

+-------+---+-------+-----+-----+
| Name|Age|Subject|Class| Fees|
+-------+---+-------+-----+-----+
| James| 18|Physics| 3|10000|
|Michael| 18|Biology| 3|12000|
| David| 23| Maths| 4|11000|
+-------+---+-------+-----+-----+
caution

Always wrap each condition in parentheses when combining them with &, |, or ~. Without parentheses, Python's operator precedence can produce unexpected results or errors:

# Wrong: missing parentheses
df.filter(df.Age == 18 & df.Class == 3) # Error or wrong result

# Correct: each condition wrapped in parentheses
df.filter((df.Age == 18) & (df.Class == 3))

Splitting into Multiple DataFrames by Distinct Values

When you need to split a DataFrame into multiple subsets based on all distinct values in a column, use a loop:

# Get all distinct ages
distinct_ages = [row.Age for row in df.select("Age").distinct().collect()]

# Split into a dictionary of DataFrames
splits = {}
for age in sorted(distinct_ages):
splits[age] = df.filter(df.Age == age)

# Display each split
for age, split_df in splits.items():
print(f"Age = {age}:")
split_df.show()

Output:

Age = 18:
+-------+---+-------+-----+-----+
| Name|Age|Subject|Class| Fees|
+-------+---+-------+-----+-----+
| James| 18|Physics| 3|10000|
|Michael| 18|Biology| 3|12000|
+-------+---+-------+-----+-----+

Age = 21:
+-----+---+---------+-----+-----+
| Name|Age| Subject|Class| Fees|
+-----+---+---------+-----+-----+
|Emily| 21|Chemistry| 2|14000|
|Sarah| 21| English| 2|13000|
+-----+---+---------+-----+-----+

Age = 23:
+-----+---+-------+-----+-----+
| Name|Age|Subject|Class| Fees|
+-----+---+-------+-----+-----+
|David| 23| Maths| 4|11000|
+-----+---+-------+-----+-----+
tip

For large DataFrames with many distinct values, using collect() to fetch all distinct values to the driver is acceptable because distinct values are typically a small set. However, avoid creating too many DataFrames in a loop as each one triggers a separate Spark evaluation.

Splitting by Ranges

You can split DataFrames using comparison operators for range-based conditions:

from pyspark.sql.functions import col

# Split by fee ranges
df_low = df.filter(col("Fees") < 11000)
df_mid = df.filter((col("Fees") >= 11000) & (col("Fees") < 13000))
df_high = df.filter(col("Fees") >= 13000)

print("Low fees (< 11000):")
df_low.show()

print("Mid fees (11000 - 12999):")
df_mid.show()

print("High fees (>= 13000):")
df_high.show()

Output:

Low fees (< 11000):
+-----+---+-------+-----+-----+
| Name|Age|Subject|Class| Fees|
+-----+---+-------+-----+-----+
|James| 18|Physics| 3|10000|
+-----+---+-------+-----+-----+

Mid fees (11000 - 12999):
+-------+---+-------+-----+-----+
| Name|Age|Subject|Class| Fees|
+-------+---+-------+-----+-----+
|Michael| 18|Biology| 3|12000|
| David| 23| Maths| 4|11000|
+-------+---+-------+-----+-----+

High fees (>= 13000):
+-----+---+---------+-----+-----+
| Name|Age| Subject|Class| Fees|
+-----+---+---------+-----+-----+
|Emily| 21|Chemistry| 2|14000|
|Sarah| 21| English| 2|13000|
+-----+---+---------+-----+-----+

Using isin() for Multiple Values

When you want to filter by several specific values, use the isin() method:

# Students studying Physics or Chemistry
df_science = df.filter(df.Subject.isin("Physics", "Chemistry"))

# Students NOT studying Physics or Chemistry
df_other = df.filter(~df.Subject.isin("Physics", "Chemistry"))

print("Science subjects:")
df_science.show()

print("Other subjects:")
df_other.show()

Output:

Science subjects:
+-----+---+---------+-----+-----+
| Name|Age| Subject|Class| Fees|
+-----+---+---------+-----+-----+
|James| 18| Physics| 3|10000|
|Emily| 21|Chemistry| 2|14000|
+-----+---+---------+-----+-----+

Other subjects:
+-------+---+-------+-----+-----+
| Name|Age|Subject|Class| Fees|
+-------+---+-------+-----+-----+
|Michael| 18|Biology| 3|12000|
| David| 23| Maths| 4|11000|
| Sarah| 21|English| 2|13000|
+-------+---+-------+-----+-----+

Comparison of Methods

MethodSyntaxBest For
filter(df.col == value)Column object comparisonSimple, readable conditions
filter("SQL expression")SQL stringComplex expressions, SQL-familiar users
where(df.col == value)Identical to filter()SQL-style naming preference
filter(col("name") == value)Using col() functionDynamic column references

Summary

Splitting a PySpark DataFrame by column value is essential for data partitioning, conditional analysis, and pipeline branching. Key takeaways:

  • filter() and where() are interchangeable - both accept column conditions or SQL expression strings.
  • Use & (AND), | (OR), and ~ (NOT) with parentheses around each condition for compound filters.
  • Use isin() to filter by multiple specific values efficiently.
  • Loop over distinct values with collect() when you need to split into many subsets dynamically.
  • Use comparison operators (<, >=, between) for range-based splitting.