### State monad for object-oriented programmers

Let’s start from building a simple stack-based calculator using object-oriented approach. The calculator is represented by a stack of integers (`List[Int]`

) and certain operations which modify the stack and can return a result.

```
class Calc(var stack: List[Int]) {
def push(x: Int): Unit = {
stack = x::stack
}
def pop(): Int = {
val x::tail = stack
stack = tail
x
}
// s0 + s1
def add(): Int = {
val a = pop()
val b = pop()
a + b
}
// s0 * s1
def mul(): Int = {
val a = pop()
val b = pop()
a * b
}
}
val c = new Calc(List[Int]())
c.push(2)
c.push(3)
c.add() //> 5
```

Looks good so far, now I want to add a slightly more complex logic, I want to compute `(a * b + c * d)`

, where `a, b, c, d`

are the top 4 values in my stack. I have two options, one is to add the method to `Calc`

class like this:

but I don’t want to pollute my clean and nice `Calc`

class as this method can be specific to some part of my program, instead I want to create some temporary method (preferrably a function), so I can pass it around and combine with other functions:

```
object Calc {
import Calc._
// s0 * s1 + s2 * s3
val m1: Calc => Int =
calc => {
val ab = calc.mul()
val cd = calc.mul()
calc.push(ab)
calc.push(cd)
calc.add()
}
```

Later I decided that I would need `(a * b + c * d) * e`

too, and since I want to reuse my work I simply added more methods like the above. I also added `m0`

to push 4 values to the stack in one shot:

```
object Calc {
import Calc._
// (a, b, c, d) => (s0, s1, s2, s3)
val m0: (Calc, Int, Int, Int, Int) => Int =
(calc, a, b, c, d) => {
calc.push(d)
calc.push(c)
calc.push(b)
calc.push(a)
0
}
// s0 * s1 + s2 * s3
val m1: Calc => Int =
calc => {
val ab = calc.mul()
val cd = calc.mul()
calc.push(ab)
calc.push(cd)
calc.add()
}
// x * e
val m2: (Calc, Int, Int) => Int =
(calc, x, e) => {
calc.push(x)
calc.push(e)
calc.mul()
}
// (a * b + c * d) * e
val m3: (Calc, Int, Int, Int, Int, Int) => Int =
(calc, a, b, c, d, e) => {
m0(calc, a, b, c, d)
val r1 = m1(calc)
m2(calc, r1, e)
}
}
Calc.m3(new Calc(List[Int]()), 1, 2, 3, 4, 5) //> 70
```

Not too bad, but there is some annoying boilerplate, I need to repeat `calc.`

every time I want to use `Calc`

methods and explicitely lug state around. I would be great if I could only use method names and skip state alltogether. Let’s do some refactoring:

```
object Calc {
import Calc._
// (a, b, c, d) => (s0, s1, s2, s3)
val m0: (Int, Int, Int, Int) => (Calc => Int) =
(a, b, c, d) =>
calc => {
calc.push(d)
calc.push(c)
calc.push(b)
calc.push(a)
0
}
// (s0 * s1 + s2 * s3)
val m1: () => (Calc => Int) =
() =>
calc => {
val ab = calc.mul()
val cd = calc.mul()
calc.push(ab)
calc.push(cd)
calc.add()
}
// x * e
val m2: (Int, Int) => (Calc => Int) =
(x, e) =>
calc => {
calc.push(x)
calc.push(e)
calc.mul()
}
// (a * b + c * d) * e
val m3: (Int, Int, Int, Int, Int) => (Calc => Int) =
(a, b, c, d, e) =>
calc => {
m0(a, b, c, d)(calc)
val r1 = m1()(calc)
m2(r1, e)(calc)
}
}
Calc.m3(1, 2, 3, 4, 5)(new Calc(List[Int]())) //> 70
```

Not much better, but at least state is passed separately now, so maybe there is a chance to conceal it. The goal is to be able to write something like this:

```
val m3: (Int, Int, Int, Int, Int) => ??? = {
(a, b, c, d, e) =>
bind( m0(a, b, c, d), { _ =>
bind( m1(), { r1 =>
m2(r1, e)
})
})
}
```

Looking at the last refactoring it seems that `???`

could be `Calc => Int`

. Let’s try to define `bind`

as:

```
def bind(f1: Calc => Int, f2: Int => (Calc => Int)): Calc => Int =
calc => {
val r1 = f1(calc)
f2(r1)(calc)
}
```

Now that we defined `bind`

we can write:

```
val m3: (Int, Int, Int, Int, Int) => (Calc => Int) =
(a, b, c, d, e) => {
bind(m0(a, b, c, d), { _ =>
bind(m1(), { r1 =>
m2(r1, e)
})
})
}
```

Still ugly, but state is concealed at least. This notation can be improved by using Scala `for`

comprehensions. Only `for`

comprehensions require `flatMap`

and `map`

to be defined as methods of `Calc => Int`

, which we can’t do. The solution is to create a wrapper:

```
case class State[S, A](runState: S => A) {
def flatMap(f: A => State[S, A]): State[S, A] = {
State(
s => {
val r = runState(s)
f(r).runState(s)
}
)
}
def map(f: A => A): State[S, A] = {
State(
s => {
val r = runState(s)
f(r)
}
)
}
}
val m3: (Int, Int, Int, Int, Int) => State[Calc, Int] =
(a, b, c, d, e) => {
State(m0(a, b, c, d)).flatMap { _ =>
State(m1()).flatMap { r =>
State(m2(r, e))
}
}
}
```

Which can also be written as:

```
...
val m3: (Int, Int, Int, Int, Int) => State[Calc, Int] =
(a, b, c, d, e) => {
for {
_ <- State(m0(a, b, c, d))
r1 <- State(m1())
r2 <- State(m2(r1, e))
} yield(r2)
}
```

This looks much better. Time to get rid of `State`

here, for which we need to change `m0 .. m2`

so they return a `State`

object instead of `Calc => Int`

:

```
object Calc {
import Calc._
// (a, b, c, d) => (s0, s1, s2, s3)
val m0: (Int, Int, Int, Int) => State[Calc, Int] =
(a, b, c, d) =>
State(
calc => {
calc.push(d)
calc.push(c)
calc.push(b)
calc.push(a)
0
}
)
// s0 * s1 + s2 * s3
val m1: () => State[Calc, Int] =
() =>
State(
calc => {
val ab = calc.mul()
val cd = calc.mul()
calc.push(ab)
calc.push(cd)
calc.add()
}
)
// e * s0
val m2: (Int, Int) => State[Calc, Int] =
(x, e) =>
State(
calc => {
calc.push(x)
calc.push(e)
calc.mul()
}
)
// (a * b + c * d) * e
val m3: (Int, Int, Int, Int, Int) => State[Calc, Int] =
(a, b, c, d, e) => {
for {
_ <- m0(a, b, c, d)
r1 <- m1()
r2 <- m2(r1, e)
} yield(r2)
}
}
Calc.m3(1, 2, 3, 4, 5).runState(new Calc(List[Int]())) //> 70
```

At this point it’s easy to notice that we drag mutable state around and there is no particular reason why it should be mutable, we can always update state and pass a new state forward. Also methods `push`

, `pop`

, `add`

, `mul`

don’t look special anymore, they are not that different from `m0`

, `m1`

, `m2`

, `m3`

. So let’s create an improved version of class `State`

and also implement those 4 methods. We can completely abandon the class `Calc`

then since there is nothing left there except for `List[Int]`

:

```
object Calc {
case class State[S, A](runState: S => (A, S)) {
def flatMap[B](f: A => State[S, B]): State[S, B] = {
State(
s => {
val (r, newState) = runState(s)
f(r).runState(newState)
}
)
}
def map[B](f: A => B): State[S, B] =
flatMap { x => State.unit(f(x)) }
def getResult(s: S) = runState(s)._1
}
object State {
def unit[S, A](x: A): State[S, A] =
State((x, _))
}
val push: Int => State[List[Int], Unit] =
x => State(s => ((), x::s))
val pop: () => State[List[Int], Int] =
() => State { case (x::xs) => (x, xs) }
val add: () => State[List[Int], Int] =
() => State { case (a::b::xs) => (a + b, xs) }
val mul: () => State[List[Int], Int] =
() => State { case (a::b::xs) => (a * b, xs) }
// (a, b, c, d) => (s0, s1, s2, s3)
val m0: (Int, Int, Int, Int) => State[List[Int], Unit] =
(a, b, c, d) =>
for {
_ <- push(d)
_ <- push(c)
_ <- push(b)
_ <- push(a)
} yield(())
// s0 * s1 + s2 * s3
val m1: () => State[List[Int], Int] =
() =>
for {
ab <- mul()
cd <- mul()
_ <- push(ab)
_ <- push(cd)
s <- add()
} yield(s)
// e * s0
val m2: (Int, Int) => State[List[Int], Int] =
(x, e) =>
for {
_ <- push(x)
_ <- push(e)
p <- mul()
} yield(p)
// (a * b + c * d) * e
val m3: (Int, Int, Int, Int, Int) => State[List[Int], Int] =
(a, b, c, d, e) => {
for {
_ <- m0(a, b, c, d)
r1 <- m1()
r2 <- m2(r1, e)
} yield(r2)
}
}
Calc.m3(1, 2, 3, 4, 5).getResult(Nil) //> 70
```

Class `State[S, A]`

is called State monad in functional programming. Let’s look at the benefits:

New methods are easily composable, good for code reuse.

We started from an attempt to get rid of a boilerplate:

```
// Before
val m1: () => (Calc => Int) =
() =>
calc => {
val ab = calc.mul()
val cd = calc.mul()
calc.push(ab)
calc.push(cd)
calc.add()
}
// After
val m1: () => State[List[Int], Int] =
() =>
for {
ab <- mul()
cd <- mul()
_ <- push(ab)
_ <- push(cd)
s <- add()
} yield(s)
```

I would argue the latter is more readable and less error prone, repeating `calc.`

in the former snippet simply creates a visual noise. So State monad provides a good way to build DSL-like composition.

- Immutable state. Not sure it brings much here since state is concealed and harder to mess up with.

We can easily apply other monadic concepts here, e.g. `traverse`

:

```
object State {
...
def traverse[S, A, B](la: List[A])(f: A => State[S, B]): State[S, List[B]] =
State(
s => {
val (bs, newS) =
la.foldLeft((collection.mutable.ListBuffer[B](), s)) { case ((bs, s), a) =>
val (r, newS) = f(a).runState(s)
(bs :+ r, newS)
}
(bs.toList, newS)
}
)
}
```

With `traverse`

code becomes even shorter:

```
// Before
val push: Int => State[List[Int], Unit] =
x => State(s => ((), x::s))
val m0: (Int, Int, Int, Int) => State[List[Int], Unit] =
(a, b, c, d) =>
for {
_ <- push(d)
_ <- push(c)
_ <- push(b)
_ <- push(a)
} yield(())
// s0 * s1 + s2 * s3
val m1: () => State[List[Int], Int] =
() =>
for {
ab <- mul()
cd <- mul()
_ <- push(ab)
_ <- push(cd)
s <- add()
} yield(s)
...
// After
val push: (Int*) => State[List[Int], List[Unit]] =
(xs) =>
State.traverse(xs.toList)(x => State(s => ((), x::s)))
// s0 * s1 + s2 * s3
val m1: () => State[List[Int], Int] =
() =>
for {
ab <- mul()
cd <- mul()
_ <- push(ab, cd)
s <- add()
} yield(s)
```