State monad for object-oriented programmers
Let’s start from building a simple stack-based calculator using object-oriented approach. The calculator is represented by a stack of integers (List[Int]
) and certain operations which modify the stack and can return a result.
class Calc(var stack: List[Int]) {
def push(x: Int): Unit = {
stack = x::stack
}
def pop(): Int = {
val x::tail = stack
stack = tail
x
}
// s0 + s1
def add(): Int = {
val a = pop()
val b = pop()
a + b
}
// s0 * s1
def mul(): Int = {
val a = pop()
val b = pop()
a * b
}
}
val c = new Calc(List[Int]())
c.push(2)
c.push(3)
c.add() //> 5
Looks good so far, now I want to add a slightly more complex logic, I want to compute (a * b + c * d)
, where a, b, c, d
are the top 4 values in my stack. I have two options, one is to add the method to Calc
class like this:
but I don’t want to pollute my clean and nice Calc
class as this method can be specific to some part of my program, instead I want to create some temporary method (preferrably a function), so I can pass it around and combine with other functions:
object Calc {
import Calc._
// s0 * s1 + s2 * s3
val m1: Calc => Int =
calc => {
val ab = calc.mul()
val cd = calc.mul()
calc.push(ab)
calc.push(cd)
calc.add()
}
Later I decided that I would need (a * b + c * d) * e
too, and since I want to reuse my work I simply added more methods like the above. I also added m0
to push 4 values to the stack in one shot:
object Calc {
import Calc._
// (a, b, c, d) => (s0, s1, s2, s3)
val m0: (Calc, Int, Int, Int, Int) => Int =
(calc, a, b, c, d) => {
calc.push(d)
calc.push(c)
calc.push(b)
calc.push(a)
0
}
// s0 * s1 + s2 * s3
val m1: Calc => Int =
calc => {
val ab = calc.mul()
val cd = calc.mul()
calc.push(ab)
calc.push(cd)
calc.add()
}
// x * e
val m2: (Calc, Int, Int) => Int =
(calc, x, e) => {
calc.push(x)
calc.push(e)
calc.mul()
}
// (a * b + c * d) * e
val m3: (Calc, Int, Int, Int, Int, Int) => Int =
(calc, a, b, c, d, e) => {
m0(calc, a, b, c, d)
val r1 = m1(calc)
m2(calc, r1, e)
}
}
Calc.m3(new Calc(List[Int]()), 1, 2, 3, 4, 5) //> 70
Not too bad, but there is some annoying boilerplate, I need to repeat calc.
every time I want to use Calc
methods and explicitely lug state around. I would be great if I could only use method names and skip state alltogether. Let’s do some refactoring:
object Calc {
import Calc._
// (a, b, c, d) => (s0, s1, s2, s3)
val m0: (Int, Int, Int, Int) => (Calc => Int) =
(a, b, c, d) =>
calc => {
calc.push(d)
calc.push(c)
calc.push(b)
calc.push(a)
0
}
// (s0 * s1 + s2 * s3)
val m1: () => (Calc => Int) =
() =>
calc => {
val ab = calc.mul()
val cd = calc.mul()
calc.push(ab)
calc.push(cd)
calc.add()
}
// x * e
val m2: (Int, Int) => (Calc => Int) =
(x, e) =>
calc => {
calc.push(x)
calc.push(e)
calc.mul()
}
// (a * b + c * d) * e
val m3: (Int, Int, Int, Int, Int) => (Calc => Int) =
(a, b, c, d, e) =>
calc => {
m0(a, b, c, d)(calc)
val r1 = m1()(calc)
m2(r1, e)(calc)
}
}
Calc.m3(1, 2, 3, 4, 5)(new Calc(List[Int]())) //> 70
Not much better, but at least state is passed separately now, so maybe there is a chance to conceal it. The goal is to be able to write something like this:
val m3: (Int, Int, Int, Int, Int) => ??? = {
(a, b, c, d, e) =>
bind( m0(a, b, c, d), { _ =>
bind( m1(), { r1 =>
m2(r1, e)
})
})
}
Looking at the last refactoring it seems that ???
could be Calc => Int
. Let’s try to define bind
as:
def bind(f1: Calc => Int, f2: Int => (Calc => Int)): Calc => Int =
calc => {
val r1 = f1(calc)
f2(r1)(calc)
}
Now that we defined bind
we can write:
val m3: (Int, Int, Int, Int, Int) => (Calc => Int) =
(a, b, c, d, e) => {
bind(m0(a, b, c, d), { _ =>
bind(m1(), { r1 =>
m2(r1, e)
})
})
}
Still ugly, but state is concealed at least. This notation can be improved by using Scala for
comprehensions. Only for
comprehensions require flatMap
and map
to be defined as methods of Calc => Int
, which we can’t do. The solution is to create a wrapper:
case class State[S, A](runState: S => A) {
def flatMap(f: A => State[S, A]): State[S, A] = {
State(
s => {
val r = runState(s)
f(r).runState(s)
}
)
}
def map(f: A => A): State[S, A] = {
State(
s => {
val r = runState(s)
f(r)
}
)
}
}
val m3: (Int, Int, Int, Int, Int) => State[Calc, Int] =
(a, b, c, d, e) => {
State(m0(a, b, c, d)).flatMap { _ =>
State(m1()).flatMap { r =>
State(m2(r, e))
}
}
}
Which can also be written as:
...
val m3: (Int, Int, Int, Int, Int) => State[Calc, Int] =
(a, b, c, d, e) => {
for {
_ <- State(m0(a, b, c, d))
r1 <- State(m1())
r2 <- State(m2(r1, e))
} yield(r2)
}
This looks much better. Time to get rid of State
here, for which we need to change m0 .. m2
so they return a State
object instead of Calc => Int
:
object Calc {
import Calc._
// (a, b, c, d) => (s0, s1, s2, s3)
val m0: (Int, Int, Int, Int) => State[Calc, Int] =
(a, b, c, d) =>
State(
calc => {
calc.push(d)
calc.push(c)
calc.push(b)
calc.push(a)
0
}
)
// s0 * s1 + s2 * s3
val m1: () => State[Calc, Int] =
() =>
State(
calc => {
val ab = calc.mul()
val cd = calc.mul()
calc.push(ab)
calc.push(cd)
calc.add()
}
)
// e * s0
val m2: (Int, Int) => State[Calc, Int] =
(x, e) =>
State(
calc => {
calc.push(x)
calc.push(e)
calc.mul()
}
)
// (a * b + c * d) * e
val m3: (Int, Int, Int, Int, Int) => State[Calc, Int] =
(a, b, c, d, e) => {
for {
_ <- m0(a, b, c, d)
r1 <- m1()
r2 <- m2(r1, e)
} yield(r2)
}
}
Calc.m3(1, 2, 3, 4, 5).runState(new Calc(List[Int]())) //> 70
At this point it’s easy to notice that we drag mutable state around and there is no particular reason why it should be mutable, we can always update state and pass a new state forward. Also methods push
, pop
, add
, mul
don’t look special anymore, they are not that different from m0
, m1
, m2
, m3
. So let’s create an improved version of class State
and also implement those 4 methods. We can completely abandon the class Calc
then since there is nothing left there except for List[Int]
:
object Calc {
case class State[S, A](runState: S => (A, S)) {
def flatMap[B](f: A => State[S, B]): State[S, B] = {
State(
s => {
val (r, newState) = runState(s)
f(r).runState(newState)
}
)
}
def map[B](f: A => B): State[S, B] =
flatMap { x => State.unit(f(x)) }
def getResult(s: S) = runState(s)._1
}
object State {
def unit[S, A](x: A): State[S, A] =
State((x, _))
}
val push: Int => State[List[Int], Unit] =
x => State(s => ((), x::s))
val pop: () => State[List[Int], Int] =
() => State { case (x::xs) => (x, xs) }
val add: () => State[List[Int], Int] =
() => State { case (a::b::xs) => (a + b, xs) }
val mul: () => State[List[Int], Int] =
() => State { case (a::b::xs) => (a * b, xs) }
// (a, b, c, d) => (s0, s1, s2, s3)
val m0: (Int, Int, Int, Int) => State[List[Int], Unit] =
(a, b, c, d) =>
for {
_ <- push(d)
_ <- push(c)
_ <- push(b)
_ <- push(a)
} yield(())
// s0 * s1 + s2 * s3
val m1: () => State[List[Int], Int] =
() =>
for {
ab <- mul()
cd <- mul()
_ <- push(ab)
_ <- push(cd)
s <- add()
} yield(s)
// e * s0
val m2: (Int, Int) => State[List[Int], Int] =
(x, e) =>
for {
_ <- push(x)
_ <- push(e)
p <- mul()
} yield(p)
// (a * b + c * d) * e
val m3: (Int, Int, Int, Int, Int) => State[List[Int], Int] =
(a, b, c, d, e) => {
for {
_ <- m0(a, b, c, d)
r1 <- m1()
r2 <- m2(r1, e)
} yield(r2)
}
}
Calc.m3(1, 2, 3, 4, 5).getResult(Nil) //> 70
Class State[S, A]
is called State monad in functional programming. Let’s look at the benefits:
New methods are easily composable, good for code reuse.
We started from an attempt to get rid of a boilerplate:
// Before
val m1: () => (Calc => Int) =
() =>
calc => {
val ab = calc.mul()
val cd = calc.mul()
calc.push(ab)
calc.push(cd)
calc.add()
}
// After
val m1: () => State[List[Int], Int] =
() =>
for {
ab <- mul()
cd <- mul()
_ <- push(ab)
_ <- push(cd)
s <- add()
} yield(s)
I would argue the latter is more readable and less error prone, repeating calc.
in the former snippet simply creates a visual noise. So State monad provides a good way to build DSL-like composition.
- Immutable state. Not sure it brings much here since state is concealed and harder to mess up with.
We can easily apply other monadic concepts here, e.g. traverse
:
object State {
...
def traverse[S, A, B](la: List[A])(f: A => State[S, B]): State[S, List[B]] =
State(
s => {
val (bs, newS) =
la.foldLeft((collection.mutable.ListBuffer[B](), s)) { case ((bs, s), a) =>
val (r, newS) = f(a).runState(s)
(bs :+ r, newS)
}
(bs.toList, newS)
}
)
}
With traverse
code becomes even shorter:
// Before
val push: Int => State[List[Int], Unit] =
x => State(s => ((), x::s))
val m0: (Int, Int, Int, Int) => State[List[Int], Unit] =
(a, b, c, d) =>
for {
_ <- push(d)
_ <- push(c)
_ <- push(b)
_ <- push(a)
} yield(())
// s0 * s1 + s2 * s3
val m1: () => State[List[Int], Int] =
() =>
for {
ab <- mul()
cd <- mul()
_ <- push(ab)
_ <- push(cd)
s <- add()
} yield(s)
...
// After
val push: (Int*) => State[List[Int], List[Unit]] =
(xs) =>
State.traverse(xs.toList)(x => State(s => ((), x::s)))
// s0 * s1 + s2 * s3
val m1: () => State[List[Int], Int] =
() =>
for {
ab <- mul()
cd <- mul()
_ <- push(ab, cd)
s <- add()
} yield(s)