Generic recursion in Ruby with trampolines and thunks

Posted on April 18, 2015

Original idea is here.

class TailRec
  def result
    next_r = self
    loop do
      next_r =
        case next_r
        when Done
          break next_r.a
        when Call
          next_r.rest
        when Cont
          f = next_r.f
          case (a = next_r.a)
          when Done
            f[a.a]
          when Call
            a.rest.flat_map(&f)
          when Cont
            # never happens?
            a.a.flatMap { |x| a.f[x].flatMap(&f) }
          end
        end
    end
  end
end

class Done < TailRec
  attr_reader :a

  def initialize(a)
    @a = a
  end

  def flat_map
    Call { yield(@a) }
  end
end

def Done(a)
  Done.new(a)
end

class Call < TailRec
  def initialize
    @rest = -> { yield }
  end

  def rest
    @rest[]
  end

  def flat_map
    Cont(self) { |x| yield(x) }
  end
end

def Call
  Call.new { yield }
end

class Cont < TailRec
  attr_reader :a, :f

  def initialize(a)
    @a, @f = a, -> (x) { yield(x) }
  end

  def flat_map
    Cont(@a) { |x| @f[x].flat_map { |y| yield(y) } }
  end
end

def Cont(a)
  Cont.new(a) { |x| yield(x) }
end

Now recursion becomes so much easier:

def fact(n)
  if n > 0
    Call { fact(n - 1) }.flat_map { |x|
      Done(x * n)
    }
  else
    Done(1)
  end
end

fact(10000).result
#=> 28462596809170545189064132...

def fib(n)
  if n < 2
    Done(n)
  else
    Call { fib(n - 1) }.flat_map { |x|
      Call { fib(n - 2) }.flat_map { |y|
        Done(x + y)
      }
    }
  end
end

(0..9).map { |i| fib(i).result }
#=> [0, 1, 1, 2, 3, 5, 8, 13, 21, 34]