Wednesday, November 18, 2015

Error Handling in Menhir

My favorite category of paper is the "cute trick" paper, where the author describes a programming trick that is both easy to implement and much better than the naive approach. One of my favorite cute trick papers is Clinton L. Jeffrey's 2003 ACM TOPLAS paper, Generating LR Syntax Error Messages from Examples. The abstract explains the idea:
LR parser generators are powerful and well-understood, but the parsers they generate are not suited to provide good error messages. Many compilers incur extensive modifications to the source grammar to produce useful syntax error messages. Interpreting the parse state (and input token) at the time of error is a nonintrusive alternative that does not entangle the error recovery mechanism in error message production. Unfortunately, every change to the grammar may significantly alter the mapping from parse states to diagnostic messages, creating a maintenance problem.

Merr is a tool that allows a compiler writer to associate diagnostic messages with syntax errors by example, avoiding the need to add error productions to the grammar or interpret integer parse states. From a specification of errors and messages, Merr runs the compiler on each example error to obtain the relevant parse state and input token, and generates a yyerror() function that maps parse states and input tokens to diagnostic messages. Merr enables useful syntax error messages in LR-based compilers in a manner that is robust in the presence of grammar changes.

Basically, you can take a list of syntactically incorrect programs and error messages for them, and the tool will run the parser to figure out which state the LR automaton will get stuck in for each broken program, and then ensure that whenever you hit that state, you report the error message you wrote. I like this idea an awful lot, because it means that even if you use the full power of LR parsing, you still have a better error message story than any other approach, including hand-coded recursive descent!

Recently, Francois Pottier has studied a generalization of this idea, in his new draft Reachability and error diagnosis in LR(1) automata.

Given an LR(1) automaton, what are the states in which an error can be detected? For each such “error state”, what is a minimal input sentence that causes an error in this state? We propose an algorithm that answers these questions. Such an algorithm allows building a collection of pairs of an erroneous input sentence and a diagnostic message, ensuring that this collection covers every error state, and maintaining this property as the grammar evolves. We report on an application of this technique to the CompCert ISO C99 parser, and discuss its strengths and limitations.
Basically, in Jeffrey's approach, you write broken programs and their error messages, and Pottier has woroked out an algorithm to generate a covering set of broken programs. Even more excitingly, Pottier has implemented this idea in a production parser generator --- the most recent version of his Menhir parser generator.

So now there's no excuse for using parser combinators, since they offer worse runtime performance and worse error messages! :)

Friday, July 31, 2015

FRP without Space Leaks

The dataflow engine I gave in my last post can be seen as an implementation of self-adjusting computation, in the style of Acar, Blelloch and Harper's original POPL 2002 paper Adaptive Functional Programming. (Since then, state of the art implementation techniques have improved a lot, so don't take my post as indicative of what modern libraries do.)

Many people have seen resemblances between self-adjusting computation and functional reactive programming --- a good example of this is Jake Donham's Froc library for Ocaml. Originally, I was one of those people, but that's no longer true: I think SAC and FRP are completely orthogonal.

I now think that FRP libraries can be very minimalistic --- my ICFP 2013 paper Higher-Order Reactive Programming without Spacetime Leaks gives a type system, implementation, and correctness proof for an FRP language with full support for higher order constructions like higher-order functions and streams of streams, while at the same time statically ruling out space and time leaks.

The key idea is to distinguish between stable values (like ints and bools) whose representation doesn't change over time from dynamic values (like streams) whose representation is time-varying. Stable values are the usual datatypes, and can be used whenever we like. But dynamic values have a scheduling constraint: we can only use them at certain times. For example, with a stream, we want to look at the head at time 0, the head of the tail at time 1, the head of the tail of the tail at time 2, and so on. It's a mistake to look at the head of the tail of a stream at time 0, because that value might not be available yet.

With an appropriate type discipline, it's possible to ensure scheduling correctness statically, but unfortunately many people are put off by modal types and Kripke logical relations. This is a shame, because the payoff of all this is that the implementation strategy is super-simple -- we can just use plain-vanilla lazy evaluation to implement FRP.

Recently, though, I've figured out how to embed this kind of FRP library into standard functional languages like Ocaml. Since we can't define modal type operators in standard functional languages, we have to give up some static assurance, and replace the static checks of time-correctness with dynamic checks, but we are still able to rule out space leaks by construction, and still get a runtime error if we mis-schedule a program. Essentially, we can replace type checking with contract checking. As usual, you can find the code on Github here.

Let's look at an Ocaml signature for this library:

module type NEXT =
sig
  type 'a t
  exception Timing_error of int * int 

  val delay  : (unit -> 'a) -> 'a t
  val map    : ('a -> 'b) -> 'a t -> 'b t
  val zip    : 'a t * 'b t -> ('a * 'b) t
  val unzip  : ('a * 'b) t -> 'a t * 'b t
  val ($)    : ('a -> 'b) t -> 'a t -> 'b t   (* This op is redundant but convenient *)
  val fix    : ('a t -> 'a) -> 'a

  (* Use these operations to implement an event loop *)

  module Runtime : sig
    val tick : unit -> unit
    val force : 'a t -> 'a
  end
end

The NEXT signature introduces a single type constructor 'a t, which can be thought of as the type of computations which are scheduled to be evaluated on precisely the next tick of the clock. The elements of 'a t are dynamic in the sense that I mention above: we are only permitted to evaluate it on the next tick of the clock, and evaluating it at any other time is an error.

To model this kind of error, we also have a Timing_error exception, which signals an error whose first argument contains the time a thunk was scheduled to be evaluated, and whose second argument contains the actual time.

Elements of 'a t are the only primitive way to create dynamic values -- other values (like function closures) can be dynamic, but only if they end up capturing a next-step thunk.

The delay function lets us create a next value from a thunk, and the map function maps a function over a thunk. The zip and unzip are used for pairing, and the $ operation is the McBride/Paterson idiomatic application operator. (Technically, it's derivable from zip, but it's easiest to throw it in to the basic API.)

The fix operation is the one that really makes reactive programming possible -- it says that guarded recursion is allowed. So if we have a function which takes an 'a next and returns an 'a, then we can take a fixed point. This fixed point will always never block the event loop, because its type ensures that we always delay by a tick before making a recursive call.

This raw interface is, honestly, not so useful as is, but the slightly miraculous fact is that this is the complete API we need to build all the higher-level abstractions --- like events and streams --- that we need to do real reactive programming.

Now, let's see what an implementation of this library could look like.

module Next : NEXT = struct
  let time = ref 0

We can keep track of the current time in a reference cell.

  type 'a t = {
    time : int;
    mutable code : 'a Lazy.t
  }

The type of a thunk is a record consisting of a lazy thunk, and the time when it is safe to force it.
  type s = Hide : 'a t -> s 
  let thunks : s list ref = ref []

We also have a list that stores all of the references that we've allocated. We'll use this list to enforce space-safety, by mutating any thunk that gets too old.

  exception Timing_error of int * int 

  let delay t =
    let t = { time = 1 + !time; code = Lazy.from_fun t} in
    thunks := (Hide t) :: !thunks;
    t

When we create a thunk with the delay function, we are creating a thunk to be forced on the next time tick. So we can dererefence time in order to find out the current time, and add 1 to get the scheduled execution time for the thunk. We also add it to the list thunks, so that we can remember that we created it.
  

  let force t =
     if t.time != !time then
       raise (Timing_error(t.time, !time))
     else
       Lazy.force t.code

Forcing a thunk just forces the code thunk, if the current time matches the scheduled time for the thunk. Otherwise, we raise a Timing_error. Note that memoization is handled by Ocaml's built-in 'a Lazy.t type.

   let map f r = delay (fun () -> f (force r))
   let zip (r, r') = delay (fun () -> (force r, force r'))
   let unzip r = (map fst r, map snd r)
   let ($) f x = delay (fun () -> force f (force x))
 
The map, zip, unzip, and ($) operators just force and delay things in the obvious places.
 
   let rec fix f = f (delay (fun () -> fix f))
 
The fixed point operation looks exactly like the standard lazy fixed point.
 
   module Runtime = struct
     let force = force
 
The runtime exposes the force function to the event loop.
 
     let cleanup (Hide t) =
       let b = t.time < !time in
       (if b then t.code <- lazy (raise (assert false)));
       b
 
     let tick () =
       time := !time + 1;
       thunks := List.filter cleanup !thunks
  end
end

The tick function advances time by doing two things. First, it increments the current time. Then, it filters the list of thunks using the cleanup function, which does two things. First, cleanup returns true if its argument is older than the current time. As a result, we only retain thunks in thunks which can be forced now or in in the future.

Second, if the argument thunk to cleanup is old, it replaces the code body with an assertion failure, since no time-correct program should ever force this thunk. Updating the code ensures that by construction next-step thunks always lose their reference to their data once they age out, because every thunk is placed onto thunks when it is created, and when the clock is ticked past its time, it is guaranteed to drop its references to its data.

This guarantees that spacetime leaks are impossible, since we dynamically zero out any thunks that get too old! So here we see how essential data abstraction is for imperative programming, and not just functional programming.

As you can see, the implementation of the Next library is pretty straightforward. The only mildly clever thing we do is to keep track of the next-tick computations so we can null them out when they get too old.

You should be wondering now how we can actually write reactive programs, when the primitive the API provides only lets you schedule a computation to run on the next tick, and that's it. The answer is datatype declarations. Now that we have a type that lets us talk about time, We can re-use our host language's facility to define types which say more interesting things about time.

Let's start with the classic datatype of functional reactive programming: streams. Streams are a kind of lazy sequence, which recursively give you a value now, and a stream starting tomorrow, thereby giving you a value on every time step.

module Stream :
  sig
    type 'a stream = Cons of 'a * 'a stream Next.t
    val head : 'a stream -> 'a
    val tail : 'a stream -> 'a stream Next.t
    val unfold : ('a -> 'b * 'a Next.t) -> 'a -> 'b stream
    val map : ('a -> 'b) -> 'a stream -> 'b stream
    val zip : 'a stream * 'b stream -> ('a * 'b) stream
    val unzip : ('a * 'b) stream -> 'a stream * 'b stream
  end = struct

We give a simple signature for streams above. They are a datatype exactly following the English description above, as well as a collection of accessor and constructor functions, like head, tail, map, unfold and so on. All of these have pretty much the expected types.

The only difference from the usual stream types is that sometimes we need a Next.t to tell us when a value needs to be available. Now let's look at the implementation.


  open Next

  type 'a stream = Cons of 'a * 'a stream Next.t


  let head (Cons(x, xs)) = x
  let tail (Cons(x, xs)) = xs
We can write accessor functions for streams, for convenience.

  let map f = fix (fun loop (Cons(x, xs)) -> Cons(f x, loop $ xs))

The map function uses the fix fixed point operator in our API, because we want to call the recursive function at a later time.
  
  let unfold f = fix (fun loop seed ->
      let (x, seed) = f seed in
      Cons(x, loop $ seed))

The unfold function uses a function f and an initial seed value to incrementally produce a sequence of values. This is exactly like the usual unfold, except we have to use the applicative interface to the 'a Next.t type to apply the function.
  

  let zip pair =
    unfold (fun (Cons(x, xs), Cons(y, ys)) -> ((x, y), Next.zip (xs, ys)))
           pair

  let unzip xys = fix (fun loop (Cons((x,y), xys')) ->
                         let (xs', ys') = Next.unzip (loop $ xys') in
                         (Cons(x, xs'), Cons(y, ys')))
                      xys
  end
zip and unzip work about the way we'd expect, in that we use Next.zip and Next.unzip to put together and take apart delayed pairs to build the ability to put together and take apart streams.

This is all very nice, but the real power of giving a reactive API based on a next-step type is that we can build types which aren't streams. For example, let's give a datatype of events, which is the type of values which will become available at some point in the future, but we don't know exactly when.

module Event :
  sig
    type 'a event = Now of 'a | Wait of 'a event Next.t
    val map : ('a -> 'b) -> 'a event -> 'b event
    val return : 'a -> 'a event
    val bind : 'a event -> ('a -> 'b event) -> 'b event
    val select : 'a event -> 'a event -> 'a event
  end
= struct
  open Next

  type 'a event = Now of 'a | Wait of 'a event Next.t

We represent this with a datatype 'a event, which has two constructors. We say that an 'a event is either a value of type 'a available Now, or we have to Wait to get another event tomorrow. So this is a single value of type 'a that could come at any time --- and we don't know when!

  let map f = fix (fun loop e ->
      match e with
      | Now x -> Now (f x)
      | Wait e' -> Wait (loop $ e'))

We can map over events, by waiting until the value becomes available and then applying a function to the result.

  let return x = Now x

  let bind m f =
    fix (fun bind m ->
        match m with 
        | Now v -> f v
        | Wait e' -> Wait (bind $ e'))
      m
Events also form a monad, which corresponds to the ability to sequence promises or futures in the promises libraries you'll find in Javascript or Scala. (The bind here is a bit like the code promise.then() method in JS.
  let select e1 e2 =
    fix (fun loop e1 e2 -> 
          match e1, e2 with 
          | Now a1, _ -> Now a1
          | _, Now a2 -> Now a2
          | Wait e1, Wait e2 -> Wait (loop $ e1 $ e2))
      e1
      e2
end

The really cool thing is that we can also join on two events to wait for the first one to complete! This can be extended to lists of events, if desired, but the pattern is easiest to see in the binary case.

Of course, here's a small example of how you can actually put this together to actually run a program. The run function gives an event loop that runs for k steps and halts, and prints out the first k elements of the stream it gets passed as an argument.

module Test =
struct
  open Next
  open Stream

  let ints n = unfold (fun i -> (i, delay(fun () -> i+1))) n 

  let rec run k xs =
    if k = 0
    then ()
    else
      let (x, xs) = (head xs, tail xs) in
      Printf.printf "%d\n" x;
      Runtime.tick();
      run (k-1) (Runtime.force xs)
end

Wednesday, July 22, 2015

How to implement a spreadsheet

My friend Lindsey Kuper recently remarked on Twitter that spreadsheets were commonly understood to be the most widely used dataflow programming model, and asked if there was a simple implementation of them.

As chance would have it, this was one of the subjects of my thesis work -- as part of it, I wrote and proved the correctness of a small dataflow programming library. This program has always been one of my favorite little higher-order imperative programs, and in this post I'll walk through the implementation. (You can find the code here.)

As for the proof, you can look at this TLDI paper for some idea of the complexities involved. These days it could all be done more simply, but the pressure of proving everything correct did have a very salutary effect in keeping everything as simple as possible.

The basic idea of a spreadsheet (or other dataflow engine) is that you have a collection of places called cells, each of which contains an expression. An expression is basically a small program, which has the special ability to ask other cells what their value is. The reason cells are interesting is because they do memoization: if you ask a cell for its value twice, it will only evaluate its expression the first time. Furthermore, it's also possible for the user to modify the expression a cell contains (though we don't want cells to modify their code as they execute).

So let's turn this into code. I'll use Ocaml, because ML modules make describing the interface particularly pretty, but it should all translate into Scala or Haskell easily enough. In particular, we'll start by giving a module signature writing down the interace.

 module type CELL = sig

We start by declaring two abstract types, the type 'a cell of cells containing a value of type 'a, and the type 'a exp of expressions returning a value of type 'a.

   type 'a cell
   type 'a exp

Now, the trick we are using in implementing expressions is that we treat them as a monadic type. By re-using our host language as the language of terms that lives inside of a cell, we don't have to implement parsers or interpreters or anything like that. This is a familiar trick to Haskell programmers, but it's still a good trick! So we first give the monadic bind and return operators:

 
   val return : 'a -> 'a exp
   val (>>=) : 'a exp -> ('a -> 'b exp) -> 'b exp

And then we can specify the two operations that are unique to our monadic DSL: reading a cell (which we call get), and creating a new cell (which we call cell). It's a bit unusual to be able to create new cells as a program executes, but it's rather handy.

   val cell : 'a exp -> 'a cell exp 
   val get :  'a cell -> 'a exp

Aside from that, there are no other operations in the monadic expression DSL. Now we can give the operations that don't live in the monad. First is the update operation, which modifies the contents of a cell. This should not be called from within an 'a exp terms --- in Haskell, that might be enforced by giving update an IO type.

 
   val set : 'a cell -> 'a exp -> unit 

Finally, there's the run operation, which we use to run an expression. This is useful mainly for looking at the values of cells from the outside.

   val run : 'a exp -> 'a 
 end

Now, we can move on to the implementation.

 
 module Cell : CELL = struct
The implementation of cells is at the heart of the dataflow engine, and is worth discussing in detail. A cell is a record with five fields:
   type 'a cell = {
     mutable code      : 'a exp;
     mutable value     : 'a option;
     mutable reads     : ecell list;
     mutable observers : ecell list;
     id                : int
   }
  • The code field of this record is the pointer to the expression that the cell contains. This field is mutable because we can alter the contents of a cell!
  • The value field is an option type, which is None if the cell has not been evaluated yet, and Some v if the code had evaluated to v.
  • The reads field is a list containing all of the cells that were read when the code in the code field was executed. If the cell hasn't been evaluated yet, then this is the empty list.
  • The observers field is a list containing all of the cells that have read this cell when they were evaluated. So the reads field lists all the cells this cell depends on, and the observers field lists all the cells which depend on this cell. If this cell hasn't been evaluated yet, then observers will of course be empty.
  • The id contains an integer which is the unique id of each cell.

Both reads and observers store lists of dependent cells, and dependent cells can be of any type. In order to build a list of heterogenous cells, we need to introduce a type ecell, which just hides the cell type under an existential (using Ocaml's new GADT syntax):

   and ecell = Pack : 'a cell -> ecell

We can now also give the concrete type of expressions. We define an element of expression type 'a exp to be a thunk, which when forced returns (a) a value of type 'a, and (b) the list of cells that it read while evaluating:

   and 'a exp = unit -> ('a * ecell list)
 

Next, let's define a couple of helper functions. The id function just returns the id of an existentially packed ecell, and the union function merges two lists of ecells while removing duplicates.

   let id (Pack c) = c.id

   let rec union xs ys =
     match xs with
     | [] -> []
     | x :: xs' ->
       if List.exists (fun y -> id x = id y) ys then
         union xs' ys
       else
         x :: (union xs' ys)

The return function just produces a thunk which returns a value and an empty list of read dependencies, and the monadic bind (>>=) sequentially composes two computations, and returns the union of their read dependencies.

   let return v () = (v, [])
   let (>>=) cmd f () =
     let (a, cs) = cmd () in
     let (b, ds) = f a () in
     (b, union cs ds)

To implement the cell operator, we need a source of fresh id's. So we create an integer reference, and new_id bumps the counter before returning a fresh id.

   let r = ref 0
   let new_id () = incr r; !r 

Now we can implement cell. This function takes an expression exp, and uses new_id to create a unique id for a cell, and then intializes a cell with the appropriate values -- the code field is exp, the value field is None (because the cell is created in an unevaluated state), and the reads and observers fields are empty (because the cell is unevaluated), and the id is set to the value we generated.

This is returned with an empty list of read dependencies because we didn't read anything to construct a fresh cell!

   let cell exp () =
      let n = new_id() in
    let cell = {
      code = exp;
      value = None;
      reads = [];
      observers = [];
      id = n;
     } in
    (cell, [])

To read a cell, we need to implement the get operation. This works a bit like memoization. First, we check to see if the value field already has a value. If it does, then we can return that. If it is None, then we have a bit more work to do.

First, we have to evaluate the expression in the code field, which returns a value v and a list of read dependencies ds. We can update the value field to Some v, and then set the reads field to ds. Then, we add this cell to the observers field of every read dependency in ds, because this cell is observing them now.

Finally, we return the value v as well as a list containing the current cell (which is the only dependency of reading the cell).

  let get c () =
    match c.value with
    | Some v -> (v, [Pack c])
    | None ->
      let (v, ds) = c.code ()  in
      c.value <-Some v;
      c.reads <- ds;
      List.iter (fun (Pack d) -> d.observers <- (Pack c) :: d.observers) ds;
      (v, [Pack c])

This concludes the implementation of the monadic expression language, but our API also includes an operation to modify the code in a cell. This requires more code than just updating a field -- we have to invalidate everything which depends on the cell, too. So we need some helper functions to do that.

The first helper is remove_observer o o'. This removes the cell o from the observers field of o'. It does this by comparing the id field (which was in fact put in for this very purpose).

  let remove_observer o (Pack c) =
    c.observers <- List.filter (fun o' -> id o != id o') c.observers

This function is used to implement invalidate, which takes a cell, marks it as invalid, and then marks everything which transitively depends on it invalid too. It does this by we saving the reads and observers fields into the variables rs and os. Then, it marks the current cell as invalid by setting the value field to None, and setting the observers and reads fields to the empty list. Then, it removes the current cell from the observers list of every cell in the old read set rs, and then it calls invalidate recursively on every observer in os.

  let rec invalidate (Pack c) =
    let os = c.observers in
    let rs = c.reads in 
    c.observers <- [];
    c.value <- None;
    c.reads <- [];
    List.iter (remove_observer (Pack c)) rs;
    List.iter invalidate os

This then makes it easy to implement set -- we just update the code, and then invalidate the cell (since the memoized value is no longer valid).

    
  let set c exp =
    c.code <- exp;
    invalidate (Pack c)

Finally, we can implement the run function by forcing the thunk and throwing away the read dependencies.

  let run cmd = fst (cmd ())
end

That's pretty much it. I think it's quite pleasant to see how little code it takes to implement such an engine.

One thing I like about this program is that it also shows off how gc-unfriendly dataflow is: we track dependencies in both directions, and as a result the whole graph is always reachable. As a result, the usual gc heuristis will collect nothing as long as anything is reachable. You can fix the problem by using weak references to the observers, but weak references are also horribly gc-unfriendly (usually there's a traversal of every weak reference on every collection).

So I think it's very interesting that there are a natural class of programs for which the reachability heuristic just doesn't work, and this indicates that some interesting semantics remains to be done to explain what the correct memory management strategy for these kinds of programs is.

Tuesday, March 17, 2015

Abstract Binding Trees, an addendum

It struck me after the last post that it might be helpful to give an example using abstract binding trees in a more nontrivial way. The pure lambda calculus has a very simple binding structure, and pretty much anything you do will work out. So I decided to show how ABTs can be used to easily support a much more involved form of binding -- namely, pattern matching.

This makes for a nice example, because it is a very modular and well-structured language feature, that naturally has a rather complex binding structure. However, as we will see, it very easily into an ABT-based API. I will do so assuming the implementation in the previous post, focusing only on the use of ABTs.

As is usual, we will need to introduce a signature module for the language.

 module Lambda =
 struct
Since pattern matching really only makes sense with a richer set of types, let's start by adding sums and products to the type language.
   type tp = Arrow of tp * tp | One | Prod of tp * tp | Sum of tp * tp 

Next, we'll give a datatype of patterns. The PWild constructor is the wildcard pattern $\_$, the PVar constructor is a variable pattern, the PPair(p, p') constructor is the pair pattern, and the Inr p and Inl p patterns are the patterns for the left and right injections of the sum type.

   type pat = PWild | PVar | PUnit | PPair of pat * pat 
            | PInl of pat | PInr of pat 

Now, we can give the datatype for the language signature itself. We add pairs, units, and sums to the language, as well as a case expression which takes a scrutinee, and a list of branches (a list of patterns and the corresponding case arms).

   type 'a t = Lam of 'a | App of 'a * 'a | Annot of tp * 'a
             | Unit | Pair of 'a * 'a | Inl of 'a | Inr of 'a
             | Case of 'a * (pat * 'a) list

One thing worth noting is that the datatype of patterns has no binding information in it at all. The basic idea is that if we will represent (say) a pair elimination $\letv{(x,y)}{e}{e'}$ as the constructor Case(e, [PPair(PVar, PVar), Abs(x, Abs(y, e'))]) (suppressing useless Tm constructors). So the PVar constructor is merely an indication that the term be an abstraction, with the number of abstractions determined by the shape of the pattern. This representation is first documented in Rob Simmons's paper Structural Focalization.

This is really the key thing that gives the ABT interface its power: binding trees have only one binding form, and we never introduce any others.

Now we can give the map and join operations for this signature.

 
   let map f = function
     | Lam x        -> Lam (f x)
     | App (x, y)   -> App(f x, f y)
     | Annot(t, x)  -> Annot(t, f x)
     | Unit         -> Unit 
     | Pair(x, y)   -> Pair(f x, f y)
     | Inl x        -> Inl (f x)
     | Inr x        -> Inr (f x)
     | Case(x, pys) -> Case(f x, List.map (fun (p, y) -> (p, f y)) pys)

   let join m = function
     | Lam x -> x
     | App(x, y) -> m.join x y
     | Annot(_, x) -> x
     | Unit        -> m.unit
     | Pair(x,y)   -> m.join x y
     | Inl x       -> x
     | Inr x       -> x
     | Case(x, pys) -> List.fold_right (fun (_, y) -> m.join y) pys x 
 end

As usual, we construct the syntax by applying the Abt functor.

 
 module Syntax = Abt(Lambda)

We can now define a bidirectional typechecker for this language. Much of the infrastructure is the same as in the previous post.

 
 module Bidir = struct
   open Lambda
   open Syntax
  type ctx = (var * tp) list

We do, however, extend the is_check and is_synth functions to handle the new operations in the signature. Note that case statements are viewed as checking forms.

  let is_check = function 
    | Tm (Lam _) | Tm Unit | Tm (Pair(_, _))
    | Tm (Inl _) | Tm (Inr _) | Tm (Case(_, _))-> true
    | _ -> false
  let is_synth e = not(is_check e)

  let unabs e =
    match out e with
    | Abs(x, e) -> (x, e)
    | _ -> assert false

When we reach the typechecker itself, most of it --- the check and synth functions --- are unchanged. We have to add new cases to check to handle the new value forms (injections, pairs, and units), but they are pretty straightforward.

  let rec check ctx e tp =
    match out e, tp with
    | Tm (Lam t), Arrow(tp1, tp') ->
      let (x, e') = unabs t in
      check ((x, tp1) :: ctx) e' tp'
    | Tm (Lam _), _               -> failwith "expected arrow type"
    | Tm Unit, One  -> ()
    | Tm Unit, _ -> failwith "expected unit type"
    | Tm (Pair(e, e')), Prod(tp, tp') ->
      let () = check ctx e tp in
      let () = check ctx e' tp' in
      ()
    | Tm (Pair(_, _)), _ -> failwith "expected product type"
    | Tm (Inl e), Sum(tp, _) -> check ctx e tp
    | Tm (Inr e), Sum(_, tp) -> check ctx e tp
    | Tm (Inl _), _
    | Tm (Inr _), _          -> failwith "expected sum type"

The big difference is in checking the Case form. Now, we need to synthesize a type for the scrutinee, and then check that each branch is well-typed. This works by calling the check_branch function on each branch, passing it the pattern, arm, type and result type as arguments. For reasons that will become apparent, the pattern and its type are passed in as singleton lists. (I don't do coverage checking here, only because it doesn't interact with binding in any way.)

    | Tm (Case(e, branches)), tp ->
      let tp' = synth ctx e in
      List.iter (fun (p,e) -> check_branch ctx [p] e [tp'] tp) branches 
    | body, _ when is_synth body ->
      if tp = synth ctx e then () else failwith "Type mismatch"
    | _ -> assert false

  and synth ctx e =
    match out e with
    | Var x -> (try List.assoc x ctx with Not_found -> failwith "unbound variable")
    | Tm(Annot(tp, e)) -> let () = check ctx e tp in tp 
    | Tm(App(f, e))  ->
      (match synth ctx f with
       | Arrow(tp, tp') -> let () = check ctx e tp in tp'
       | _ -> failwith "Applying a non-function!")
    | body when is_check body -> failwith "Cannot synthesize type for checking term"
    | _ -> assert false
The way that branch checking works is by steadily deconstructing a list of patterns (and their types) into smaller lists.
  and check_branch ctx ps e tps tp_result =
    match ps, tps with

If there are no more patterns and their types, we are done, and can check that the arm has the right type.

    | [], []-> check ctx e tp_result

If we have a variable pattern, we unabstract the arm, and bind that variable to type, and recur on the smaller list of patterns and types. Note that this is the only place we have to do anything at all with binding, and it's trivial!

    | (PVar :: ps), (tp :: tps)
      -> let (x, e) = unabs e in
         check_branch ((x, tp) :: ctx) ps e tps tp_result

Wildcards and unit patterns work the same way, except that they don't bind anything.

    | (PWild :: ps), (tp :: tps)
      -> check_branch ctx ps e tps tp_result
    | (PUnit :: ps), (One :: tps)
      -> check_branch ctx ps e tps tp_result
    | (PUnit :: ps), (tp :: tps)
      -> failwith "expected term of unit type"

Pair patterns are deconstructed into two smaller types, and the product type they are checked against is broken into its two subterms, and then the list is lengthened with all of the subderivations.

    | (PPair(p, p') :: ps), (Prod(tp, tp') :: tps) 
      -> check_branch ctx (p :: p' :: ps) e (tp :: tp' :: tps) tp_result
    | (PPair(p, p') :: ps), (_ :: tps) 
      -> failwith "expected term of product type"

Sum patterns work by recurring on the sub-pattern, dropping the left or right part of the sum type, as appropriate.

    | (PInl p :: ps), (Sum(tp, _) :: tps)
      -> check_branch ctx (p :: ps) e (tp :: tps) tp_result
    | (PInl p :: ps), (_ :: tps)
      -> failwith "expected sum type"
    | (PInr p :: ps), (Sum(_, tp) :: tps)
      -> check_branch ctx (p :: ps) e (tp :: tps) tp_result
    | (PInr p :: ps), (_ :: tps)
      -> failwith "expected sum type"
    | _ -> assert false
end

That's it! So I hope it's clear that ABTs can handle complex binding forms very gracefully.