Skip to main content

Python PySpark: How to Count Rows Based on Conditions in PySpark in Python

Counting records that meet specific criteria is fundamental to data validation, quality checks, and reporting in distributed data processing. PySpark provides several efficient methods for conditional counting that scale across massive datasets, but choosing the wrong approach can mean scanning your data multiple times unnecessarily.

This guide covers basic filtering, multi-condition logic, grouped counts, and optimized techniques for calculating multiple metrics in a single pass over your data.

Basic Conditional Counting with filter() and where()

The simplest way to count rows matching a condition is to filter the DataFrame and then call .count(). The methods .filter() and .where() are interchangeable and behave identically:

from pyspark.sql import SparkSession
from pyspark.sql.functions import col

spark = SparkSession.builder.appName("ConditionalCount").getOrCreate()

data = [
("Electronics", 1200, "2024-01-15"),
("Books", 25, "2024-01-16"),
("Electronics", 800, "2024-01-17"),
("Clothing", 150, "2024-01-18"),
("Electronics", 2000, "2024-01-19")
]
df = spark.createDataFrame(data, ["category", "price", "date"])

# Count using column expression
electronics_count = df.filter(col("category") == "Electronics").count()

# Count using SQL string syntax
expensive_count = df.where("price > 500").count()

print(f"Electronics items: {electronics_count}")
print(f"Expensive items (>$500): {expensive_count}")

Output:

Electronics items: 3
Expensive items (>$500): 3
info

The .count() method is a Spark action that triggers actual computation across the cluster. Each separate .filter().count() call causes a full data scan. If you need multiple counts, consider the aggregation techniques covered later in this guide to avoid redundant scans.

Combining Multiple Conditions

Combine conditions using the bitwise operators & (AND), | (OR), and ~ (NOT). Each individual condition must be wrapped in parentheses:

from pyspark.sql import SparkSession
from pyspark.sql.functions import col

spark = SparkSession.builder.appName("ConditionalCount").getOrCreate()

data = [
("Electronics", 1200, "2024-01-15"),
("Books", 25, "2024-01-16"),
("Electronics", 800, "2024-01-17"),
("Clothing", 150, "2024-01-18"),
("Electronics", 2000, "2024-01-19")
]
df = spark.createDataFrame(data, ["category", "price", "date"])

# AND: Electronics priced over 1000
premium_electronics = df.filter(
(col("category") == "Electronics") & (col("price") > 1000)
).count()

# OR: Electronics or items over 1000
elec_or_expensive = df.filter(
(col("category") == "Electronics") | (col("price") > 1000)
).count()

# NOT: Non-electronics
non_electronics = df.filter(
~(col("category") == "Electronics")
).count()

print(f"Premium electronics: {premium_electronics}")
print(f"Electronics or expensive: {elec_or_expensive}")
print(f"Non-electronics: {non_electronics}")

Output:

Premium electronics: 2
Electronics or expensive: 3
Non-electronics: 2
warning

Each condition must be enclosed in parentheses when using bitwise operators. Without them, Python's operator precedence causes confusing errors:

# Wrong: missing parentheses
df.filter(col("category") == "Electronics" & col("price") > 1000)

# Correct: each condition wrapped in parentheses
df.filter((col("category") == "Electronics") & (col("price") > 1000))

This is one of the most common PySpark mistakes. The & operator binds more tightly than ==, so without parentheses Python tries to evaluate "Electronics" & col("price") first, which fails.

Multiple Counts in a Single Pass

When you need several different counts from the same DataFrame, calling .filter().count() multiple times scans the data once per call. A much more efficient approach uses .agg() with sum(when(...)) to compute all counts in a single pass:

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sum, when, count

spark = SparkSession.builder.appName("ConditionalCount").getOrCreate()

data = [
("Electronics", 1200, "2024-01-15"),
("Books", 25, "2024-01-16"),
("Electronics", 800, "2024-01-17"),
("Clothing", 150, "2024-01-18"),
("Electronics", 2000, "2024-01-19")
]
df = spark.createDataFrame(data, ["category", "price", "date"])

summary = df.agg(
count("*").alias("total_rows"),
sum(when(col("category") == "Electronics", 1).otherwise(0)).alias("electronics_count"),
sum(when(col("price") > 500, 1).otherwise(0)).alias("expensive_count"),
sum(when(col("price") < 100, 1).otherwise(0)).alias("cheap_count")
)

summary.show()

Output:

+----------+-----------------+---------------+-----------+
|total_rows|electronics_count|expensive_count|cheap_count|
+----------+-----------------+---------------+-----------+
| 5| 3| 3| 1|
+----------+-----------------+---------------+-----------+
tip

The agg(sum(when(...))) pattern is significantly faster than multiple .filter().count() calls because it processes the entire dataset only once. On large datasets, this difference can be substantial.

Counting with groupBy()

To count records within each group, use .groupBy() combined with an aggregation function:

from pyspark.sql import SparkSession
from pyspark.sql.functions import count

spark = SparkSession.builder.appName("ConditionalCount").getOrCreate()

data = [
("Electronics", 1200, "2024-01-15"),
("Books", 25, "2024-01-16"),
("Electronics", 800, "2024-01-17"),
("Clothing", 150, "2024-01-18"),
("Electronics", 2000, "2024-01-19")
]
df = spark.createDataFrame(data, ["category", "price", "date"])

category_counts = df.groupBy("category").agg(
count("*").alias("item_count")
)

category_counts.show()

Output:

+-----------+----------+
| category|item_count|
+-----------+----------+
| Books| 1|
| Clothing| 1|
|Electronics| 3|
+-----------+----------+

Conditional Counts Within Groups

You can combine grouping with conditional aggregation to break down counts further within each group:

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sum, when, count

spark = SparkSession.builder.appName("ConditionalCount").getOrCreate()

data = [
("Electronics", 1200, "2024-01-15"),
("Books", 25, "2024-01-16"),
("Electronics", 800, "2024-01-17"),
("Clothing", 150, "2024-01-18"),
("Electronics", 2000, "2024-01-19")
]
df = spark.createDataFrame(data, ["category", "price", "date"])

detailed_counts = df.groupBy("category").agg(
count("*").alias("total"),
sum(when(col("price") > 500, 1).otherwise(0)).alias("expensive"),
sum(when(col("price") <= 500, 1).otherwise(0)).alias("affordable")
)

detailed_counts.show()

Output:

+-----------+-----+---------+----------+
| category|total|expensive|affordable|
+-----------+-----+---------+----------+
| Books| 1| 0| 1|
| Clothing| 1| 0| 1|
|Electronics| 3| 3| 0|
+-----------+-----+---------+----------+

Distinct and Approximate Counts

Counting unique values requires different functions. PySpark provides both exact and approximate options:

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, countDistinct, approx_count_distinct

spark = SparkSession.builder.appName("ConditionalCount").getOrCreate()

data = [
("Electronics", 1200, "2024-01-15"),
("Books", 25, "2024-01-16"),
("Electronics", 800, "2024-01-17"),
("Clothing", 150, "2024-01-18"),
("Electronics", 2000, "2024-01-19")
]
df = spark.createDataFrame(data, ["category", "price", "date"])

# Exact distinct count
exact_categories = df.select(countDistinct("category")).collect()[0][0]

# Approximate distinct count (faster for large datasets)
approx_categories = df.select(approx_count_distinct("category")).collect()[0][0]

# Distinct count with a filter applied first
distinct_expensive_categories = df.filter(col("price") > 100).select(
countDistinct("category")
).collect()[0][0]

print(f"Exact distinct categories: {exact_categories}")
print(f"Approx distinct categories: {approx_categories}")
print(f"Distinct categories (price > 100): {distinct_expensive_categories}")

Output:

Exact distinct categories: 3
Approx distinct categories: 3
Distinct categories (price > 100): 2
info

approx_count_distinct() uses the HyperLogLog algorithm to estimate the number of distinct values without requiring a full shuffle. On datasets with billions of rows, this can be orders of magnitude faster than countDistinct() with only a small accuracy trade-off (typically within 2% of the true count).

Method Comparison

ApproachData ScansBest For
df.filter(cond).count()1 per callA single, simple count
df.agg(sum(when(...)))1 totalMultiple conditional counts efficiently
df.groupBy().count()1 totalCounts broken down by category
countDistinct()1 totalExact unique value counts
approx_count_distinct()1 totalFast estimated unique counts on large data

Practical Example: Data Quality Report

A common real-world use case is generating a data quality summary that checks for nulls, invalid values, and outliers. Using the single-pass aggregation pattern, you can compute all of these metrics at once:

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sum, when, count

spark = SparkSession.builder.appName("ConditionalCount").getOrCreate()

data = [
("Electronics", 1200, "2024-01-15"),
("Books", 25, "2024-01-16"),
("Electronics", None, "2024-01-17"),
("Clothing", -50, "2024-01-18"),
("Electronics", 2000, "2024-01-19"),
(None, 150, "2024-01-20")
]
df = spark.createDataFrame(data, ["category", "price", "date"])

def generate_quality_report(df):
"""Generate data quality metrics in a single pass."""
return df.agg(
count("*").alias("total_records"),
sum(when(col("price").isNull(), 1).otherwise(0)).alias("null_prices"),
sum(when(col("price") < 0, 1).otherwise(0)).alias("negative_prices"),
sum(when(col("category").isNull(), 1).otherwise(0)).alias("null_categories"),
sum(when(col("price") > 10000, 1).otherwise(0)).alias("outlier_prices")
)

quality_report = generate_quality_report(df)
quality_report.show()

Output:

+-------------+-----------+---------------+---------------+--------------+
|total_records|null_prices|negative_prices|null_categories|outlier_prices|
+-------------+-----------+---------------+---------------+--------------+
| 6| 1| 1| 1| 0|
+-------------+-----------+---------------+---------------+--------------+

This pattern scans the data exactly once regardless of how many quality metrics you compute, making it ideal for large-scale data validation in production pipelines.

Summary

For a single simple count, .filter().count() is clear and readable.

  • When you need multiple conditional counts from the same DataFrame, use agg(sum(when(...))) to compute them all in a single data scan.
  • For counts broken down by category, combine .groupBy() with conditional aggregation.
  • When counting unique values on very large datasets, consider approx_count_distinct() for a significant speed improvement with minimal accuracy loss.

Always remember to wrap each condition in parentheses when combining them with &, |, or ~ to avoid operator precedence errors.