scalaz-plugin icon indicating copy to clipboard operation
scalaz-plugin copied to clipboard

ST fusion

Open sir-wabbit opened this issue 7 years ago • 1 comments

The best performing IO or ST is still tremendously slower than impure code.

From twitter:

For each def foo(args: ...): ST[S, A], create an impure def foo$ST(args...): A. Whenever you do ST { body } flatMap(foo(...)) rewrite it as ST { val a = body; foo$ST(args) }. This way you can remove all or almost all trampoline jumps in non-higher order ST code. Finally at the end rewrite runST { ST { foo } } into foo$ST. Cross-module fusion will only work for non-HO code, but arguably that's the majority of IO and ST code anyway.

Note that this can not be done with a simple rewriting system, since we need to generate methods.

sir-wabbit avatar Jun 28 '18 20:06 sir-wabbit

Here is how I think the encoding should look like:

import scala.reflect.ClassTag

object Test {
  def foo[A: ClassTag](l: List[A]): List[A] = ST.runST(new ForallST[List[A]] {
    def apply[S]: ST[S, List[A]] = for {
      b <- Buf(l)
      a <- b.get(0)
      _ <- b.set(1, a)
      l <- b.toList
    } yield l
  })
  def baz[S, A](buf: Buf[S, A]): ST[S, Int] = buf.size map (_ + 1)

  // Same module
  def foo[A: ClassTag](l: List[A]): List[A] = {
    val b = l.toArray
    val a = b.data(0)
    b.data(1) = a
    b.data.toList
  }
  def baz$ST[S, A](buf: Buf[S, A]): Int = buf.data.size + 1
  def baz[S, A](buf: Buf[S, A]): ST[S, Int] = ST { baz$ST(buf) }

  // Cross-module or implementations invoking private[this] functions
  def foo[A: ClassTag](l: List[A]): List[A] = {
    val b = Buf.apply$ST(l)
    val a = b.get$ST(0)
    b.set$ST(1, a)
    b.toList$ST
  }
  def baz$ST[S, A](buf: Buf[S, A]): Int = buf.size$ST() + 1
  def baz[S, A](buf: Buf[S, A]): ST[S, Int] = ST { baz$ST(buf) }
}

//////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////

class Buf[S, A] private (private var data: Array[A]) {
  def size: ST[S, Int] = ST { data.size }
  def get(i: Int): ST[S, A] = ST { data(i) }
  def set(i: Int, a: A): ST[S, Unit] = ST { data(i) = a }
  def toList: ST[S, List[A]] = ST { data.toList }
}
object Buf {
  def apply[S, A: ClassTag](l: List[A]): ST[S, Buf[S, A]] =
    ST { new Buf[S, A](l.toArray) }
}

trait ForallST[A] {
  def apply[S]: ST[S, A]
}

final class ST[S, A] private (private val unsafeRun: () => A) {
  def map[B](f: A => B): ST[S, B] = new ST(() => f(unsafeRun()))
  def flatMap[B](f: A => ST[S, B]): ST[S, B] = new ST(() => f(unsafeRun()).unsafeRun())
}
object ST {
  def apply[S, A](a: => A): ST[S, A] = new ST(() => a)
  def pure[S, A](a: A): ST[S, A] = new ST(() => a)

  def runST[A](s: ForallST[A]): A = s.apply[Unit].unsafeRun()
}

sir-wabbit avatar Jun 29 '18 18:06 sir-wabbit