Pyhon NumPy: How to Calculate Square Roots of Negative Numbers with NumPy emath
Standard np.sqrt() returns NaN for negative inputs because it expects real number results. NumPy's emath module provides mathematically correct handling of negative and complex numbers by returning complex results.
The Problem with np.sqrt()
The standard square root function doesn't handle negative numbers:
import numpy as np
arr = np.array([4, -1, -9, 16])
result = np.sqrt(arr)
print(result) # [ 2. nan nan 4.]
# Warning: RuntimeWarning: invalid value encountered in sqrt
This behavior is problematic when your data legitimately contains negative values that need complex square roots.
The Solution: np.emath.sqrt()
The emath (extended math) module returns complex numbers when needed:
import numpy as np
arr = np.array([4, -1, -9, 16])
result = np.emath.sqrt(arr)
print(result)
Output:
[2.+0.j 0.+1.j 0.+3.j 4.+0.j]
Understanding the Output
import numpy as np
# √(-1) = i = 1j in Python
print(np.emath.sqrt(-1)) # 1j
# √(-9) = 3i = 3j
print(np.emath.sqrt(-9)) # 3j
# √(4) = 2 (but returned as complex: 2+0j)
print(np.emath.sqrt(4)) # (2+0j)
# √(-4) = 2i
print(np.emath.sqrt(-4)) # 2j
Output:
1j
3j
2.0
2j
emath.sqrt() always returns complex numbers. Use .real to extract the real part if needed for positive inputs.
Working with Complex Numbers
emath.sqrt() handles complex inputs correctly:
import numpy as np
# Complex array
complex_arr = np.array([1+2j, 3+4j, -1+0j])
result = np.emath.sqrt(complex_arr)
print(result)
# [1.27201965+0.78615138j 2. +1.j 0. +1.j ]
# Verify: square the results
print(result ** 2)
# [1.+2.j 3.+4.j -1.+0.j] (Back to original!)
Output:
[1.27201965+0.78615138j 2. +1.j 0. +1.j ]
[ 1.+2.j 3.+4.j -1.+0.j]
Extracting Real and Imaginary Parts
import numpy as np
arr = np.array([4, -1, -9, 16])
result = np.emath.sqrt(arr)
# Get real parts
real_parts = result.real
print(f"Real parts: {real_parts}")
# Real parts: [2. 0. 0. 4.]
# Get imaginary parts
imag_parts = result.imag
print(f"Imaginary parts: {imag_parts}")
# Imaginary parts: [0. 1. 3. 0.]
# Get magnitude (absolute value)
magnitudes = np.abs(result)
print(f"Magnitudes: {magnitudes}")
# Magnitudes: [2. 1. 3. 4.]
Output:
Real parts: [2. 0. 0. 4.]
Imaginary parts: [0. 1. 3. 0.]
Magnitudes: [2. 1. 3. 4.]
Other emath Functions
The emath module provides extended versions of several functions:
Logarithm of Negative Numbers
import numpy as np
# Standard log fails for negative numbers
print(np.log(-1)) # nan (with warning)
# emath handles it correctly
print(np.emath.log(-1)) # 3.141592653589793j (= iπ)
# Because e^(iπ) = -1
arr = np.array([1, -1, np.e, -np.e])
print(np.emath.log(arr))
# [0.+0.j 0.+3.14159265j 1.+0.j 1.+3.14159265j]
Output:
RuntimeWarning: invalid value encountered in log
print(np.log(-1)) # nan (with warning)
nan
3.141592653589793j
[0.+0.j 0.+3.14159265j 1.+0.j 1.+3.14159265j]
Power with Complex Results
import numpy as np
# Fractional powers of negative numbers
print(np.emath.power(-8, 1/3)) # (1+1.732j) (a complex cube root)
# Standard power gives nan
print(np.power(-8, 1/3)) # nan
# Works with arrays
arr = np.array([-8, -27, 8, 27])
print(np.emath.power(arr, 1/3))
Output:
(1+1.732050807568877j) nan RuntimeWarning: invalid value encountered in power print(np.power(-8, 1/3)) # nan [1. +1.73205081j 1.5+2.59807621j 2. +0.j 3. +0.j ]
Arcsin and Arccos Beyond [-1, 1]
import numpy as np
# Standard arcsin fails outside [-1, 1]
print(np.arcsin(2)) # nan
# emath returns complex result
print(np.emath.arcsin(2)) # (1.5707963267948966+1.3169578969248166j)
# Same for arccos
print(np.emath.arccos(2)) # -1.3169578969248166j
Output:
RuntimeWarning: invalid value encountered in arcsin
print(np.arcsin(2))
nan
(1.5707963267948966+1.3169578969248166j)
-1.3169578969248166j
Practical Example: Quadratic Formula
The quadratic formula can produce complex roots:
import numpy as np
def quadratic_roots(a, b, c):
"""Solve ax² + bx + c = 0, handling complex roots."""
discriminant = b**2 - 4*a*c
# Use emath.sqrt to handle negative discriminants
sqrt_disc = np.emath.sqrt(discriminant)
root1 = (-b + sqrt_disc) / (2*a)
root2 = (-b - sqrt_disc) / (2*a)
return root1, root2
# Real roots: x² - 5x + 6 = 0
r1, r2 = quadratic_roots(1, -5, 6)
print(f"x² - 5x + 6 = 0: x = {r1.real}, {r2.real}")
# x = 3.0, 2.0
# Complex roots: x² + 1 = 0
r1, r2 = quadratic_roots(1, 0, 1)
print(f"x² + 1 = 0: x = {r1}, {r2}")
# x = 1j, -1j
# Complex roots: x² - 2x + 5 = 0
r1, r2 = quadratic_roots(1, -2, 5)
print(f"x² - 2x + 5 = 0: x = {r1}, {r2}")
# x = (1+2j), (1-2j)
Output:
x² - 5x + 6 = 0: x = 3.0, 2.0
x² + 1 = 0: x = 1j, -1j
x² - 2x + 5 = 0: x = (1+2j), (1-2j)
Converting Back to Real (When Appropriate)
If you know results should be real, extract the real part:
import numpy as np
arr = np.array([4, 9, 16, 25])
# Using emath (returns complex)
complex_result = np.emath.sqrt(arr)
print(complex_result)
# [2.+0.j 3.+0.j 4.+0.j 5.+0.j]
# Convert to real
real_result = complex_result.real
print(real_result)
# [2. 3. 4. 5.]
# Or check if imaginary parts are negligible
if np.allclose(complex_result.imag, 0):
real_result = complex_result.real
Output:
[2. 3. 4. 5.]
[2. 3. 4. 5.]
Performance Consideration
emath functions are slightly slower due to complex number handling:
import numpy as np
import timeit
arr = np.abs(np.random.randn(100000)) # Positive values only
# Standard sqrt is faster for positive numbers
standard_time = timeit.timeit(lambda: np.sqrt(arr), number=100)
emath_time = timeit.timeit(lambda: np.emath.sqrt(arr), number=100)
print(f"np.sqrt: {standard_time:.4f}s")
print(f"np.emath.sqrt: {emath_time:.4f}s")
Output:
np.sqrt: 0.0146s
np.emath.sqrt: 0.0229s
Use np.sqrt() when you're certain all values are non-negative. Use np.emath.sqrt() when negative values are possible or expected.
Summary of emath Functions
| Function | Standard Behavior | emath Behavior |
|---|---|---|
sqrt(-1) | NaN | 1j |
log(-1) | NaN | πj |
power(-8, 1/3) | NaN | Complex cube root |
arcsin(2) | NaN | Complex result |
arccos(2) | NaN | Complex result |
Summary
Use np.emath.sqrt() when your data may contain negative numbers and you need mathematically correct complex results. The emath module extends standard NumPy functions to handle cases that would otherwise produce NaN. For purely positive data, stick with np.sqrt() for better performance.