State monad for object-oriented programmers

Posted on June 17, 2016

Let’s start from building a simple stack-based calculator using object-oriented approach. The calculator is represented by a stack of integers (List[Int]) and certain operations which modify the stack and can return a result.

class Calc(var stack: List[Int]) {
  def push(x: Int): Unit = {
    stack = x::stack
  }

  def pop(): Int = {
    val x::tail = stack
    stack = tail
    x
  }

  // s0 + s1
  def add(): Int = {
    val a = pop()
    val b = pop()
    a + b
  }

  // s0 * s1
  def mul(): Int = {
    val a = pop()
    val b = pop()
    a * b
  }
}

val c = new Calc(List[Int]())
c.push(2)
c.push(3)
c.add() //> 5

Looks good so far, now I want to add a slightly more complex logic, I want to compute (a * b + c * d), where a, b, c, d are the top 4 values in my stack. I have two options, one is to add the method to Calc class like this:

  // s0 * s1 + s2 * s3
  def m1(): Int = {
    val ab = mul()
    val cd = mul()
    push(ab)
    push(cd)
    add()
  }

but I don’t want to pollute my clean and nice Calc class as this method can be specific to some part of my program, instead I want to create some temporary method (preferrably a function), so I can pass it around and combine with other functions:

object Calc {
  import Calc._

  // s0 * s1 + s2 * s3
  val m1: Calc => Int = 
    calc => {
      val ab = calc.mul()
      val cd = calc.mul()
      calc.push(ab)
      calc.push(cd)
      calc.add()
    }

Later I decided that I would need (a * b + c * d) * e too, and since I want to reuse my work I simply added more methods like the above. I also added m0 to push 4 values to the stack in one shot:

object Calc {
  import Calc._

  // (a, b, c, d) => (s0, s1, s2, s3)
  val m0: (Calc, Int, Int, Int, Int) => Int =
    (calc, a, b, c, d) => {
      calc.push(d)
      calc.push(c)
      calc.push(b)
      calc.push(a)
      0
    }

  // s0 * s1 + s2 * s3
  val m1: Calc => Int = 
    calc => {
      val ab = calc.mul()
      val cd = calc.mul()
      calc.push(ab)
      calc.push(cd)
      calc.add()
    }

  // x * e 
  val m2: (Calc, Int, Int) => Int =
    (calc, x, e) => {
      calc.push(x)
      calc.push(e)
      calc.mul()
    }

  // (a * b + c * d) * e
  val m3: (Calc, Int, Int, Int, Int, Int) => Int =
    (calc, a, b, c, d, e) => {
      m0(calc, a, b, c, d)
      val r1 = m1(calc)
      m2(calc, r1, e)
    }

}

Calc.m3(new Calc(List[Int]()), 1, 2, 3, 4, 5) //> 70

Not too bad, but there is some annoying boilerplate, I need to repeat calc. every time I want to use Calc methods and explicitely lug state around. I would be great if I could only use method names and skip state alltogether. Let’s do some refactoring:

object Calc {
  import Calc._

  // (a, b, c, d) => (s0, s1, s2, s3)
  val m0: (Int, Int, Int, Int) => (Calc => Int) =
    (a, b, c, d) =>
      calc => {
        calc.push(d)
        calc.push(c)
        calc.push(b)
        calc.push(a)
        0
      }

  // (s0 * s1 + s2 * s3)
  val m1: () => (Calc => Int) = 
    () =>
      calc => {
        val ab = calc.mul()
        val cd = calc.mul()
        calc.push(ab)
        calc.push(cd)
        calc.add()
      }

  // x * e
  val m2: (Int, Int) => (Calc => Int) =
    (x, e) =>
      calc => {
        calc.push(x)
        calc.push(e)
        calc.mul()
      }

  // (a * b + c * d) * e
  val m3: (Int, Int, Int, Int, Int) => (Calc => Int) =
    (a, b, c, d, e) =>
      calc => {
        m0(a, b, c, d)(calc)
        val r1 = m1()(calc)
        m2(r1, e)(calc)
      }
}

Calc.m3(1, 2, 3, 4, 5)(new Calc(List[Int]())) //> 70

Not much better, but at least state is passed separately now, so maybe there is a chance to conceal it. The goal is to be able to write something like this:

  val m3: (Int, Int, Int, Int, Int) => ??? = {
    (a, b, c, d, e) =>
      bind( m0(a, b, c, d), { _ =>
        bind( m1(), { r1 => 
          m2(r1, e)
        })
      })
  }

Looking at the last refactoring it seems that ??? could be Calc => Int. Let’s try to define bind as:

  def bind(f1: Calc => Int, f2: Int => (Calc => Int)): Calc => Int = 
    calc => {
      val r1 = f1(calc)
      f2(r1)(calc)
    }

Now that we defined bind we can write:

  val m3: (Int, Int, Int, Int, Int) => (Calc => Int) = 
    (a, b, c, d, e) => {
      bind(m0(a, b, c, d), { _ =>
        bind(m1(), { r1 => 
          m2(r1, e)
        })
      })
    }

Still ugly, but state is concealed at least. This notation can be improved by using Scala for comprehensions. Only for comprehensions require flatMap and map to be defined as methods of Calc => Int, which we can’t do. The solution is to create a wrapper:

  case class State[S, A](runState: S => A) {
    def flatMap(f: A => State[S, A]): State[S, A] = {
      State(
        s => {
          val r = runState(s)
          f(r).runState(s)
        }
      )
    }

    def map(f: A => A): State[S, A] = {
      State(
        s => {
          val r = runState(s)
          f(r)
        }
      )
    }
  }

  val m3: (Int, Int, Int, Int, Int) => State[Calc, Int] = 
    (a, b, c, d, e) => {
      State(m0(a, b, c, d)).flatMap { _ =>
        State(m1()).flatMap { r => 
          State(m2(r, e))
        }
      }
    }

Which can also be written as:


  ...

  val m3: (Int, Int, Int, Int, Int) => State[Calc, Int] =
    (a, b, c, d, e) => {
      for {
        _ <-  State(m0(a, b, c, d))
        r1 <- State(m1())
        r2 <- State(m2(r1, e))
      } yield(r2)
    }

This looks much better. Time to get rid of State here, for which we need to change m0 .. m2 so they return a State object instead of Calc => Int:

object Calc {
  import Calc._

  // (a, b, c, d) => (s0, s1, s2, s3)
  val m0: (Int, Int, Int, Int) => State[Calc, Int] =
    (a, b, c, d) =>
      State(
        calc => {
          calc.push(d)
          calc.push(c)
          calc.push(b)
          calc.push(a)
          0
        }
      )

  // s0 * s1 + s2 * s3
  val m1: () => State[Calc, Int] = 
    () =>
      State(
        calc => {
          val ab = calc.mul()
          val cd = calc.mul()
          calc.push(ab)
          calc.push(cd)
          calc.add()
        }
    )

  // e * s0
  val m2: (Int, Int) => State[Calc, Int] =
    (x, e) =>
      State(
        calc => {
          calc.push(x)
          calc.push(e)
          calc.mul()
        }
      )

  // (a * b + c * d) * e
  val m3: (Int, Int, Int, Int, Int) => State[Calc, Int] =
    (a, b, c, d, e) => {
      for {
        _ <-  m0(a, b, c, d)
        r1 <- m1()
        r2 <- m2(r1, e)
      } yield(r2)
    }
}

Calc.m3(1, 2, 3, 4, 5).runState(new Calc(List[Int]())) //> 70

At this point it’s easy to notice that we drag mutable state around and there is no particular reason why it should be mutable, we can always update state and pass a new state forward. Also methods push, pop, add, mul don’t look special anymore, they are not that different from m0, m1, m2, m3. So let’s create an improved version of class State and also implement those 4 methods. We can completely abandon the class Calc then since there is nothing left there except for List[Int]:

object Calc {

  case class State[S, A](runState: S => (A, S)) {
    def flatMap[B](f: A => State[S, B]): State[S, B] = {
      State(
        s => {
          val (r, newState) = runState(s)
          f(r).runState(newState)
        }
      )
    }

    def map[B](f: A => B): State[S, B] =
      flatMap { x => State.unit(f(x)) }

    def getResult(s: S) = runState(s)._1
  }

  object State {
    def unit[S, A](x: A): State[S, A] =
      State((x, _))
  }

  val push: Int => State[List[Int], Unit] = 
    x => State(s => ((), x::s))

  val pop: () => State[List[Int], Int] =
    () => State { case (x::xs) => (x, xs) }

  val add: () => State[List[Int], Int] = 
    () => State { case (a::b::xs) => (a + b, xs) }

  val mul: () => State[List[Int], Int] = 
    () => State { case (a::b::xs) => (a * b, xs) }

  // (a, b, c, d) => (s0, s1, s2, s3)
  val m0: (Int, Int, Int, Int) => State[List[Int], Unit] =
    (a, b, c, d) =>
      for {
        _ <- push(d)
        _ <- push(c)
        _ <- push(b)
        _ <- push(a)
      } yield(())

  // s0 * s1 + s2 * s3
  val m1: () => State[List[Int], Int] =
    () =>
      for {
        ab <- mul()
        cd <- mul()
        _ <- push(ab)
        _ <- push(cd)
        s <- add()
      } yield(s)

  // e * s0
  val m2: (Int, Int) => State[List[Int], Int] =
    (x, e) =>
      for {
        _ <- push(x)
        _ <- push(e)
        p <- mul()
      } yield(p)

  // (a * b + c * d) * e
  val m3: (Int, Int, Int, Int, Int) => State[List[Int], Int] =
    (a, b, c, d, e) => {
      for {
        _ <-  m0(a, b, c, d)
        r1 <- m1()
        r2 <- m2(r1, e)
      } yield(r2)
    }
}

Calc.m3(1, 2, 3, 4, 5).getResult(Nil) //> 70

Class State[S, A] is called State monad in functional programming. Let’s look at the benefits:

  1. New methods are easily composable, good for code reuse.

  2. We started from an attempt to get rid of a boilerplate:

// Before
val m1: () => (Calc => Int) = 
  () =>
    calc => {
      val ab = calc.mul()
      val cd = calc.mul()
      calc.push(ab)
      calc.push(cd)
      calc.add()
    }

// After
val m1: () => State[List[Int], Int] =
  () =>
    for {
      ab <- mul()
      cd <- mul()
      _ <- push(ab)
      _ <- push(cd)
      s <- add()
    } yield(s)

I would argue the latter is more readable and less error prone, repeating calc. in the former snippet simply creates a visual noise. So State monad provides a good way to build DSL-like composition.

  1. Immutable state. Not sure it brings much here since state is concealed and harder to mess up with.

We can easily apply other monadic concepts here, e.g. traverse:

  object State {
    ...

    def traverse[S, A, B](la: List[A])(f: A => State[S, B]): State[S, List[B]] =
      State(
        s => {
          val (bs, newS) =
            la.foldLeft((collection.mutable.ListBuffer[B](), s)) { case ((bs, s), a) => 
              val (r, newS) = f(a).runState(s)
              (bs :+ r, newS)
            }
          (bs.toList, newS)
        }
      )
  }

With traverse code becomes even shorter:

  // Before

  val push: Int => State[List[Int], Unit] = 
    x => State(s => ((), x::s))

  val m0: (Int, Int, Int, Int) => State[List[Int], Unit] =
    (a, b, c, d) =>
      for {
        _ <- push(d)
        _ <- push(c)
        _ <- push(b)
        _ <- push(a)
      } yield(())

  // s0 * s1 + s2 * s3
  val m1: () => State[List[Int], Int] =
    () =>
      for {
        ab <- mul()
        cd <- mul()
        _ <- push(ab)
        _ <- push(cd)
        s <- add()
      } yield(s)

  ...


  // After
  val push: (Int*) => State[List[Int], List[Unit]] = 
    (xs) =>
      State.traverse(xs.toList)(x => State(s => ((), x::s)))

  // s0 * s1 + s2 * s3
  val m1: () => State[List[Int], Int] =
    () =>
      for {
        ab <- mul()
        cd <- mul()
        _ <- push(ab, cd)
        s <- add()
      } yield(s)