Generic unification

Published on

The unification-fd package by wren gayle romano is the de-facto standard way to do unification in Haskell. You’d use it if you need to implement type inference for your DSL, for example.

To use unification-fd, we first need to express our Type type as a fixpoint of a functor, a.k.a. an initial algebra.

For instance, let’s say we want to implement type inference for the simply typed lambda calculus (STLC). The types in STLC can be represented by a Haskell type

data Type = BaseType String
          | Fun Type Type

Note that Type cannot represent type variables.

Type can be equivalently represented as a fixpoint of a functor, TypeF:

-- defined in Data.Functor.Fixedpoint in unification-fd
newtype Fix f = Fix { unFix :: f (Fix f) }

data TypeF a = BaseType String
             | Fun a a
  deriving (Functor, Foldable, Traversable)

type Type = Fix TypeF

So Fix TypeF still cannot represent any type variables, but UTerm TypeF can. UTerm is another type defined in unification-fd that is similar to Fix except it includes another constructor for type variables:

-- defined in Control.Unification in unification-fd
data UTerm t v
    = UVar  !v               -- ^ A unification variable.
    | UTerm !(t (UTerm t v)) -- ^ Some structure containing subterms.

type PolyType = UTerm TypeF IntVar

UTerm, by the way, is the free monad over the functor t.

Unifiable

The Control.Unification module exposes several algorithms (unification, alpha equivalence) that work on any UTerm, provided that the underlying functor t (TypeF in our example) implements a zipMatch function:

class (Traversable t) => Unifiable t where
    -- | Perform one level of equality testing for terms. If the
    -- term constructors are unequal then return @Nothing@; if they
    -- are equal, then return the one-level spine filled with
    -- resolved subterms and\/or pairs of subterms to be recursively
    -- checked.
    zipMatch :: t a -> t a -> Maybe (t (Either a (a,a)))

zipMatch essentially tells the algorithms which constructors of our TypeF functor are the same, which are different, and which fields correspond to variables. So for TypeF it could look like

instance Unifiable TypeF where
  zipMatch (BaseType a) (BaseType b) =
    if a == b
      then Just $ BaseType a
      else Nothing
  zipMatch (Fun a1 a2) (Fun b1 b2) =
    Just $ Fun (Right (a1, b1)) (Right (a2, b2))
  zipMatch _ _ = Nothing

Now, I prefer the following style instead:

instance Unifiable TypeF where
  zipMatch a b =
    case a of
      BaseType a' -> do
        BaseType b' <- return b
        guard $ a' == b'
        return $ BaseType a'
      Fun a1 a2 -> do
        Fun b1 b2 <- return b
        return $ Fun (Right (a1, b1)) (Right (a2, b2))

Why? First, I really don’t like multi-clause definitions. But the main reason is that the second definition behaves more reliably when we add new constructors to TypeF. Namely, if we enable ghc warnings (-Wall) and extend TypeF to include tuples:

data TypeF a = BaseType String
             | Fun a a
             | Tuple a a
  deriving (Functor, Foldable, Traversable)

… we’ll get a warning telling us not to forget to implement zipMatch for tuples:

warning: [-Wincomplete-patterns]
    Pattern match(es) are non-exhaustive
    In a case alternative: Patterns not matched: (Tuple _ _)

If we went with the first version, however, we would get no warning, because it contains a catch-all clause

  zipMatch _ _ = Nothing

As a result, it is likely that we forget to update zipMatch, and our tuples will never unify.

This is a common mistake people make when implementing binary operations in Haskell, so I just wanted to point it out. But other than that, both definitions are verbose and boilerplate-heavy.

And it goes without saying that in real-life situations, the types we want to unify tend to be bigger, and the boilerplate becomes even more tedious.

For instance, I’ve been working recently on implementing type inference for the nstack DSL, which includes tuples, records, sum types, optionals, arrays, the void type, and many primitive types. Naturally, I wasn’t eager to write zipMatch by hand.

Generic Unifiable

Generic programming is a set of techniques to avoid writing boilerplate such as our implementation of zipMatch above.

Over the years, Haskell has acquired a lot of different generic programming libraries.

For most of my generic programming needs, I pick uniplate. Uniplate is very simple to use and reasonably efficient. Occasionally I have a problem that requires something more sophisticated, like generics-sop to parse YAML or traverse-with-class to resolve names in a Haskell AST.

But none of these libraries can help us to implement a generic zipMatch.

Consider the following type:

data TypeF a = Foo a a
             | Bar Int String

A proper zipMatch implementation works very differently for Foo and Bar: Foo has two subterms to unify whereas Bar has none.

But most generics libraries don’t see this difference between Foo and Bar. They don’t distinguish between polymoprhic and non-polymorphic fields. Instead, they treat all fields as non-polymorphic. From their point of view, TypeF Bool is exactly equivalent to

data TypeF = Foo Bool Bool
           | Bar Int String

Luckily, there is a generic programming library that lets us “see” type parameters. Well, just one type parameter, but that’s exactly enough for zipMatch. In other words, this library provides a generic representation for type constructors of kind * -> *, whereas most other libraries only concern themselves with ordinary types of kind *.

What is that library called? base.

Seriously, starting from GHC 7.6 (released in 2012), the base library includes a module GHC.Generics. The module consists of:

  1. Several types (constants K1, parameters Par1, sums :+:, products :*:, compositions :.:) out of which we can build different algebraic types of kind * -> *.
  2. A class for representable algebraic data types, Generic1:

    The associated type synonym Rep1 maps an algebraic data type like TypeF to an isomorphic type composed out of the primitives like K1 and :*:. The functions from1 and to1 allow converting between the two.

The compiler itself knows how to derive the Generic1 instance for eligible types. Here is what it looks like:

{-# LANGUAGE DeriveFunctor, DeriveFoldable, DeriveTraversable #-}
{-# LANGUAGE DeriveGeneric #-}

import GHC.Generics (Generic1)

data TypeF a = BaseType String
             | Fun a a
             | Tuple a a
  deriving (Functor, Foldable, Traversable, Generic1)

So, in order to have a generic Unifiable instance, all I had to do was:

  1. Implement Unifiable for the primitive types in GHC.Generics.
  2. Add a default zipMatch implementation to the Unifiable class.

You can see the details in the pull request.

Complete example

Here is a complete example that unifies a -> (c, d) with c -> (a, b -> a).

{-# OPTIONS_GHC -Wall #-}
{-# LANGUAGE DeriveFunctor, DeriveFoldable,
             DeriveTraversable, DeriveGeneric,
             DeriveAnyClass, FlexibleContexts
  #-}

import Data.Functor.Identity
import Control.Unification
import Control.Unification.IntVar
import Control.Unification.Types
import Control.Monad.Trans.Except
import Control.Monad.Trans.Class
import qualified Data.Map as Map
import GHC.Generics

data TypeF a = BaseType String
             | Fun a a
             | Tuple a a
  deriving (Functor, Foldable, Traversable, Show, Generic1, Unifiable)
  --                                              ^^^^^^^^^^^^^^^^^^^
  --                                           the magic happens here

unified :: IntBindingT TypeF Identity
  (Either
    (UFailure TypeF IntVar)
    (UTerm TypeF String))
unified = runExceptT $ do
  a_var <- lift freeVar
  b_var <- lift freeVar
  c_var <- lift freeVar
  d_var <- lift freeVar
  let
    a = UVar a_var
    b = UVar b_var
    c = UVar c_var
    d = UVar d_var
    
    term1 = UTerm (Fun a (UTerm $ Tuple c d))
    term2 = UTerm (Fun c (UTerm $ Tuple a (UTerm $ Fun b a)))

  result <- applyBindings =<< unify term1 term2

  -- replace integer variable identifiers with variable names
  let
    all_vars = Map.fromList
      [(getVarID a_var, "a")
      ,(getVarID b_var, "b")
      ,(getVarID c_var, "c")
      ,(getVarID d_var, "d")
      ]

  return $ fmap ((all_vars Map.!) . getVarID) result

main :: IO ()
main = print . runIdentity $ evalIntBindingT unified

Output:

Right (Fun "c" (Tuple "c" (Fun "b" "c")))