scalaz-plugin
scalaz-plugin copied to clipboard
ST fusion
The best performing IO or ST is still tremendously slower than impure code.
From twitter:
For each def foo(args: ...): ST[S, A], create an impure def foo$ST(args...): A. Whenever you do ST { body } flatMap(foo(...)) rewrite it as ST { val a = body; foo$ST(args) }. This way you can remove all or almost all trampoline jumps in non-higher order ST code. Finally at the end rewrite runST { ST { foo } } into foo$ST. Cross-module fusion will only work for non-HO code, but arguably that's the majority of IO and ST code anyway.
Note that this can not be done with a simple rewriting system, since we need to generate methods.
Here is how I think the encoding should look like:
import scala.reflect.ClassTag
object Test {
def foo[A: ClassTag](l: List[A]): List[A] = ST.runST(new ForallST[List[A]] {
def apply[S]: ST[S, List[A]] = for {
b <- Buf(l)
a <- b.get(0)
_ <- b.set(1, a)
l <- b.toList
} yield l
})
def baz[S, A](buf: Buf[S, A]): ST[S, Int] = buf.size map (_ + 1)
// Same module
def foo[A: ClassTag](l: List[A]): List[A] = {
val b = l.toArray
val a = b.data(0)
b.data(1) = a
b.data.toList
}
def baz$ST[S, A](buf: Buf[S, A]): Int = buf.data.size + 1
def baz[S, A](buf: Buf[S, A]): ST[S, Int] = ST { baz$ST(buf) }
// Cross-module or implementations invoking private[this] functions
def foo[A: ClassTag](l: List[A]): List[A] = {
val b = Buf.apply$ST(l)
val a = b.get$ST(0)
b.set$ST(1, a)
b.toList$ST
}
def baz$ST[S, A](buf: Buf[S, A]): Int = buf.size$ST() + 1
def baz[S, A](buf: Buf[S, A]): ST[S, Int] = ST { baz$ST(buf) }
}
//////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////
class Buf[S, A] private (private var data: Array[A]) {
def size: ST[S, Int] = ST { data.size }
def get(i: Int): ST[S, A] = ST { data(i) }
def set(i: Int, a: A): ST[S, Unit] = ST { data(i) = a }
def toList: ST[S, List[A]] = ST { data.toList }
}
object Buf {
def apply[S, A: ClassTag](l: List[A]): ST[S, Buf[S, A]] =
ST { new Buf[S, A](l.toArray) }
}
trait ForallST[A] {
def apply[S]: ST[S, A]
}
final class ST[S, A] private (private val unsafeRun: () => A) {
def map[B](f: A => B): ST[S, B] = new ST(() => f(unsafeRun()))
def flatMap[B](f: A => ST[S, B]): ST[S, B] = new ST(() => f(unsafeRun()).unsafeRun())
}
object ST {
def apply[S, A](a: => A): ST[S, A] = new ST(() => a)
def pure[S, A](a: A): ST[S, A] = new ST(() => a)
def runST[A](s: ForallST[A]): A = s.apply[Unit].unsafeRun()
}