storch icon indicating copy to clipboard operation
storch copied to clipboard

Implement "pico" GPT example

Open hmf opened this issue 2 years ago β€’ 56 comments

Response to request in issue https://github.com/sbrunk/storch/issues/44.

Attempt to rewrite the "pico" example from Karpathy's "Let's build GPT: from scratch, in code, spelled out" in storch.

hmf avatar Aug 17 '23 09:08 hmf

@sbrunk or anyone else. I need some assistance in this work . To code the "pico" example, I need the Embedding operator. In my branch I have added this here. I have also added comments and made sure ScalaDoc is ok (minus the math expressions).

The code I am working on now, is the BiGram class. If I understand the code correctly, I have to pass a Tensor of shape/size (B,T) and get a Float back. According to the native code that seems to be a call to the forward method. So I am using this in the embedding class as per the other modules:

  def apply(t: Tensor[Int64]): Tensor[D] = Tensor(nativeModule.forward(t.native))

And this is a problem because I get the error:

[error] 101 |final class Embedding[D <: DType: Default](
[error]     |            ^
[error]     |class Embedding needs to be abstract, since def apply(v1: T1): R in trait Function1 in package scala is not defined 
[error]     |(Note that
[error]     | parameter T1 in def apply(v1: T1): R in trait Function1 in package scala does not match
[error]     | parameter torch.Tensor[torch.Int64] in def apply(t: torch.Tensor[torch.Int64]): torch.Tensor[D] in class Embedding in package torch.nn.modules.embed
[error]     | )

I think this is because we extend from TensorModule :

trait TensorModule[D <: DType] extends Module with (Tensor[D] => Tensor[D]):

In other words, the apply from (Tensor[D] => Tensor[D]) assumes the input and output are of the same type. Do we have other operators were this is not true? If not, how should we handle this?

On a related note, is it possible to constrain the Tensor by its shape?

TIA

hmf avatar Aug 17 '23 09:08 hmf

In order to keep going I have used the following solution:

  def apply(t: Tensor[D]): Tensor[D] = Tensor(nativeModule.forward(t.native))
  @targetName("apply_T_D")
  def apply[T<:DType](t: Tensor[T]): Tensor[D] = Tensor(nativeModule.forward(t.native))

Is this ok for a final solution?

hmf avatar Aug 17 '23 09:08 hmf

@hmf You're right Embedding is an example where the input type might be different from the output, so we can't inherit from TensorModule.

Note that @davoclavo has also added Embedding and a few other modules in #36 (haven't been able to finish and merge that yet, unfortunately) and added a more generic TensorModuleBase to tackle this issue:

https://github.com/sbrunk/storch/blob/05f7dbdca35daa0589447ad0d4eadbefe38e1aeb/core/src/main/scala/torch/nn/modules/sparse/Embedding.scala#L58-L68

https://github.com/sbrunk/storch/blob/05f7dbdca35daa0589447ad0d4eadbefe38e1aeb/core/src/main/scala/torch/nn/modules/Module.scala#L125-L127

So eventually we need to merge your solutions but for now you could also just inherit from nn.Module and add then use your apply method:

  def apply[T<:DType](t: Tensor[T]): Tensor[D] = Tensor(nativeModule.forward(t.native))

On a related note, is it possible to constrain the Tensor by its shape?

Right now, we're tracking only the dtype at compile time. We might add that in the future though.

sbrunk avatar Aug 17 '23 10:08 sbrunk

@sbrunk I have looked at the embedding class and my version its pretty close to it. Currently cannot search @davoclavo's branch, but I think I can copy and use that code (minimum set of classes with updated docs). Might be easier on your side.

In the meantime if you do merge into the main branch, I will update accordingly. Ok, with you?

hmf avatar Aug 17 '23 11:08 hmf

Sounds good to me πŸ‘

sbrunk avatar Aug 17 '23 11:08 sbrunk

Question about cross entropy functions. IThe orgial code uses something like:

import torch
import torch.nn as nn
from torch.nn import functional as F

...
loss = F.cross_entropy(logits, targets)
...
            probs = F.softmax(logits, dim=-1) # (B, C)

I see that we have 2 options, a function in the Loss package (does not exist yet, only binary version available) and the torch.nn.loss.CrossEntropyLoss version. The storch examples use the latter.

What are the advantages/disadvantages of using one or the other?

hmf avatar Aug 17 '23 13:08 hmf

I see that we have 2 options, a function in the Loss package (does not exist yet, only binary version available) and the torch.nn.loss.CrossEntropyLoss version. The storch examples use the latter.

What are the advantages/disadvantages of using one or the other?

PyTorch has a functional and a class/module variant for most of its nn operations. See torch.nn.functional.cross_entropy and torch.nn.CrossEntropyLoss. The class variant usually inherits from Module to it's easy to put it into containers expecting modules.

The functional variant does not contain any state, you call it directly with the tensor inputs and other arguments. The class/module variant can be initialized first with init parameters, and then later reused for different inputs. If you have modules with learnable weights/parameters, the module variant also helps you manage that state (makes it easier to update all weights of your model etc.).

For stateless ops without weights, like cross_entropy the class variant doesn't have much advantage except for reuse, so you can also just use the functional variant but it doesn't make much of a difference after all.

sbrunk avatar Aug 17 '23 13:08 sbrunk

Hello @hmf! awesome work on implementing Karpathy's examples. I have done some progress as well, but last month I got sidetracked with some things at work so wasn't able to prepare the code to share it.

I'll leave my progress implementing some of the model building blocks here in case it is helpful in any way to you. As @sbrunk mentioned, there are some new modules implemented in PR #36 - such as Embedding, LayerNorm, ModuleList, etc. - and this code expects those modules to exist in storch.

(Btw, you should be able to access my branch from via the PR, or via this direct link)

final case class Head[D <: FloatNN: Default](
    numEmbeddings: Int,
    headSize: Int,
    blockSize: Int,
    dropoutProb: Float
) extends TensorModule[D] {
  val query = register(nn.Linear(numEmbeddings, headSize))
  val key = register(nn.Linear(numEmbeddings, headSize))
  val value = register(nn.Linear(numEmbeddings, headSize))
  val tril = register(torch.tril(torch.ones(Seq(blockSize, blockSize))))
  val dropout = register(Dropout(dropoutProb))

  override def apply(input: Tensor[D]): Tensor[D] =
      val Seq(batch, timeStep, channels) = input.shape // (B, T, C) (64, 256, 384) [Float32]
      assert(blockSize == timeStep, "Block size must be equal to time step")

      val k: Tensor[D] = key(input) // (64, 256, 64) [Float32]
      val q: Tensor[D] = query(input) // (64, 256, 64) [Float32]
      val v: Tensor[D] = value(input) // (64, 256, 64) [Float32]

      // TODO Get rid of the `.to(dtype = q.dtype)`
      val weight =
        torch.matmul(q, torch.transpose(k, -2, -1)) / Tensor(Math.sqrt(channels)).to(dtype = q.dtype) // (64, 256, 256) [Float32]
      val weightMasked =
        weight.maskedFill(
          tril(Slice(0, timeStep), Slice(0, timeStep)) == 0,
          Float.NegativeInfinity
        ) // (64, 256, 256) [Float32]
      val attention =
        torch.nn.functional.softmax(weightMasked, dim = 2)(
          weightMasked.dtype
        ) // (64, 256, 256) [Float32]
      val attentionDropout = dropout(attention) // (64, 256, 256) [Float32]
      val output = weight.matmul(v) // (64, 256, 64) [Float32]
      output
}

final case class MultiHeadAttention[D <: FloatNN: Default](
    numHeads: Int,
    numEmbeddings: Int,
    headSize: Int,
    blockSize: Int,
    dropoutProb: Float
) extends TensorModule[D] {
  // Multiple heads of self-attention in parallel

  val heads = register(nn.ModuleList(Range(0, numHeads).map { _ =>
    Head[D](numEmbeddings, headSize, blockSize, dropoutProb)
  }*))
  val projection = register(nn.Linear(numHeads * headSize, numEmbeddings))
  val dropout = register(Dropout(dropoutProb))
  override def apply(input: Tensor[D]): Tensor[D] =
      val headOutputs = heads.map { head =>
        head(input)
      } // (6, 64, 256, 384) [Float32]
      val headOutputsConcat = torch.cat(headOutputs, dim = -1) // (64, 256, 384) [Float32]
      val projectedOutput = projection(headOutputsConcat) // (64, 256, 384) [Float32]
      dropout(projectedOutput) // (64, 256, 384) [Float32]
}

final case class FeedForward[D <: FloatNN: Default](numEmbeddings: Int, dropoutProb: Float)
    extends TensorModule[D] {
  // A simple linear layer followed by a non-linearity

  val net = register(nn.Sequential(
    nn.Linear(numEmbeddings, numEmbeddings * 4),
    nn.ReLU(),
    nn.Linear(numEmbeddings * 4, numEmbeddings),
    Dropout(dropoutProb)
  ))
  override def apply(input: Tensor[D]): Tensor[D] =
    net(input)

}

final case class Block[D <: FloatNN: Default](numEmbeddings: Int, numHeads: Int, blockSize: Int, dropoutProb: Float)
    extends TensorModule[D] {
  // Transformer block: communication followed by computation
  val headSize = numEmbeddings / numHeads // 384 / 6 = 64
  val attention = register(MultiHeadAttention(numHeads, numEmbeddings, headSize, blockSize, dropoutProb))
  val feedForward = register(FeedForward(numEmbeddings, dropoutProb))
  val layerNorm1 = register(nn.LayerNorm(Seq(numEmbeddings)))
  val layerNorm2 = register(nn.LayerNorm(Seq(numEmbeddings)))

  override def apply(input: Tensor[D]): Tensor[D] =
      // (64, 256, 384) [Float32]
      val a = input + attention(layerNorm1(input)) // (64, 256, 384) [Float32]
      val b = a + feedForward(layerNorm2(a)) // (64, 256, 384) [Float32]
      b

}

final case class Dropout[D <: FloatNN: Default](probability: Float) extends TensorModule[D] {
  override def apply(x: Tensor[D]): Tensor[D] =
    nn.functional.dropout(x, probability)
}

I'm happy to assist you in any way to get this to work. I was able to get some inference going without any runtime errors, but haven't had time to train the model using shakespeare writings yet.

I will also be available to continue work on the pending PR to get it merged, in case I can help in any way @sbrunk

davoclavo avatar Aug 17 '23 19:08 davoclavo

Oh I forgot, there are also some changes needed for pico GPT that I haven't created a PR for, but I have fixed in my local project. I aim to get these changes submitted soon, but here they are in case you need them earlier:

Tensor#maskedFill

def maskedFill[S <: ScalaType](mask: Tensor[Bool], value: S): Tensor[D] = Tensor(
 native.masked_fill(mask.native, toScalar(value))
)

Tensor#sqrt

def sqrt = Tensor(native.sqrt())

torch.tril

  def tril[D <: DType](input: Tensor[D], diagonal: Int = 0): Tensor[D] =
    Tensor(torchNative.tril(input.native, diagonal.toLong))

Fixing tensor.split (see #39)

  def split[D <: DType](
      input: Tensor[D],
      splitSizeOrSections: Int | Seq[Int],
      dim: Int = 0
  ): Seq[Tensor[D]] = {
    val result =
      splitSizeOrSections match {
        case i: Int      => torchNative.split(input.native, i.toLong, dim.toLong)
        case s: Seq[Int] => torchNative.split(input.native, s.map(_.toLong).toArray, dim.toLong)
      }
    (0L until result.size()).map(i => Tensor(result.get(i)).clone())
  }

davoclavo avatar Aug 17 '23 19:08 davoclavo

I will also be available to continue work on the pending PR to get it merged, in case I can help in any way @sbrunk

@davoclavo feel free to take over #36 again if you have capacity. I've merged main into it with some improvements of the native bindings but since Scala Days is only 4 weeks away I'd like to focus on getting my Storch talk ready first. Happy to help/review etc. but I'm not sure I'll be able to actually work on it before the talk.

sbrunk avatar Aug 17 '23 20:08 sbrunk

@sbrunk sounds good, I'll try to polish the last remaining bits.

Best of luck on the Scala Days talk! Hopefully it will be streamed/recorded, I'd love to watch it :D

davoclavo avatar Aug 17 '23 20:08 davoclavo

Best of luck on the Scala Days talk! Hopefully it will be streamed/recorded, I'd love to watch it :D

Thanks! I'm sure it will be recorded and put on youtube some time after the conference as the videos from the Seattle edition from June are already online. I'll keep you posted :)

sbrunk avatar Aug 17 '23 20:08 sbrunk

@davoclavo Thanks for the assist. Please note that at this time I am working on the very simple "video" version. My aim here is to learn about GPT.

I will look at your code and incorporate all I can to make merging easier.

hmf avatar Aug 18 '23 06:08 hmf

Questions regarding softmax. I was coding the cross_entropy examples to make sure the typing is correct. In the second example we need the softmax function in the link below. Looking at the code I see we have:

  def softmax[In <: DType, Out <: DType](input: Tensor[In], dim: Long)(
      dtype: Out = input.dtype
  ): Tensor[Out] =
    val nativeDType =
      if dtype == input.dtype then ScalarTypeOptional() else ScalarTypeOptional(dtype.toScalarType)
    Tensor(torchNative.softmax(input.native, dim, nativeDType))

This means that we have explicitly provide the last (usually empty) parameter so:

  val target1 = F.softmax( input=torch.randn(Seq(3, 5)), dim=1L)()

If we don't, we get the error:

[error] 358 |  val loss1 = F.crossEntropy(input1, target1)
[error]     |                                     ^^^^^^^
[error]     |Found:    (gpt.BiGram.target1 : torch.DType => torch.Tensor[torch.DType])
[error]     |Required: torch.Tensor[O]
[error]     |
[error]     |where:    O is a type variable with constraint <: torch.NumericRealNN

I have made that last parameter an implicit. I did the same for logSoftmax. If we do this, we avoid having to provide that last parameter. It seems that only the softmax call was used. Ran the test, had no problem. Ok, with this change or am I missing something?

The original Python example code uses a Tensor.softmax(dim=1) call. This method does not exist in storch. The Python documentation states that it is an "Alias for torch.nn.functional.softmax()." Should we add this? If so, do we add as a standard method or use use Scala 3 extension methods?

TIA

hmf avatar Aug 18 '23 07:08 hmf

I have made that last parameter an implicit. I did the same for logSoftmax. If we do this, we avoid having to provide that last parameter. It seems that only the softmax call was used. Ran the test, had no problem. Ok, with this change or am I missing something?

That's fine but could you give the following variant a try? It's a solution we already use in other places and avoids both implicits and multiple parameter lists (at the expense of a slightly more verbose type signature).

import Derive.derive

// ...

  def softmax[In <: DType, Out <: FloatNN | Derive](
      input: Tensor[In],
      dim: Long,
      dtype: Out = derive
  ): Tensor[DTypeOrDeriveFromTensor[In, Out]] =
    val derivedDType = dtype match
      case _: Derive => input.dtype
      case d: DType  => d
    val nativeDType =
      if dtype == input.dtype then ScalarTypeOptional()
      else ScalarTypeOptional(derivedDType.toScalarType)
    Tensor(torchNative.softmax(input.native, dim, nativeDType))
}

The original Python example code uses a Tensor.softmax(dim=1) call. This method does not exist in storch. The Python documentation states that it is an "Alias for torch.nn.functional.softmax()." Should we add this? If so, do we add as a standard method or use use Scala 3 extension methods?

Yes, you can add it as a regular method in Tensor delegating to the implementation in nn.functional

sbrunk avatar Aug 18 '23 10:08 sbrunk

That's fine but could you give the following variant a try? It's a solution we already use in other places and avoids both implicits and multiple parameter lists (at the expense of a slightly more verbose type signature).

Done (also for logSoftmax). Compiled and all tests pass.

Yes, you can add it as a regular method in Tensor delegating to the implementation in nn.functional

Done:

  def shape: Seq[Int] = size

  def softmax[Out <: FloatNN | Derive](
      dim: Long,
      dtype: Out = derive
  ): Tensor[DTypeOrDeriveFromTensor[D, Out]] = F.softmax(input = this, dim = dim, dtype = dtype)

  def square = Tensor(native.square())

hmf avatar Aug 18 '23 11:08 hmf

While trying to replicate the Colaboratory notebook to check the code is working, I tried to do the following:

  // We want x[b,t] = mean_{i<=t} x[b,i]
  val xbow = torch.zeros(Seq(b0, t0, c0))
  for b <- 0 until b0
  do
    for t <- 0 until t0
    do
      val xprev = x(b,ΒΊ`:`t+1) // (t,C)
      xbow(b,t) = torch.mean(xprev, 0)  

The Tensorclass has no assignment operator. I also did not find a method for this in the JavaCPP code. How should one go about assigning a value?

TIA

hmf avatar Aug 18 '23 15:08 hmf

The Tensorclass has no assignment operator. I also did not find a method for this in the JavaCPP code. How should one go about assigning a value?

The C++ API has a method for assigning values (with indices): See https://pytorch.org/cppdocs/notes/tensor_indexing.html#setter It's just not that easy to find, because it's named index_put_. It's also mapped via JavaCPP, but was missing in Storch.

https://github.com/sbrunk/storch/pull/53 should add support for it. Could you give it a try?

sbrunk avatar Aug 18 '23 22:08 sbrunk

Found some compiler weirdness with the changes above.These do not compile:

      xbow(Seq(b,t)) = torch.mean(input=xprev, dim=0)
      xbow(Seq(b,t)) = torch.mean(xprev, dim=0)  

The error is:

method mean in trait ReductionOps: (input: torch.Tensor[?], dtype: torch.Float32): torch.Tensor[torch.Float32] does not have a parameter dim

and (for the last one):

Found:    (0 : Int)
Required: torch.Float32

But these do:

      xbow(b,t) += torch.mean(xprev, dim=0)  
      val c = torch.mean(xprev, dim=0) 
      xbow(Seq(b,t)) = c
      xbow(Seq(b,t)) = torch.mean(input=xprev, dim=0, true, float32)
      xbow(Seq(b,t)) = torch.mean(input=xprev, dim=0, true)

Maybe some tweaking of the 1st definition may get it working, but seems like a Scala issue.

hmf avatar Aug 19 '23 10:08 hmf

It looks like the compiler gets confused by the overloaded variants of mean for whatever reason. I've seen this in other places with different generic overloads.

I realized that the default dim argument with an empty seq defaults to the behavior of the overloaded variants, making them redundant so I've removed them now in #53. Could you give it another try with the changes?

sbrunk avatar Aug 19 '23 17:08 sbrunk

@sbrunk Changes work fine. Thanks.

hmf avatar Aug 20 '23 08:08 hmf

I need the use of Dropout. In Python this seems to return a constructor of sorts (did not check), which can then be applied to a Tensor.

I see that we have a torch.nn.Dropout that is private to the torch package. So the more obvious solution of having a public Dropout class and its companion object will require changes. I have the following questions:

  1. Is the suggested change above ok?
  2. If so, can I go ahead and change this?
  3. If not, what is the storch way?

EDIT 1:

@davoclavo I realized you have already defined Dropout. I searched your repo but did not find it. Were did you define it? TIA

hmf avatar Sep 08 '23 15:09 hmf

I would like to use register_buffer. According to the Python API doc, we must pass in a name.

Looking at the org.bytedeco.pytorch.Module we have:

  public Tensor register_buffer(BytePointer name, Tensor tensor) { return asModule()._register_buffer(name, tensor); }
  private native @ByRef @Name("register_buffer") Tensor _register_buffer(@StdString BytePointer name, @ByVal Tensor tensor);
  public Tensor register_buffer(String name, Tensor tensor) { return asModule()._register_buffer(name, tensor); }
  private native @ByRef @Name("register_buffer") Tensor _register_buffer(@StdString String name, @ByVal Tensor tensor);

So in torch.nn.modules.Module something like this should work:

  def registerB[D <: DType](n: String, t: Tensor[D]): Tensor[D] =
    nativeModule.register_buffer(n, t.native)
    t

However, as an example:

  def register[D <: DType](t: Tensor[D], requiresGrad: Boolean = true)(using
      name: sourcecode.Name
  ): Tensor[D] =
    nativeModule.register_parameter(name.value, t.native, requiresGrad)
    t

the name is implicitly defined. Is there any way I can keep the implicit but still allow manually setting that name?

On a related not, shouldn't these functions return a Tensor(t). We are assuming the same tensor is returned, but this is not guaranteed.

EDIT 1: we also have the problem of duplicate overload methods due to the use of defaults. What is the way to solve this here? Can I change the names?

EDIT 2: In the meantime I will use:


  def buffer[D <: DType](t: Tensor[D], n: String="")(using
      name: sourcecode.Name
  ): Tensor[D] =
    val name_ = if n.trim().isEmpty() then name.value else n.trim()
    Tensor( nativeModule.register_buffer(n, t.native) )

TIA

hmf avatar Sep 08 '23 15:09 hmf

I need the use of Dropout. In Python this seems to return a constructor of sorts (did not check), which can then be applied to a Tensor.

I see that we have a torch.nn.Dropout that is private to the torch package. So the more obvious solution of having a public Dropout class and its companion object will require changes. I have the following questions:

1. Is the suggested change above ok?

2. If so, can I go ahead and change this?

3. If not, what is the `storch` way?

I think what you found is the Dropout trait in torch.nn.functional right? The trait is private because it's members are exposed through the package object, so you can call it like this:

torch.nn.functional.dropout(input=torch.rand(Seq(3,3)))
// res2: Tensor[Float32] = tensor dtype=float32, shape=[3, 3], device=CPU 
// [[0,4759, 1,4497, 1,7002],
//  [1,2299, 0,0000, 1,1805],
//  [0,0000, 0,0000, 0,0000]]

It corresponds to torch.nn.functional.dropout in Python.

Seems like we're still missing the module variant of Dropout, which corresponds to the Python module you linked to. If you'd like to add that, that would be great! We should put it be under torch.nn.modules somewhere, like the other modules.

sbrunk avatar Sep 08 '23 16:09 sbrunk

So in torch.nn.modules.Module something like this should work:

  def registerB[D <: DType](n: String, t: Tensor[D]): Tensor[D] =
    nativeModule.register_buffer(n, t.native)
    t

However, as an example:

  def register[D <: DType](t: Tensor[D], requiresGrad: Boolean = true)(using
      name: sourcecode.Name
  ): Tensor[D] =
    nativeModule.register_parameter(name.value, t.native, requiresGrad)
    t

the name is implicitly defined. Is there any way I can keep the implicit but still allow manually setting that name?

We could add an explicit optional name parameter, i.e. defaulting to an empty string, or using an Option. If the caller provides a real name, we take that, otherwise, we fall back to the implicit. Ah I see you've just done that below in the buffer impl :)

On a related not, shouldn't these functions return a Tensor(t). We are assuming the same tensor is returned, but this is not guaranteed.

You're right, it's better to use the tensor returned by the native register method.

EDIT 1: we also have the problem of duplicate overload methods due to the use of defaults. What is the way to solve this here? Can I change the names?

Yes please go ahead. Perhaps we can keep register for modules, because it is used quite often, but use registerParameter, registerBuffer for the others.

EDIT 2: In the meantime I will use:

  def buffer[D <: DType](t: Tensor[D], n: String="")(using
      name: sourcecode.Name
  ): Tensor[D] =
    val name_ = if n.trim().isEmpty() then name.value else n.trim()
    Tensor( nativeModule.register_buffer(n, t.native) )

πŸ‘

sbrunk avatar Sep 08 '23 16:09 sbrunk

@davoclavo I realized you have already defined Dropout. I searched your repo but did not find it. Were did you define it? TIA

Hi @hmf ! Apologies for the confusion, I have not committed my changes yet, as I have a bunch of other stuff that needs to be cleaned up. I just shared them in my previous comment to partially share the progress in case it was useful to you :)

You should be able to either drop in that code I shared in your script/example, or add it as a new module to storch.

I'll keep my ear open in case you need any further help, and hopefully find some time soon to help out to contribute these modules to storch.

davoclavo avatar Sep 08 '23 16:09 davoclavo

While trying to implement and debug the multi-head attention mechanism, I have what seems to be unexpected behavior. For a model with the multi-head "only", the code:

    val nuParams = m.parameters.map(_.numel).sum
    println(s"${nuParams} parameters")

Reports:

Multi-head attention
4481 parameters

Now to this model I add the following layer:

    val ffwd = register( FeedFoward(nEmbed) )

where nEmbed = 32. If I count the number of parameters of this layer I get 1056 (nEmbed*nEmbed + nEmbed), which is correct. But the model still reports:

Multi-head attention + FFWD
4481 parameters

Shouldn't that be 4481 + 1056?

TIA

hmf avatar Sep 22 '23 15:09 hmf

@hmf I have a hunch (not tested). Could you try to wrap your Sequential in your feed forward module inside a register as well like so:

https://github.com/sbrunk/storch/blob/5e1fdf2a7b2d985a58ee7a6f8405cd8d443426b4/examples/src/main/scala/gpt/BiGram.scala#L1316-L1326

- val net = nn.Sequential(
+ val net = register(nn.Sequential(

Right now it's registering the layers inside Sequential as submodules of net, but not net itself as a submodule of FeedForward. In Python this is done implicitly. Perhaps we need a macro at some point to achieve s.th. similar in Storch as well.

sbrunk avatar Sep 22 '23 21:09 sbrunk

@sbrunk I have confirmed that I need to register the inner modules. As for the macro, maybe a single function that traverses the sub-modules and registers them would do. But we also have parameter and buffer registering, so that would also have to dealt with.

Thanks.

hmf avatar Sep 23 '23 09:09 hmf

I would like to give an update on this endeavor. I have gone through most of the video and am now at the start of the "Block" implementation. I have tried to stick to the video so that I can compare my results. Unfortunately my results show much higher loss (single head and multi head of 3).

Here are some results:

Single head

  • Andrej Karpathy gets 2.2858 @ 4500 iterations
  • Here we get 3.350137 @ 4500 iterations
  • lr = 1.e-5 (with Karpathy 1.4e-5, loss explodes

Triple Head

  • Andrej karpathy gets 2.2412 @ 4500
  • Here we get 3.6443036 @ 4500 iterations
  • lr = 1.e-5 (with Karpathy 1.4e-5, loss explodes)

I have run about 9 experiments on CPU. Even though convergence is slow, the good news is that it seems to be stable. See below.

Single Head

lr = 1e-5
Output:
step 0: train loss 4.315746, val loss 4.3061743
step 500: train loss 4.2083063, val loss 4.2047343
step 1000: train loss 4.109281, val loss 4.1095076
step 1500: train loss 4.024676, val loss 4.021858
step 2000: train loss 3.9401476, val loss 3.9419503
step 2500: train loss 3.861138, val loss 3.868681
step 3000: train loss 3.7746782, val loss 3.7817297
step 3500: train loss 3.6901476, val loss 3.7049506
step 4000: train loss 3.599073, val loss 3.617259
step 4500: train loss 3.5131109, val loss 3.5384142
step 5000: train loss 3.452971, val loss 3.4619794
step 5500: train loss 3.399948, val loss 3.4254942
step 6000: train loss 3.3541067, val loss 3.3918
step 6500: train loss 3.3242495, val loss 3.3732038
step 7000: train loss 3.3144944, val loss 3.3490424
step 7500: train loss 3.2901514, val loss 3.2941566
step 8000: train loss 3.2899778, val loss 3.308439
step 8500: train loss 3.2639534, val loss 3.2906058
step 9000: train loss 3.2651227, val loss 3.2723944
step 9500: train loss 3.2395923, val loss 3.2861238
step 10000: train loss 3.2434728, val loss 3.257814
step 10500: train loss 3.2285821, val loss 3.23281
step 11000: train loss 3.2198544, val loss 3.2416165
step 11500: train loss 3.2021954, val loss 3.2313745
step 12000: train loss 3.195072, val loss 3.2142315
step 12500: train loss 3.1960852, val loss 3.2163675
step 13000: train loss 3.1769931, val loss 3.2013638
step 13500: train loss 3.17453, val loss 3.2119668
step 14000: train loss 3.1472147, val loss 3.1825323
step 14500: train loss 3.1611233, val loss 3.192211
step 15000: train loss 3.1517265, val loss 3.1621974
step 15500: train loss 3.1394618, val loss 3.1598687
step 16000: train loss 3.1233463, val loss 3.145328
step 16500: train loss 3.1227674, val loss 3.1421418
step 17000: train loss 3.1164768, val loss 3.1276824
step 17500: train loss 3.1011841, val loss 3.0985348
step 18000: train loss 3.0856524, val loss 3.11533
step 18500: train loss 3.0842745, val loss 3.0987678
step 19000: train loss 3.049956, val loss 3.1043591
step 19500: train loss 3.0564034, val loss 3.0689766
step 20000: train loss 3.0590668, val loss 3.0758286
step 20500: train loss 3.0560205, val loss 3.0690722
step 21000: train loss 3.0467145, val loss 3.0635276
step 21500: train loss 3.0318224, val loss 3.0459983
step 22000: train loss 3.025454, val loss 3.0337
step 22500: train loss 3.0058165, val loss 3.0480902
step 23000: train loss 3.0240664, val loss 3.0332391
step 23500: train loss 2.9987218, val loss 3.023562
step 24000: train loss 2.985587, val loss 3.0277314
step 24500: train loss 2.9775257, val loss 3.002483
step 24999: train loss 2.9854958, val loss 3.0055265
step 24999: train loss 2.9771202, val loss 3.0027666

Triple Head

learningRate = 1.0E-5 
maxIterations = 75000
Output: step 0: train loss 4.1618342, val loss 4.16153 step 500: train loss 4.1205373, val loss 4.1242867 step 1000: train loss 4.0790596, val loss 4.081698 step 1500: train loss 4.03232, val loss 4.0372114 step 2000: train loss 3.9790084, val loss 3.9862146 step 2500: train loss 3.9226956, val loss 3.9263957 step 3000: train loss 3.8504639, val loss 3.8638783 step 3500: train loss 3.7733784, val loss 3.786392 step 4000: train loss 3.6981156, val loss 3.720096 step 4500: train loss 3.628634, val loss 3.6443036 step 5000: train loss 3.5587113, val loss 3.5619648 step 5500: train loss 3.4964852, val loss 3.4965785 step 6000: train loss 3.421188, val loss 3.4415543 step 6500: train loss 3.362888, val loss 3.391549 step 7000: train loss 3.3282533, val loss 3.3622048 step 7500: train loss 3.320427, val loss 3.3560946 step 8000: train loss 3.292354, val loss 3.2881603 step 8500: train loss 3.2728596, val loss 3.2815585 step 9000: train loss 3.2583148, val loss 3.2723749 step 9500: train loss 3.2506166, val loss 3.2684808 step 10000: train loss 3.2148948, val loss 3.2601957 step 10500: train loss 3.1988037, val loss 3.2456586 step 11000: train loss 3.206799, val loss 3.2168744 step 11500: train loss 3.1882236, val loss 3.2074182 step 12000: train loss 3.1804316, val loss 3.213266 step 12500: train loss 3.1568613, val loss 3.1786158 step 13000: train loss 3.1662986, val loss 3.1859655 step 13500: train loss 3.1503942, val loss 3.1711109 step 14000: train loss 3.147156, val loss 3.166469 step 14500: train loss 3.1371622, val loss 3.1470597 step 15000: train loss 3.1360898, val loss 3.1625636 step 15500: train loss 3.1335275, val loss 3.1326685 step 16000: train loss 3.1126864, val loss 3.1321425 step 16500: train loss 3.1063373, val loss 3.107653 step 17000: train loss 3.0943058, val loss 3.1191053 step 17500: train loss 3.0952134, val loss 3.1210895 step 18000: train loss 3.1009033, val loss 3.081947 step 18500: train loss 3.0783198, val loss 3.1051643 step 19000: train loss 3.0771048, val loss 3.0912302 step 19500: train loss 3.0516539, val loss 3.0886228 step 20000: train loss 3.0385761, val loss 3.0701919 step 20500: train loss 3.042524, val loss 3.0808887 step 21000: train loss 3.0532212, val loss 3.0581033 step 21500: train loss 3.0394647, val loss 3.0615013 step 22000: train loss 3.021087, val loss 3.0511756 step 22500: train loss 3.0327508, val loss 3.0316634 step 23000: train loss 3.0150063, val loss 3.044455 step 23500: train loss 3.0176592, val loss 3.0279248 step 24000: train loss 3.0032563, val loss 3.0306866 step 24500: train loss 2.9985764, val loss 3.031719 step 25000: train loss 2.9964828, val loss 3.0107496 step 25500: train loss 2.9989612, val loss 3.0088224 step 26000: train loss 2.9867206, val loss 3.0070848 step 26500: train loss 2.9651825, val loss 3.009421 step 27000: train loss 2.978981, val loss 2.9872468 step 27500: train loss 2.972667, val loss 2.9928696 step 28000: train loss 2.9587805, val loss 2.9770303 step 28500: train loss 2.9506211, val loss 2.9797046 step 29000: train loss 2.9521976, val loss 2.9750147 step 29500: train loss 2.9423668, val loss 2.9667535 step 30000: train loss 2.9549394, val loss 2.9439688 step 30500: train loss 2.9268918, val loss 2.9612598 step 31000: train loss 2.91975, val loss 2.94916 step 31500: train loss 2.9251237, val loss 2.934956 step 32000: train loss 2.9079664, val loss 2.9431274 step 32500: train loss 2.910727, val loss 2.9221253 step 33000: train loss 2.91429, val loss 2.919466 step 33500: train loss 2.9074776, val loss 2.9280725 step 34000: train loss 2.896589, val loss 2.9004114 step 34500: train loss 2.898249, val loss 2.9251142 step 35000: train loss 2.8961527, val loss 2.9172351 step 35500: train loss 2.8839464, val loss 2.9006162 step 36000: train loss 2.8779233, val loss 2.9100876 step 36500: train loss 2.879361, val loss 2.9014406 step 37000: train loss 2.8839698, val loss 2.8981316 step 37500: train loss 2.853179, val loss 2.8779168 step 38000: train loss 2.8649895, val loss 2.894176 step 38500: train loss 2.8693879, val loss 2.8744462 step 39000: train loss 2.8525827, val loss 2.8651721 step 39500: train loss 2.858041, val loss 2.8500593 step 40000: train loss 2.8418512, val loss 2.863662 step 40500: train loss 2.8385842, val loss 2.8543704 step 41000: train loss 2.8311198, val loss 2.8574524 step 41500: train loss 2.825884, val loss 2.8499897 step 42000: train loss 2.8429782, val loss 2.8441114 step 42500: train loss 2.8070157, val loss 2.8388376 step 43000: train loss 2.8123505, val loss 2.842098 step 43500: train loss 2.810964, val loss 2.8345373 step 44000: train loss 2.8263602, val loss 2.830025 step 44500: train loss 2.811398, val loss 2.834848 step 45000: train loss 2.802633, val loss 2.810559 step 45500: train loss 2.8123126, val loss 2.8265247 step 46000: train loss 2.7979581, val loss 2.8048408 step 46500: train loss 2.7967849, val loss 2.8334157 step 47000: train loss 2.7803953, val loss 2.7922354 step 47500: train loss 2.7942781, val loss 2.825274 step 48000: train loss 2.7804523, val loss 2.792919 step 48500: train loss 2.785722, val loss 2.8042364 step 49000: train loss 2.7795138, val loss 2.7809396 step 49500: train loss 2.7776642, val loss 2.7782316 step 50000: train loss 2.769403, val loss 2.7787275 step 50500: train loss 2.7557025, val loss 2.7765558 step 51000: train loss 2.759183, val loss 2.7775955 step 51500: train loss 2.7498598, val loss 2.7687922 step 52000: train loss 2.764737, val loss 2.7726612 step 52500: train loss 2.7710688, val loss 2.7590082 step 53000: train loss 2.7473223, val loss 2.760512 step 53500: train loss 2.7373915, val loss 2.7564347 step 54000: train loss 2.7325678, val loss 2.7411654 step 54500: train loss 2.7540653, val loss 2.752793 step 55000: train loss 2.736955, val loss 2.751245 step 55500: train loss 2.7224433, val loss 2.7364216 step 56000: train loss 2.7233686, val loss 2.74944 step 56500: train loss 2.7202756, val loss 2.7465448 step 57000: train loss 2.7280054, val loss 2.7374096 step 57500: train loss 2.7064633, val loss 2.7330124 step 58000: train loss 2.6934423, val loss 2.7236161 step 58500: train loss 2.6968424, val loss 2.72582 step 59000: train loss 2.6981068, val loss 2.7159605 step 59500: train loss 2.695939, val loss 2.724237 step 60000: train loss 2.6998184, val loss 2.7238555 step 60500: train loss 2.6900072, val loss 2.7078435 step 61000: train loss 2.6998444, val loss 2.7143097 step 61500: train loss 2.6824317, val loss 2.699878 step 62000: train loss 2.678613, val loss 2.6927574 step 62500: train loss 2.695001, val loss 2.7028537 step 63000: train loss 2.6931143, val loss 2.6938097 step 63500: train loss 2.6818473, val loss 2.6830072 step 64000: train loss 2.6860394, val loss 2.6763582 step 64500: train loss 2.6754217, val loss 2.6692927 step 65000: train loss 2.652602, val loss 2.6785924 step 65500: train loss 2.6655686, val loss 2.6759882 step 66000: train loss 2.6485276, val loss 2.6638012 step 66500: train loss 2.6445954, val loss 2.6824584 step 67000: train loss 2.6588178, val loss 2.6743371 step 67500: train loss 2.665208, val loss 2.6798043 step 68000: train loss 2.6643429, val loss 2.6748931 step 68500: train loss 2.6562061, val loss 2.6429644 step 69000: train loss 2.6405647, val loss 2.648562 step 69500: train loss 2.6491652, val loss 2.6551437 step 70000: train loss 2.641609, val loss 2.6503496 step 70500: train loss 2.6256104, val loss 2.6489353 step 71000: train loss 2.6348572, val loss 2.6602316 step 71500: train loss 2.6440005, val loss 2.6452422 step 72000: train loss 2.625387, val loss 2.655331 step 72500: train loss 2.6233087, val loss 2.6433735 step 73000: train loss 2.623311, val loss 2.6347494 step 73500: train loss 2.609082, val loss 2.6489167 step 74000: train loss 2.6275, val loss 2.6279202 step 74500: train loss 2.624021, val loss 2.643931 step 74999: train loss 2.6234972, val loss 2.628585 step 75000: train loss 2.6114013, val loss 2.623228

hmf avatar Sep 28 '23 16:09 hmf