Posted on June 17, 2016

Let’s start from building a simple stack-based calculator using object-oriented approach. The calculator is represented by a stack of integers (`List[Int]`) and certain operations which modify the stack and can return a result.

``````class Calc(var stack: List[Int]) {
def push(x: Int): Unit = {
stack = x::stack
}

def pop(): Int = {
val x::tail = stack
stack = tail
x
}

// s0 + s1
val a = pop()
val b = pop()
a + b
}

// s0 * s1
def mul(): Int = {
val a = pop()
val b = pop()
a * b
}
}

val c = new Calc(List[Int]())
c.push(2)
c.push(3)

Looks good so far, now I want to add a slightly more complex logic, I want to compute `(a * b + c * d)`, where `a, b, c, d` are the top 4 values in my stack. I have two options, one is to add the method to `Calc` class like this:

``````  // s0 * s1 + s2 * s3
def m1(): Int = {
val ab = mul()
val cd = mul()
push(ab)
push(cd)
}``````

but I don’t want to pollute my clean and nice `Calc` class as this method can be specific to some part of my program, instead I want to create some temporary method (preferrably a function), so I can pass it around and combine with other functions:

``````object Calc {
import Calc._

// s0 * s1 + s2 * s3
val m1: Calc => Int =
calc => {
val ab = calc.mul()
val cd = calc.mul()
calc.push(ab)
calc.push(cd)
}``````

Later I decided that I would need `(a * b + c * d) * e` too, and since I want to reuse my work I simply added more methods like the above. I also added `m0` to push 4 values to the stack in one shot:

``````object Calc {
import Calc._

// (a, b, c, d) => (s0, s1, s2, s3)
val m0: (Calc, Int, Int, Int, Int) => Int =
(calc, a, b, c, d) => {
calc.push(d)
calc.push(c)
calc.push(b)
calc.push(a)
0
}

// s0 * s1 + s2 * s3
val m1: Calc => Int =
calc => {
val ab = calc.mul()
val cd = calc.mul()
calc.push(ab)
calc.push(cd)
}

// x * e
val m2: (Calc, Int, Int) => Int =
(calc, x, e) => {
calc.push(x)
calc.push(e)
calc.mul()
}

// (a * b + c * d) * e
val m3: (Calc, Int, Int, Int, Int, Int) => Int =
(calc, a, b, c, d, e) => {
m0(calc, a, b, c, d)
val r1 = m1(calc)
m2(calc, r1, e)
}

}

Calc.m3(new Calc(List[Int]()), 1, 2, 3, 4, 5) //> 70``````

Not too bad, but there is some annoying boilerplate, I need to repeat `calc.` every time I want to use `Calc` methods and explicitely lug state around. I would be great if I could only use method names and skip state alltogether. Let’s do some refactoring:

``````object Calc {
import Calc._

// (a, b, c, d) => (s0, s1, s2, s3)
val m0: (Int, Int, Int, Int) => (Calc => Int) =
(a, b, c, d) =>
calc => {
calc.push(d)
calc.push(c)
calc.push(b)
calc.push(a)
0
}

// (s0 * s1 + s2 * s3)
val m1: () => (Calc => Int) =
() =>
calc => {
val ab = calc.mul()
val cd = calc.mul()
calc.push(ab)
calc.push(cd)
}

// x * e
val m2: (Int, Int) => (Calc => Int) =
(x, e) =>
calc => {
calc.push(x)
calc.push(e)
calc.mul()
}

// (a * b + c * d) * e
val m3: (Int, Int, Int, Int, Int) => (Calc => Int) =
(a, b, c, d, e) =>
calc => {
m0(a, b, c, d)(calc)
val r1 = m1()(calc)
m2(r1, e)(calc)
}
}

Calc.m3(1, 2, 3, 4, 5)(new Calc(List[Int]())) //> 70``````

Not much better, but at least state is passed separately now, so maybe there is a chance to conceal it. The goal is to be able to write something like this:

``````  val m3: (Int, Int, Int, Int, Int) => ??? = {
(a, b, c, d, e) =>
bind( m0(a, b, c, d), { _ =>
bind( m1(), { r1 =>
m2(r1, e)
})
})
}``````

Looking at the last refactoring it seems that `???` could be `Calc => Int`. Let’s try to define `bind` as:

``````  def bind(f1: Calc => Int, f2: Int => (Calc => Int)): Calc => Int =
calc => {
val r1 = f1(calc)
f2(r1)(calc)
}``````

Now that we defined `bind` we can write:

``````  val m3: (Int, Int, Int, Int, Int) => (Calc => Int) =
(a, b, c, d, e) => {
bind(m0(a, b, c, d), { _ =>
bind(m1(), { r1 =>
m2(r1, e)
})
})
}``````

Still ugly, but state is concealed at least. This notation can be improved by using Scala `for` comprehensions. Only `for` comprehensions require `flatMap` and `map` to be defined as methods of `Calc => Int`, which we can’t do. The solution is to create a wrapper:

``````  case class State[S, A](runState: S => A) {
def flatMap(f: A => State[S, A]): State[S, A] = {
State(
s => {
val r = runState(s)
f(r).runState(s)
}
)
}

def map(f: A => A): State[S, A] = {
State(
s => {
val r = runState(s)
f(r)
}
)
}
}

val m3: (Int, Int, Int, Int, Int) => State[Calc, Int] =
(a, b, c, d, e) => {
State(m0(a, b, c, d)).flatMap { _ =>
State(m1()).flatMap { r =>
State(m2(r, e))
}
}
}``````

Which can also be written as:

``````
...

val m3: (Int, Int, Int, Int, Int) => State[Calc, Int] =
(a, b, c, d, e) => {
for {
_ <-  State(m0(a, b, c, d))
r1 <- State(m1())
r2 <- State(m2(r1, e))
} yield(r2)
}``````

This looks much better. Time to get rid of `State` here, for which we need to change `m0 .. m2` so they return a `State` object instead of `Calc => Int`:

``````object Calc {
import Calc._

// (a, b, c, d) => (s0, s1, s2, s3)
val m0: (Int, Int, Int, Int) => State[Calc, Int] =
(a, b, c, d) =>
State(
calc => {
calc.push(d)
calc.push(c)
calc.push(b)
calc.push(a)
0
}
)

// s0 * s1 + s2 * s3
val m1: () => State[Calc, Int] =
() =>
State(
calc => {
val ab = calc.mul()
val cd = calc.mul()
calc.push(ab)
calc.push(cd)
}
)

// e * s0
val m2: (Int, Int) => State[Calc, Int] =
(x, e) =>
State(
calc => {
calc.push(x)
calc.push(e)
calc.mul()
}
)

// (a * b + c * d) * e
val m3: (Int, Int, Int, Int, Int) => State[Calc, Int] =
(a, b, c, d, e) => {
for {
_ <-  m0(a, b, c, d)
r1 <- m1()
r2 <- m2(r1, e)
} yield(r2)
}
}

Calc.m3(1, 2, 3, 4, 5).runState(new Calc(List[Int]())) //> 70``````

At this point it’s easy to notice that we drag mutable state around and there is no particular reason why it should be mutable, we can always update state and pass a new state forward. Also methods `push`, `pop`, `add`, `mul` don’t look special anymore, they are not that different from `m0`, `m1`, `m2`, `m3`. So let’s create an improved version of class `State` and also implement those 4 methods. We can completely abandon the class `Calc` then since there is nothing left there except for `List[Int]`:

``````object Calc {

case class State[S, A](runState: S => (A, S)) {
def flatMap[B](f: A => State[S, B]): State[S, B] = {
State(
s => {
val (r, newState) = runState(s)
f(r).runState(newState)
}
)
}

def map[B](f: A => B): State[S, B] =
flatMap { x => State.unit(f(x)) }

def getResult(s: S) = runState(s)._1
}

object State {
def unit[S, A](x: A): State[S, A] =
State((x, _))
}

val push: Int => State[List[Int], Unit] =
x => State(s => ((), x::s))

val pop: () => State[List[Int], Int] =
() => State { case (x::xs) => (x, xs) }

val add: () => State[List[Int], Int] =
() => State { case (a::b::xs) => (a + b, xs) }

val mul: () => State[List[Int], Int] =
() => State { case (a::b::xs) => (a * b, xs) }

// (a, b, c, d) => (s0, s1, s2, s3)
val m0: (Int, Int, Int, Int) => State[List[Int], Unit] =
(a, b, c, d) =>
for {
_ <- push(d)
_ <- push(c)
_ <- push(b)
_ <- push(a)
} yield(())

// s0 * s1 + s2 * s3
val m1: () => State[List[Int], Int] =
() =>
for {
ab <- mul()
cd <- mul()
_ <- push(ab)
_ <- push(cd)
} yield(s)

// e * s0
val m2: (Int, Int) => State[List[Int], Int] =
(x, e) =>
for {
_ <- push(x)
_ <- push(e)
p <- mul()
} yield(p)

// (a * b + c * d) * e
val m3: (Int, Int, Int, Int, Int) => State[List[Int], Int] =
(a, b, c, d, e) => {
for {
_ <-  m0(a, b, c, d)
r1 <- m1()
r2 <- m2(r1, e)
} yield(r2)
}
}

Calc.m3(1, 2, 3, 4, 5).getResult(Nil) //> 70``````

Class `State[S, A]` is called State monad in functional programming. Let’s look at the benefits:

1. New methods are easily composable, good for code reuse.

2. We started from an attempt to get rid of a boilerplate:

``````// Before
val m1: () => (Calc => Int) =
() =>
calc => {
val ab = calc.mul()
val cd = calc.mul()
calc.push(ab)
calc.push(cd)
}

// After
val m1: () => State[List[Int], Int] =
() =>
for {
ab <- mul()
cd <- mul()
_ <- push(ab)
_ <- push(cd)
} yield(s)``````

I would argue the latter is more readable and less error prone, repeating `calc.` in the former snippet simply creates a visual noise. So State monad provides a good way to build DSL-like composition.

1. Immutable state. Not sure it brings much here since state is concealed and harder to mess up with.

We can easily apply other monadic concepts here, e.g. `traverse`:

``````  object State {
...

def traverse[S, A, B](la: List[A])(f: A => State[S, B]): State[S, List[B]] =
State(
s => {
val (bs, newS) =
la.foldLeft((collection.mutable.ListBuffer[B](), s)) { case ((bs, s), a) =>
val (r, newS) = f(a).runState(s)
(bs :+ r, newS)
}
(bs.toList, newS)
}
)
}``````

With `traverse` code becomes even shorter:

``````  // Before

val push: Int => State[List[Int], Unit] =
x => State(s => ((), x::s))

val m0: (Int, Int, Int, Int) => State[List[Int], Unit] =
(a, b, c, d) =>
for {
_ <- push(d)
_ <- push(c)
_ <- push(b)
_ <- push(a)
} yield(())

// s0 * s1 + s2 * s3
val m1: () => State[List[Int], Int] =
() =>
for {
ab <- mul()
cd <- mul()
_ <- push(ab)
_ <- push(cd)
} yield(s)

...

// After
val push: (Int*) => State[List[Int], List[Unit]] =
(xs) =>
State.traverse(xs.toList)(x => State(s => ((), x::s)))

// s0 * s1 + s2 * s3
val m1: () => State[List[Int], Int] =
() =>
for {
ab <- mul()
cd <- mul()
_ <- push(ab, cd)