Simple FFT in Haskell
Published on
The article develops a simple implementation of the fast Fourier transform in Haskell.
Raw performance of the algorithm is explicitly not a goal here; for
instance, I use things like nub
, Writer
, and
lists for simplicity. On the other hand, I do pay attention to the
algorithmic complexity in terms of the number of arithmetic operations
performed; the analysis thereof will be done in a subsequent
article.
Background
Discrete Fourier transform turns \(n\) complex numbers \(a_0,a_1,\ldots,a_{n-1}\) into \(n\) complex numbers
\[f_k = \sum_{l=0}^{n-1} e^{- 2 \pi i k l / n} a_l.\]
An alternative way to think about \(f_k\) is as the values of the polynomial
\[P(x)=\sum_{l=0}^{n-1} a_l x^l\]
at \(n\) points \(w^0,w^1,\ldots,w^{n-1}\), where \(w=e^{-2 \pi i / n}\) is a certain \(n\)th primitive root of unity.
The naive calculation requires \(\Theta(n^2)\) operations; our goal is to reduce that number to \(\Theta(n \log n)\).
An excellent explanation of the algorithm (which inspired this article in the first place) is given by Daniel Gusfield in his video lectures; he calls it “the most important algorithm that most computer scientists have never studied”. You only need to watch the first two lectures (and maybe the beginning of the third one) to understand the algorithm and this article.
Roots of unity
Roots of unity could in principle be represented in the Cartesian
form by the Complex a
type. However, that would make it
very hard to compare them for equality, which we are going to do to
achieve a subquadratic complexity.
So here’s a small module just for representing these special complex numbers in the polar form, taking advantage of the fact that their absolute values are always 1 and their phases are rational multiples of \(\pi\).
module RootOfUnity
U -- abstract
(
, mkU
, toComplex
, u_pow
, u_sqrwhere
)
import Data.Complex
-- | U q corresponds to the complex number exp(2 i pi q)
newtype U = U Rational
deriving (Show, Eq, Ord)
-- | Convert a U number to the equivalent complex number
toComplex :: Floating a => U -> Complex a
U q) = mkPolar 1 (2 * pi * realToFrac q)
toComplex (
-- | Smart constructor for U numbers; automatically performs normalization
mkU :: Rational -> U
= U (q - realToFrac (floor q))
mkU q
-- | Raise a U number to a power
u_pow :: U -> Integer -> U
U q) p = mkU (fromIntegral p*q)
u_pow (
-- | Square a U number
u_sqr :: U -> U
= u_pow x 2 u_sqr x
Fast Fourier transform
{-# LANGUAGE ScopedTypeVariables #-}
module FFT (fft) where
import Data.Complex
import Data.Ratio
import Data.Monoid
import qualified Data.Map as Map
import Data.List
import Data.Bifunctor
import Control.Monad.Trans.Writer
import RootOfUnity
So we want to evaluate the polynomial \(P(x)=\sum_{l=0}^{n-1}a_lx^l\) at points \(w^k\). The trick is to represent \(P(x)\) as \(A_e(x^2) + x A_o(x^2)\), where \(A_e(x)=a_0+a_2 x + \ldots\) and \(A_o(x)=a_1+a_3 x + \ldots\) are polynomials constructed out of the even-numbered and odd-numbered coefficients of \(P\), respectively.
When \(x\) is a root of unity, so is \(x^2\); this allows us to apply the algorithm recursively to evaluate \(A_e\) and \(A_o\) for the squared numbers.
But the real boon comes when \(n\) is even; then there will be half as many of these squared numbers, because \(w^k\) and \(w^{k+n/2}\), when squared, both give the same number \(w^{2k}\). This is when the divide and conquer strategy really pays off.
We will represent a polynomial \(\sum_{l=0}^{n-1}a_lx^l\) in Haskell as a
list of coefficients [a_0,a_1,...]
, starting with \(a_0\).
To be able to split a polynomial into the even and odd parts, let’s define a corresponding list function
split :: [a] -> ([a], [a])
= foldr f ([], [])
split where
= (a : r2, r1) f a (r1, r2)
(I think I learned the idea of this elegant implementation from Dominic Steinitz.)
Now, the core of the algorithm: a function that evaluates a
polynomial at a given list of points on the unit circle. It tracks the
number of performed arithmetic operations through a Writer
monad over the Sum
monoid.
evalFourier :: forall a . RealFloat a
=> [Complex a] -- ^ polynomial coefficients, starting from a_0
-> [U] -- ^ points at which to evaluate the polynomial
-> Writer (Sum Int) [Complex a]
If the polynomial is a constant, there’s not much to calculate. This is our base case.
= return $ 0 <$ pts
evalFourier [] pts = return $ c <$ pts evalFourier [c] pts
Otherwise, use the recursive algorithm outlined above.
= do
evalFourier coeffs pts let
= nub $ u_sqr <$> pts -- values of x^2
squares = split coeffs
(even_coeffs, odd_coeffs) <- evalFourier even_coeffs squares
even_values <- evalFourier odd_coeffs squares
odd_values
let
-- a mapping from x^2 to (A_e(x^2), A_o(x^2))
=
square_map
Map.fromList. zip squares
$ zip even_values odd_values
-- evaluate the polynomial at a single point
eval1 :: U -> Writer (Sum Int) (Complex a)
= do
eval1 x let (ye,yo) = (square_map Map.! u_sqr x)
= ye + toComplex x * yo
r $ Sum 2 -- this took two arithmetic operations
tell return r
mapM eval1 pts
The actual FFT function is a simple wrapper around
evalFourier
which substitutes the specific points and
performs some simple conversions. It returns the result of the DFT and
the number of operations performed.
fft :: RealFloat a => [Complex a] -> ([Complex a], Int)
=
fft coeffs
second getSum. runWriter
. evalFourier coeffs
. map (u_pow w)
$ [0..n-1]
where
= genericLength coeffs
n = mkU (-1 % n) w