Why would you use ContT?

Published on

The ContT transformer is one of the more exotic and less used transformers provided by mtl. The Control.Monad.Cont module now even includes the following warning:

Before using the Continuation monad, be sure that you have a firm understanding of continuation-passing style and that continuations represent the best solution to your particular design problem. Many algorithms which require continuations in other languages do not require them in Haskell, due to Haskell’s lazy semantics. Abuse of the Continuation monad can produce code that is impossible to understand and maintain.

So what is ContT, and when does it represent the best solution to a problem?

Consider the following three functions from the base library:

openFile :: FilePath -> IOMode -> IO Handle
takeMVar :: MVar a             -> IO a
newArray :: Storable a => [a]  -> IO (Ptr a)

They are all “unsafe” in the sense that they return something to you but don’t clean up after you: don’t close the file you’ve just opened, put back the MVar you’ve just taken, or free the array’s memory. If you forget to perform the cleanup, it’s on you.

Indeed, a function that closed the file right after opening it would be rather useless. Therefore, a “safer” function should know what you intend to do with the file handle and insert that action between opening and closing the file.

The base library provides the following safer versions of the above functions:

                                 {-~~~~~~~~~~~~~~~~~~~~~~~~-}
withFile :: FilePath -> IOMode  -> (Handle -> IO r) -> IO r
withMVar :: MVar a              -> (a      -> IO r) -> IO r
withArray :: Storable a => [a]  -> (Ptr a  -> IO r) -> IO r
                                 {-~~~~~~~~~~~~~~~~~~~~~~~~-}

Notice how these functions follow the same pattern and how they relate to their unsafe versions: if the unsafe function returned IO a, then the corresponding safe function takes an additional argument of the form (a -> IO r) and returns IO r. This style of writing functions is called the continuation-passing style (CPS), and the argument a -> IO r is called the continuation.

ContT gives this pattern its own type:

newtype ContT r m a = ContT { runContT :: (a -> m r) -> m r }

(In this article, the underlying monad m will always be IO.)

You can construct ContT functions from the functions written in the CPS:

\v   -> ContT (withMVar v)   :: MVar a             -> ContT r IO a
\f m -> ContT (withFile f m) :: FilePath -> IOMode -> ContT r IO Handle
\l   -> ContT (withArray l)  :: Storable a => [a]  -> ContT r IO (Ptr a)

Notice how the types of the safe functions, once wrapped in ContT, become even more similar to the types of their non-safe versions: the only difference is ContT r IO instead of IO in the return type. And because ContT turns out to be a monad transformer, we can use and combine the safe function in all the same way as we could with the unsafe functions—getting the added safety essentially for free. Once you no longer need the resources you’ve allocated inside ContT, you turn it to IO using runContT, and when you exit a ContT block, all the cleanup actions are run—in the reverse order, of course.

Let’s consider some practical examples.

llvm-hs

In the llvm-hs package’s API, the continuation-passing style is used to reliably deallocate the allocated LLVM structures and classes, which are written in C++.

Here’s an example from llvm-hs’s test suite, reformatted such that each nesting level gets its own indentation level:

testCase "eager compilation" $ do
  resolvers <- newIORef Map.empty
  withTestModule $ \mod ->
    withHostTargetMachine $ \tm ->
      withExecutionSession $ \es ->
        withObjectLinkingLayer es (\k -> fmap (\rs -> rs Map.! k) (readIORef resolvers)) $ \linkingLayer ->
          withIRCompileLayer linkingLayer tm $ \compileLayer -> do
            testFunc <- mangleSymbol compileLayer "testFunc"
            withModuleKey es $ \k ->
              withSymbolResolver es (SymbolResolver (resolver testFunc compileLayer)) $ \resolver -> do
                modifyIORef' resolvers (Map.insert k resolver)
                withModule compileLayer k mod $ do
                  mainSymbol <- mangleSymbol compileLayer "main"
                  Right (JITSymbol mainFn _) <- CL.findSymbol compileLayer mainSymbol True
                  result <- mkMain (castPtrToFunPtr (wordPtrToPtr mainFn))
                  result @?= 42
                  Right (JITSymbol mainFn _) <- CL.findSymbolIn compileLayer k mainSymbol True
                  result <- mkMain (castPtrToFunPtr (wordPtrToPtr mainFn))
                  result @?= 42
                  unknownSymbol <- mangleSymbol compileLayer "unknownSymbol"
                  unknownSymbolRes <- CL.findSymbol compileLayer unknownSymbol True
                  unknownSymbolRes @?= Left (JITSymbolError mempty),

And here’s the same test rewritten with ContT:

testCase "eager compilation" $
  flip runContT return $ do
    resolvers <- liftIO $ newIORef Map.empty
    mod <- ContT withTestModule
    tm <- ContT withHostTargetMachine
    es <- ContT withExecutionSession
    linkingLayer <- ContT $ withObjectLinkingLayer es (\k -> fmap (\rs -> rs Map.! k) (readIORef resolvers))
    compileLayer <- ContT $ withIRCompileLayer linkingLayer tm
    testFunc <- liftIO $ mangleSymbol compileLayer "testFunc"
    k <- ContT $ withModuleKey es
    resolver <- ContT $ withSymbolResolver es (SymbolResolver (resolver testFunc compileLayer))
    liftIO $ modifyIORef' resolvers (Map.insert k resolver)
    ContT $ \c -> withModule compileLayer k mod (c ())
    -- All functions below are non-CPS, so we combine them into a single IO block
    -- instead of applying liftIO on every line.
    liftIO $ do
      mainSymbol <- mangleSymbol compileLayer "main"
      Right (JITSymbol mainFn _) <- CL.findSymbol compileLayer mainSymbol True
      result <- mkMain (castPtrToFunPtr (wordPtrToPtr mainFn))
      result @?= 42
      Right (JITSymbol mainFn _) <- CL.findSymbolIn compileLayer k mainSymbol True
      result <- mkMain (castPtrToFunPtr (wordPtrToPtr mainFn))
      result @?= 42
      unknownSymbol <- mangleSymbol compileLayer "unknownSymbol"
      unknownSymbolRes <- CL.findSymbol compileLayer unknownSymbol True
      unknownSymbolRes @?= Left (JITSymbolError mempty),

Most of the nesting and indentation is gone. The code is completely linear now, just as if we were using the unsafe, non-CPS functions. Applying ContT on every line is a bit awkward, but that would be gone if the library’s API was designed so that the CPS functions were already in the ContT monad.

Reading from a list of files

In the above example, we had merely a high nesting level. But what if the nesting level is unknown in advance?

Consider a program that takes a list of files and prints them interleaved—a line from the first file, then a line from the second, and so on, printing the line m+1 of the first file after the line m of the last file.

To implement such a program, we need to keep all the n files open all the time. If we want to use the safe withFile function, we need to nest it n times, where n is not known until the program is run. How is that even possible?

Well, it is possible with ContT.

import Control.Monad.Cont
import System.IO
import System.Environment

main = flip runContT return $ do
  args <- liftIO getArgs
  handles <- forM args $ \arg ->
    ContT $ withFile arg ReadMode
  liftIO $ print_interleaved_from_handles handles

print_interleaved_from_handles :: [Handle] -> IO ()
print_interleaved_from_handles = ...

Allocating a linked list

This example is similar to the previous one, except it actually occurred in my practice. See Way 2 in 6 ways to manage allocated memory in Haskell.

The withX-style safe functions are far from the only possible use for CPS and ContT, but it’s the one that I encounter most often. So if you were wondering why you’d use a continuation monad, hopefully this gives you some ideas.