Skip to main content

Python Pandas: How to Flatten MultiIndex Columns in Pandas

Hierarchical (MultiIndex) columns are a common byproduct of groupby aggregations and pivot table operations in Pandas. While they represent the data structure accurately, they complicate column access, make exports messy, and require tuple-based referencing like df[('Value', 'sum')] instead of simple strings. Flattening these multi-level column names into clean, single-level strings makes your DataFrames easier to work with.

In this guide, you will learn how to prevent MultiIndex columns from forming in the first place, flatten existing ones using different techniques, and handle pivot table columns with custom formatting.

Prevention: Using Named Aggregation

The best way to deal with MultiIndex columns is to avoid creating them. Named aggregation lets you specify both the aggregation function and the output column name in a single step:

import pandas as pd

df = pd.DataFrame({
'Group': ['A', 'A', 'B', 'B'],
'Value': [10, 20, 30, 40]
})

result = df.groupby('Group').agg(
total=('Value', 'sum'),
average=('Value', 'mean'),
count=('Value', 'count')
).reset_index()

print(result)

Output:

  Group  total  average  count
0 A 30 15.0 2
1 B 70 35.0 2

The result has clean, descriptive column names with no MultiIndex. This approach is recommended for new code because it produces self-documenting output.

tip

Named aggregation uses the syntax output_name=('source_column', 'function'). This gives you full control over the final column names without any post-processing.

Flattening Existing MultiIndex Columns

When you already have a DataFrame with MultiIndex columns, such as from an aggregation that used a list of functions, you can flatten them by joining the level names:

import pandas as pd

df = pd.DataFrame({
'Group': ['A', 'A', 'B'],
'Value': [10, 20, 30],
'Count': [1, 2, 3]
})

# This creates MultiIndex columns
result = df.groupby('Group').agg(['sum', 'mean'])

print("Before flattening:")
print(result.columns.tolist())
print()

# Flatten by joining the tuple elements with underscore
result.columns = ['_'.join(col) for col in result.columns]
result = result.reset_index()

print("After flattening:")
print(result)

Output:

Before flattening:
[('Value', 'sum'), ('Value', 'mean'), ('Count', 'sum'), ('Count', 'mean')]

After flattening:
Group Value_sum Value_mean Count_sum Count_mean
0 A 30 15.0 3 1.5
1 B 30 30.0 3 3.0

Each MultiIndex column is a tuple like ('Value', 'sum'). The list comprehension joins these tuple elements with an underscore to produce 'Value_sum'.

Using to_flat_index() for Complex Hierarchies

For MultiIndex columns with more than two levels, to_flat_index() provides a more explicit conversion path:

import pandas as pd

df = pd.DataFrame({
'Group': ['A', 'A', 'B'],
'Value': [10, 20, 30]
})

result = df.groupby('Group').agg(['sum', 'mean'])

# Convert to flat tuples, then join
result.columns = ['_'.join(col).strip() for col in result.columns.to_flat_index()]

print(result.columns.tolist())

Output:

['Value_sum', 'Value_mean']

The .strip() call removes any trailing whitespace that can appear when one level has empty strings.

Flattening Pivot Table Columns

Pivot tables frequently produce MultiIndex columns, especially when using multiple aggregation functions:

import pandas as pd

df = pd.DataFrame({
'Date': ['2024-01', '2024-01', '2024-02'],
'Category': ['A', 'B', 'A'],
'Sales': [100, 200, 150]
})

pivot = df.pivot_table(
values='Sales',
index='Date',
columns='Category',
aggfunc=['sum', 'mean']
)

print("Before flattening:")
print(pivot.columns.tolist())
print()

# Flatten the column names
pivot.columns = ['_'.join(col) for col in pivot.columns]

print("After flattening:")
print(pivot.columns.tolist())
print()

pivot = pivot.reset_index()
print(pivot)

Output:

Before flattening:
[('sum', 'A'), ('sum', 'B'), ('mean', 'A'), ('mean', 'B')]

After flattening:
['sum_A', 'sum_B', 'mean_A', 'mean_B']

Date sum_A sum_B mean_A mean_B
0 2024-01 100.0 200.0 100.0 200.0
1 2024-02 150.0 NaN 150.0 NaN

Custom Separators and Formatting

The join-based approach is flexible enough to accommodate different naming conventions:

import pandas as pd

df = pd.DataFrame({
'Group': ['A', 'A'],
'Value': [10, 20]
})

result = df.groupby('Group').agg(['sum', 'mean'])

# Using a different separator
result_dash = result.copy()
result_dash.columns = [' - '.join(col) for col in result_dash.columns]
print(f"Dash separator: {result_dash.columns.tolist()}")

# Reversing the order (function first)
result_reversed = result.copy()
result_reversed.columns = [f"{func}_{col}" for col, func in result_reversed.columns]
print(f"Reversed order: {result_reversed.columns.tolist()}")

# Using only the function name when there is one source column
result_func = result.copy()
result_func.columns = [func for _, func in result_func.columns]
print(f"Function only: {result_func.columns.tolist()}")

Output:

Dash separator: ['Value - sum', 'Value - mean']
Reversed order: ['sum_Value', 'mean_Value']
Function only: ['sum', 'mean']

When to Use Each Approach

MethodWhen to UseResult Example
Named aggregationWriting new aggregation codetotal, average
'_'.join(col)Flattening existing MultiIndexValue_sum
.to_flat_index()Complex or deeply nested hierarchiesFlat tuples
Custom f-string formatSpecific naming conventions neededsum_Value

Prefer named aggregation whenever possible to avoid MultiIndex columns entirely. It produces cleaner, self-documenting code from the start.

For existing MultiIndex columns, use '_'.join(col) for col in df.columns to flatten them into readable single-level names.

Always follow with reset_index() if the row index is also hierarchical and needs to be converted to a regular column.