Python PySpark: How to Count Rows Based on Conditions in PySpark in Python
Counting records that meet specific criteria is fundamental to data validation, quality checks, and reporting in distributed data processing. PySpark provides several efficient methods for conditional counting that scale across massive datasets, but choosing the wrong approach can mean scanning your data multiple times unnecessarily.
This guide covers basic filtering, multi-condition logic, grouped counts, and optimized techniques for calculating multiple metrics in a single pass over your data.
Basic Conditional Counting with filter() and where()
The simplest way to count rows matching a condition is to filter the DataFrame and then call .count(). The methods .filter() and .where() are interchangeable and behave identically:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
spark = SparkSession.builder.appName("ConditionalCount").getOrCreate()
data = [
("Electronics", 1200, "2024-01-15"),
("Books", 25, "2024-01-16"),
("Electronics", 800, "2024-01-17"),
("Clothing", 150, "2024-01-18"),
("Electronics", 2000, "2024-01-19")
]
df = spark.createDataFrame(data, ["category", "price", "date"])
# Count using column expression
electronics_count = df.filter(col("category") == "Electronics").count()
# Count using SQL string syntax
expensive_count = df.where("price > 500").count()
print(f"Electronics items: {electronics_count}")
print(f"Expensive items (>$500): {expensive_count}")
Output:
Electronics items: 3
Expensive items (>$500): 3
The .count() method is a Spark action that triggers actual computation across the cluster. Each separate .filter().count() call causes a full data scan. If you need multiple counts, consider the aggregation techniques covered later in this guide to avoid redundant scans.
Combining Multiple Conditions
Combine conditions using the bitwise operators & (AND), | (OR), and ~ (NOT). Each individual condition must be wrapped in parentheses:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
spark = SparkSession.builder.appName("ConditionalCount").getOrCreate()
data = [
("Electronics", 1200, "2024-01-15"),
("Books", 25, "2024-01-16"),
("Electronics", 800, "2024-01-17"),
("Clothing", 150, "2024-01-18"),
("Electronics", 2000, "2024-01-19")
]
df = spark.createDataFrame(data, ["category", "price", "date"])
# AND: Electronics priced over 1000
premium_electronics = df.filter(
(col("category") == "Electronics") & (col("price") > 1000)
).count()
# OR: Electronics or items over 1000
elec_or_expensive = df.filter(
(col("category") == "Electronics") | (col("price") > 1000)
).count()
# NOT: Non-electronics
non_electronics = df.filter(
~(col("category") == "Electronics")
).count()
print(f"Premium electronics: {premium_electronics}")
print(f"Electronics or expensive: {elec_or_expensive}")
print(f"Non-electronics: {non_electronics}")
Output:
Premium electronics: 2
Electronics or expensive: 3
Non-electronics: 2
Each condition must be enclosed in parentheses when using bitwise operators. Without them, Python's operator precedence causes confusing errors:
# Wrong: missing parentheses
df.filter(col("category") == "Electronics" & col("price") > 1000)
# Correct: each condition wrapped in parentheses
df.filter((col("category") == "Electronics") & (col("price") > 1000))
This is one of the most common PySpark mistakes. The & operator binds more tightly than ==, so without parentheses Python tries to evaluate "Electronics" & col("price") first, which fails.
Multiple Counts in a Single Pass
When you need several different counts from the same DataFrame, calling .filter().count() multiple times scans the data once per call. A much more efficient approach uses .agg() with sum(when(...)) to compute all counts in a single pass:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sum, when, count
spark = SparkSession.builder.appName("ConditionalCount").getOrCreate()
data = [
("Electronics", 1200, "2024-01-15"),
("Books", 25, "2024-01-16"),
("Electronics", 800, "2024-01-17"),
("Clothing", 150, "2024-01-18"),
("Electronics", 2000, "2024-01-19")
]
df = spark.createDataFrame(data, ["category", "price", "date"])
summary = df.agg(
count("*").alias("total_rows"),
sum(when(col("category") == "Electronics", 1).otherwise(0)).alias("electronics_count"),
sum(when(col("price") > 500, 1).otherwise(0)).alias("expensive_count"),
sum(when(col("price") < 100, 1).otherwise(0)).alias("cheap_count")
)
summary.show()
Output:
+----------+-----------------+---------------+-----------+
|total_rows|electronics_count|expensive_count|cheap_count|
+----------+-----------------+---------------+-----------+
| 5| 3| 3| 1|
+----------+-----------------+---------------+-----------+
The agg(sum(when(...))) pattern is significantly faster than multiple .filter().count() calls because it processes the entire dataset only once. On large datasets, this difference can be substantial.
Counting with groupBy()
To count records within each group, use .groupBy() combined with an aggregation function:
from pyspark.sql import SparkSession
from pyspark.sql.functions import count
spark = SparkSession.builder.appName("ConditionalCount").getOrCreate()
data = [
("Electronics", 1200, "2024-01-15"),
("Books", 25, "2024-01-16"),
("Electronics", 800, "2024-01-17"),
("Clothing", 150, "2024-01-18"),
("Electronics", 2000, "2024-01-19")
]
df = spark.createDataFrame(data, ["category", "price", "date"])
category_counts = df.groupBy("category").agg(
count("*").alias("item_count")
)
category_counts.show()
Output:
+-----------+----------+
| category|item_count|
+-----------+----------+
| Books| 1|
| Clothing| 1|
|Electronics| 3|
+-----------+----------+
Conditional Counts Within Groups
You can combine grouping with conditional aggregation to break down counts further within each group:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sum, when, count
spark = SparkSession.builder.appName("ConditionalCount").getOrCreate()
data = [
("Electronics", 1200, "2024-01-15"),
("Books", 25, "2024-01-16"),
("Electronics", 800, "2024-01-17"),
("Clothing", 150, "2024-01-18"),
("Electronics", 2000, "2024-01-19")
]
df = spark.createDataFrame(data, ["category", "price", "date"])
detailed_counts = df.groupBy("category").agg(
count("*").alias("total"),
sum(when(col("price") > 500, 1).otherwise(0)).alias("expensive"),
sum(when(col("price") <= 500, 1).otherwise(0)).alias("affordable")
)
detailed_counts.show()
Output:
+-----------+-----+---------+----------+
| category|total|expensive|affordable|
+-----------+-----+---------+----------+
| Books| 1| 0| 1|
| Clothing| 1| 0| 1|
|Electronics| 3| 3| 0|
+-----------+-----+---------+----------+
Distinct and Approximate Counts
Counting unique values requires different functions. PySpark provides both exact and approximate options:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, countDistinct, approx_count_distinct
spark = SparkSession.builder.appName("ConditionalCount").getOrCreate()
data = [
("Electronics", 1200, "2024-01-15"),
("Books", 25, "2024-01-16"),
("Electronics", 800, "2024-01-17"),
("Clothing", 150, "2024-01-18"),
("Electronics", 2000, "2024-01-19")
]
df = spark.createDataFrame(data, ["category", "price", "date"])
# Exact distinct count
exact_categories = df.select(countDistinct("category")).collect()[0][0]
# Approximate distinct count (faster for large datasets)
approx_categories = df.select(approx_count_distinct("category")).collect()[0][0]
# Distinct count with a filter applied first
distinct_expensive_categories = df.filter(col("price") > 100).select(
countDistinct("category")
).collect()[0][0]
print(f"Exact distinct categories: {exact_categories}")
print(f"Approx distinct categories: {approx_categories}")
print(f"Distinct categories (price > 100): {distinct_expensive_categories}")
Output:
Exact distinct categories: 3
Approx distinct categories: 3
Distinct categories (price > 100): 2
approx_count_distinct() uses the HyperLogLog algorithm to estimate the number of distinct values without requiring a full shuffle. On datasets with billions of rows, this can be orders of magnitude faster than countDistinct() with only a small accuracy trade-off (typically within 2% of the true count).
Method Comparison
| Approach | Data Scans | Best For |
|---|---|---|
df.filter(cond).count() | 1 per call | A single, simple count |
df.agg(sum(when(...))) | 1 total | Multiple conditional counts efficiently |
df.groupBy().count() | 1 total | Counts broken down by category |
countDistinct() | 1 total | Exact unique value counts |
approx_count_distinct() | 1 total | Fast estimated unique counts on large data |
Practical Example: Data Quality Report
A common real-world use case is generating a data quality summary that checks for nulls, invalid values, and outliers. Using the single-pass aggregation pattern, you can compute all of these metrics at once:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sum, when, count
spark = SparkSession.builder.appName("ConditionalCount").getOrCreate()
data = [
("Electronics", 1200, "2024-01-15"),
("Books", 25, "2024-01-16"),
("Electronics", None, "2024-01-17"),
("Clothing", -50, "2024-01-18"),
("Electronics", 2000, "2024-01-19"),
(None, 150, "2024-01-20")
]
df = spark.createDataFrame(data, ["category", "price", "date"])
def generate_quality_report(df):
"""Generate data quality metrics in a single pass."""
return df.agg(
count("*").alias("total_records"),
sum(when(col("price").isNull(), 1).otherwise(0)).alias("null_prices"),
sum(when(col("price") < 0, 1).otherwise(0)).alias("negative_prices"),
sum(when(col("category").isNull(), 1).otherwise(0)).alias("null_categories"),
sum(when(col("price") > 10000, 1).otherwise(0)).alias("outlier_prices")
)
quality_report = generate_quality_report(df)
quality_report.show()
Output:
+-------------+-----------+---------------+---------------+--------------+
|total_records|null_prices|negative_prices|null_categories|outlier_prices|
+-------------+-----------+---------------+---------------+--------------+
| 6| 1| 1| 1| 0|
+-------------+-----------+---------------+---------------+--------------+
This pattern scans the data exactly once regardless of how many quality metrics you compute, making it ideal for large-scale data validation in production pipelines.
Summary
For a single simple count, .filter().count() is clear and readable.
- When you need multiple conditional counts from the same DataFrame, use
agg(sum(when(...)))to compute them all in a single data scan. - For counts broken down by category, combine
.groupBy()with conditional aggregation. - When counting unique values on very large datasets, consider
approx_count_distinct()for a significant speed improvement with minimal accuracy loss.
Always remember to wrap each condition in parentheses when combining them with &, |, or ~ to avoid operator precedence errors.