TensorCast.jl icon indicating copy to clipboard operation
TensorCast.jl copied to clipboard

Eror when calcualte gradients with zygote

Open countradooku opened this issue 1 year ago • 5 comments

i get the folowing error when i try to calcualte gradients:

ERROR: Scalar indexing is disallowed. Invocation of getindex resulted in scalar indexing of a GPU array. This is typically caused by calling an iterating implementation of a method. Such implementations do not execute on the GPU, but very slowly on the CPU, and therefore should be avoided.

If you want to allow scalar iteration, use allowscalar or @allowscalar to enable scalar iteration globally or for the operations in question. Stacktrace:

using TensorCast
using Flux, AMDGPU, OMEinsum
using OMEinsum

struct LayerNorm{T<:AbstractArray}
    g::T
end

function LayerNorm(dim::Int)
    g = ones(Float32, (1, 1, dim, 1))
    LayerNorm(g)
end

Flux.@layer :expand LayerNorm

function (ln::LayerNorm)(x)
    dims = 3
    eps::Float32 = 1e-5
    μ = mean(x; dims=dims)
    σ² = var(x; dims=dims)
    x_normalized = (x .- μ) .* sqrt.((eps .+ σ²))
    return x_normalized .* ln.g
end


struct LinearAttention{T<:Chain,S<:Conv}
    scale::Float32
    heads::Int
    to_qkv::S
    to_out::T
end


function LinearAttention(dim::Int; heads::Int=4, dim_head::Int=32)
    scale::Float32 = dim_head^-0.5
    hidden_dim = dim_head * heads
    to_qkv = Conv((1, 1), dim => hidden_dim * 3, bias=false)
    to_out = Chain(
        Conv((1, 1), hidden_dim => dim),
        LayerNorm(dim)  # Assuming a LayerNorm implementation as earlier
    )
    return LinearAttention(scale, heads, to_qkv, to_out)
end

Flux.@layer :expand LinearAttention
using Statistics: mean, var
function (la::LinearAttention)(x::ROCArray)
    h, w, c, b = size(x)
    qkv = Flux.chunk(la.to_qkv(x), 3; dims=3)
    q, k, v = qkv
    q, k, v = q[:, :, :, :], k[:, :, :, :], v[:, :, :, :]
    @cast q[c, (x, y), h, b] |= q[x, y, (c, h), b] h in 1:la.heads
    @cast k[c, (x, y), h, b] |= k[x, y, (c, h), b] h in 1:la.heads
    @cast v[c, (x, y), h, b] |= v[x, y, (c, h), b] h in 1:la.heads
    println(typeof(q))

    q = softmax(q, dims=1)
    k = softmax(k, dims=2)

    q *= la.scale

    v /= (h * w)
    println("typeof of k", typeof(k))
    println("typeof v", typeof(v))


    context = ein"dnhb,enhb->dehb"(k, v)
    println(typeof(context))

    # println("context: ", size(context))
    out = ein"dehb,dnhb->enhb"(context, q)
    println(typeof(out))
    # println("out: ", size(out))

    @cast out[x, y, (c, h), b] |= out[c, (x, y), h, b] (h in 1:la.heads, x in 1:h, y in 1:w)
    println(typeof(out))
    return la.to_out(out)
end


x = AMDGPU.randn(256, 256, 64, 8)
loss, grad = Flux.withgradient(layer) do l
    a = layer(x)
    sum(a)
end //this doen't work
layer(x) //this works

type of q is RocArry on inference but it is Base.ReshapedArray{Float32, 4, TransmuteDims.TransmutedDimsArray{Float32, 5, (3, 1, 2, 4, 5), (2, 3, 1, 4, 5), ROCArray{Float32, 5, AMDGPU.Runtime.Mem.HIPBuffer}}, NTuple{4, Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}

on calculating gradients

countradooku avatar Apr 26 '24 17:04 countradooku

Does OMEinsum support AMD at all?

TensorCast should in principle support everything. It is by default as lazy as possible, and hence returns things like Base.ReshapedArray, and some GPU things have trouble working through these wrappers. Replacing := with |= is how you ask this package to be less lazy.

mcabbott avatar Apr 26 '24 18:04 mcabbott

yes. i personally added suport to it

countradooku avatar Apr 26 '24 19:04 countradooku

ywah. i did that (as i was saying on salck) to directly get arrays an not .ReshapedArray

countradooku avatar Apr 26 '24 19:04 countradooku

Sorry the code does have |=.

Can you figure out what operation is causing the problem? There are 4 @cast and two @ein, what's the minimal example which causes this? Can you make gradient return a surprising type?

(I cannot run this locally)

mcabbott avatar Apr 26 '24 19:04 mcabbott

Actually yes This

v /= (h * w)

countradooku avatar Apr 26 '24 20:04 countradooku