Simple FFT in Haskell

Published on

The article develops a simple implementation of the fast Fourier transform in Haskell.

Raw performance of the algorithm is explicitly not a goal here; for instance, I use things like nub, Writer, and lists for simplicity. On the other hand, I do pay attention to the algorithmic complexity in terms of the number of arithmetic operations performed; the analysis thereof will be done in a subsequent article.

Background

Discrete Fourier transform turns \(n\) complex numbers \(a_0,a_1,\ldots,a_{n-1}\) into \(n\) complex numbers

\[f_k = \sum_{l=0}^{n-1} e^{- 2 \pi i k l / n} a_l.\]

An alternative way to think about \(f_k\) is as the values of the polynomial

\[P(x)=\sum_{l=0}^{n-1} a_l x^l\]

at \(n\) points \(w^0,w^1,\ldots,w^{n-1}\), where \(w=e^{-2 \pi i / n}\) is a certain \(n\)th primitive root of unity.

The naive calculation requires \(\Theta(n^2)\) operations; our goal is to reduce that number to \(\Theta(n \log n)\).

An excellent explanation of the algorithm (which inspired this article in the first place) is given by Daniel Gusfield in his video lectures; he calls it “the most important algorithm that most computer scientists have never studied”. You only need to watch the first two lectures (and maybe the beginning of the third one) to understand the algorithm and this article.

Roots of unity

Roots of unity could in principle be represented in the Cartesian form by the Complex a type. However, that would make it very hard to compare them for equality, which we are going to do to achieve a subquadratic complexity.

So here’s a small module just for representing these special complex numbers in the polar form, taking advantage of the fact that their absolute values are always 1 and their phases are rational multiples of \(\pi\).

module RootOfUnity
  ( U -- abstract
  , mkU
  , toComplex
  , u_pow
  , u_sqr
  ) where

import Data.Complex

-- | U q corresponds to the complex number exp(2 i pi q)
newtype U = U Rational
  deriving (Show, Eq, Ord)

-- | Convert a U number to the equivalent complex number
toComplex :: Floating a => U -> Complex a
toComplex (U q) = mkPolar 1 (2 * pi * realToFrac q)

-- | Smart constructor for U numbers; automatically performs normalization
mkU :: Rational -> U
mkU q = U (q - realToFrac (floor q))

-- | Raise a U number to a power
u_pow :: U -> Integer -> U
u_pow (U q) p = mkU (fromIntegral p*q)

-- | Square a U number
u_sqr :: U -> U
u_sqr x = u_pow x 2

Fast Fourier transform

{-# LANGUAGE ScopedTypeVariables #-}
module FFT (fft) where

import Data.Complex
import Data.Ratio
import Data.Monoid
import qualified Data.Map as Map
import Data.List
import Data.Bifunctor
import Control.Monad.Trans.Writer
import RootOfUnity

So we want to evaluate the polynomial \(P(x)=\sum_{l=0}^{n-1}a_lx^l\) at points \(w^k\). The trick is to represent \(P(x)\) as \(A_e(x^2) + x A_o(x^2)\), where \(A_e(x)=a_0+a_2 x + \ldots\) and \(A_o(x)=a_1+a_3 x + \ldots\) are polynomials constructed out of the even-numbered and odd-numbered coefficients of \(P\), respectively.

When \(x\) is a root of unity, so is \(x^2\); this allows us to apply the algorithm recursively to evaluate \(A_e\) and \(A_o\) for the squared numbers.

But the real boon comes when \(n\) is even; then there will be half as many of these squared numbers, because \(w^k\) and \(w^{k+n/2}\), when squared, both give the same number \(w^{2k}\). This is when the divide and conquer strategy really pays off.

We will represent a polynomial \(\sum_{l=0}^{n-1}a_lx^l\) in Haskell as a list of coefficients [a_0,a_1,...], starting with \(a_0\).

To be able to split a polynomial into the even and odd parts, let’s define a corresponding list function

split :: [a] -> ([a], [a])
split = foldr f ([], [])
  where
    f a (r1, r2) = (a : r2, r1)

(I think I learned the idea of this elegant implementation from Dominic Steinitz.)

Now, the core of the algorithm: a function that evaluates a polynomial at a given list of points on the unit circle. It tracks the number of performed arithmetic operations through a Writer monad over the Sum monoid.

evalFourier
  :: forall a . RealFloat a
  => [Complex a] -- ^ polynomial coefficients, starting from a_0
  -> [U] -- ^ points at which to evaluate the polynomial
  -> Writer (Sum Int) [Complex a]

If the polynomial is a constant, there’s not much to calculate. This is our base case.

evalFourier []  pts = return $ 0 <$ pts
evalFourier [c] pts = return $ c <$ pts

Otherwise, use the recursive algorithm outlined above.

evalFourier coeffs pts = do
  let
    squares = nub $ u_sqr <$> pts -- values of x^2
    (even_coeffs, odd_coeffs) = split coeffs
  even_values <- evalFourier even_coeffs squares
  odd_values <- evalFourier odd_coeffs squares

  let
    -- a mapping from x^2 to (A_e(x^2), A_o(x^2))
    square_map =
      Map.fromList
      . zip squares
      $ zip even_values odd_values

    -- evaluate the polynomial at a single point
    eval1 :: U -> Writer (Sum Int) (Complex a)
    eval1 x = do
      let (ye,yo) = (square_map Map.! u_sqr x)
          r = ye + toComplex x * yo
      tell $ Sum 2 -- this took two arithmetic operations
      return r

  mapM eval1 pts

The actual FFT function is a simple wrapper around evalFourier which substitutes the specific points and performs some simple conversions. It returns the result of the DFT and the number of operations performed.

fft :: RealFloat a => [Complex a] -> ([Complex a], Int)
fft coeffs =
  second getSum
  . runWriter 
  . evalFourier coeffs 
  . map (u_pow w)
  $ [0..n-1]
  where
    n = genericLength coeffs
    w = mkU (-1 % n)