KernelFunctions.jl icon indicating copy to clipboard operation
KernelFunctions.jl copied to clipboard

Something like "MethodError: PeriodicTransform{Vector{Float32}} with SubArray in Zygote gradient of TransformedKernel (D_feats=1)"

Open yuchenxiao95 opened this issue 4 months ago • 0 comments

julia> versioninfo() Julia Version 1.11.5 Commit 760b2e5b73 (2025-04-14 06:53 UTC) Build Info: Official https://julialang.org/ release Platform Info: OS: Windows (x86_64-w64-mingw32) CPU: 24 × Intel(R) Xeon(R) Silver 4410Y WORD_SIZE: 64 LLVM: libLLVM-16.0.6 (ORCJIT, sapphirerapids) Threads: 24 default, 0 interactive, 12 GC (on 24 virtual cores) Environment: JULIA_NUM_THREADS = 24 JULIA_EDITOR = code JULIA_VSCODE_REPL = 1

using KernelFunctions, Optimisers, Zygote, LinearAlgebra, Statistics using CUDA # To enable GPU arrays and operations using PDMats # For handling positive definite matrices if needed using Functors # Needed for traversing kernel parameters (though its role is minimized now) using Base.Iterators: only # Keep if you want to use it for other purposes, but not for θ_current

--- Kernel and Parameter Management ---

---------------- Composite Kernel Suggestion ----------------

All parameters are now explicitly Float32 scalars from the start.

function suggest_composite_kernel(N::Int; period_guess::Real=12.0f0) # Ensure period_guess is Float32 l_trend = Float32(0.1 * N) l_seasonal_period = Float32(0.7 * period_guess) l_local_scale = Float32(0.3 * period_guess)

# Explicitly ensure scalar Float32 for all transform parameters
# This is a defensive measure against type widening issues
k_trend = TransformedKernel(Matern52Kernel(), ScaleTransform(Float32(1.0f0 / l_trend)))
k_seasonal_base = RationalQuadraticKernel()
k_seasonal = k_seasonal_base ∘ PeriodicTransform(Float32(l_seasonal_period))
k_local_periodic_base = PeriodicKernel()
k_local_periodic = k_local_periodic_base ∘ ScaleTransform(Float32(1.0f0 / l_local_scale))
k_linear = LinearKernel()

return k_trend + k_linear + k_seasonal + k_local_periodic

end

---------------- Kernel Matrix Construction ----------------

This function forces the kernel matrix computation to happen on the CPU

if the input data is a CuArray, then moves the result back to GPU.

function build_kernel_matrix(kernel, X::AbstractMatrix{Float32}, log_noise::Float32) local K_raw if X isa CuArray{Float32} # FORCE KERNEL MATRIX COMPUTATION ON CPU AS A WORKAROUND # This will be inefficient due to data transfer, but bypasses the Distances.jl/CUDA.jl issue println("Warning: Forcing kernel matrix computation on CPU due to Distances.jl/CUDA.jl integration issue.") println(" Performance will be severely impacted for this step.")

    # Convert CuArray X back to CPU Array
    X_cpu = Array(X) 
    # Compute kernel matrix on CPU
    K_raw_cpu = kernelmatrix(kernel, ColVecs(X_cpu), ColVecs(X_cpu))
    # Move the computed kernel matrix back to GPU
    K_raw = CuArray(K_raw_cpu) 
else
    # Original path for CPU arrays or if X is not a CuArray
    K_raw = kernelmatrix(kernel, ColVecs(X), ColVecs(X))
end

noise_var = exp(2 * log_noise)
# Ensure noise_diag is also on GPU if K_raw is on GPU (which it will be if X was CuArray)
noise_diag = Diagonal(fill(noise_var, size(K_raw, 1))) 

return K_raw + noise_diag

end

---------------- Log Marginal Likelihood ----------------

function log_marginal_likelihood(K::AbstractMatrix{Float32}, y::AbstractVector{Float32}) n = length(y) try F = cholesky(Hermitian(K)) α = F \ y return -0.5f0 * (dot(y, α) + 2f0 * sum(log, diag(F)) + n * log(2f0 * π)) catch e if isa(e, PosDefException) || isa(e, ArgumentError) return -Inf32 else rethrow(e) end end end

---------------- GPU-Compatible Training Loop ----------------

function train_gp_gpu(X::AbstractMatrix{Float32}, y::AbstractVector{Float32}; max_epochs::Int=100, lr::Float32=0.01f0, use_gpu::Bool=CUDA.has_cuda())

# Ensure X is in the correct format for KernelFunctions.jl (features x observations).
if size(X, 1) != length(y)
    @assert size(X, 2) == length(y) "X must have dimensions (features x observations) or (observations x features)."
    X_reshaped = permutedims(X)
else
    X_reshaped = X
end

# Move data to GPU (CuArray) if use_gpu is true, otherwise keep on CPU.
X_device = use_gpu ? CuArray(X_reshaped) : X_reshaped
y_device = use_gpu ? CuArray(y) : y

initial_N = size(X_reshaped, 2)

# Kernel remains on CPU.
kernel_cpu = suggest_composite_kernel(initial_N, period_guess=12.0f0) 

# Assign the CPU kernel directly. No explicit GPU transfer for the kernel itself.
kernel_device = kernel_cpu 
if use_gpu
    println("Warning: Kernel object is on CPU. Using CuArray for data inputs. Performance might be impacted.")
    println("This bypasses the `ScaleTransform` MethodError for now.")
end

# Initialize the trainable parameter: log_noise.
log_noise = log(0.1f0 * std(Vector(y_device))) 
θ = Float32[log_noise] # Our only trainable parameter for this example
θ_device = use_gpu ? CuArray(θ) : θ # Move parameters to GPU

# Setup the Adam optimizer.
opt = Optimisers.Adam(lr)
state = Optimisers.setup(opt, θ_device) # Optimizer state is on the device

# Initialize best parameters and log marginal likelihood (LML).
best_θ_device = copy(θ_device)
best_lml = -Inf32

# Training loop.
for epoch in 1:max_epochs
    # Compute gradients of the negative log marginal likelihood with respect to θ_device.
    grads = Zygote.gradient(θ_device) do θ_current
        # Use `sum()` for GPU-compatible scalar extraction within Zygote
        current_log_noise = sum(θ_current) 
        K = build_kernel_matrix(kernel_device, X_device, current_log_noise)
        return -log_marginal_likelihood(K, y_device) # Minimize negative LML
    end

    # Update parameters using the optimizer if gradients are valid.
    if !isnothing(grads[1])
        state, θ_device = Optimisers.update(state, θ_device, grads[1])
        
        # Recalculate LML to track the true value after update.
        current_lml = -log_marginal_likelihood(
            build_kernel_matrix(kernel_device, X_device, sum(θ_device)), 
            y_device
        )

        # Update best parameters if current LML is better.
        if current_lml > best_lml
            best_lml = current_lml
            best_θ_device = copy(θ_device)
        end
    end

    # Print progress every 10 epochs.
    if epoch % 10 == 0
        println("[Epoch $epoch] LML = $(round(best_lml, digits=2))")
    end
end

# Convert the best learned log_noise parameter back to a CPU array for return.
final_log_noise_cpu = Array(best_θ_device)[1] # This is fine as it's after Array conversion

# Return the final kernel structure (from CPU) and the best learned noise.
return (kernel=kernel_cpu, log_noise=final_log_noise_cpu) # Return kernel_cpu

end

--- Example Usage ---

D_feats = 1 # Number of features (e.g., time) N_points = 2000 # Number of data points (can be larger for GPU benefits)

Create synthetic time data (e.g., from 0 to 100)

X_train_time = Float32.(rand(N_points, D_feats) * 100.0f0)

Create synthetic observations with a sine wave, linear trend, and noise

y_train = Float32.(sin.(X_train_time[:,1] * 0.5f0) .+ (X_train_time[:,1] * 0.01f0) .+ 0.5f0 .* randn(Float32, N_points))

println("--- Starting GPU-compatible Gaussian Process Training ---") println("Dataset size: $(N_points) observations, $(D_feats) feature.") println("Using GPU: $(CUDA.has_cuda() ? "Yes" : "No (CUDA not detected or enabled)")")

Call the training function.

trained_params = train_gp_gpu(X_train_time, y_train, max_epochs=200, lr=0.005f0, use_gpu=true)

println("\n--- Training Complete ---") println("Best trained log_noise: $(trained_params.log_noise)") println("Inferred noise variance: $(exp(2 * trained_params.log_noise))") println("Final kernel structure: $(trained_params.kernel)")

ERROR: MethodError: no method matching (::PeriodicTransform{Vector{Float32}})(::SubArray{Float32, 1, Matrix{Float32}, Tuple{Base.Slice{…}, Int64}, true}) The object of type PeriodicTransform{Vector{Float32}} exists, but no method is defined for this combination of argument types when trying to treat it as a callable object.

Closest candidates are: (::PeriodicTransform)(::Real) @ KernelFunctions C:\Users\yuchen.julia\packages\KernelFunctions\A0P7n\src\transform\periodic_transform.jl:28

Stacktrace: [1] macro expansion @ C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface2.jl:0 [inlined] [2] _pullback(ctx::Zygote.Context{…}, f::PeriodicTransform{…}, args::SubArray{…}) @ Zygote C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface2.jl:81 [3] (::Zygote.var"#676#680"{Zygote.Context{…}, PeriodicTransform{…}})(args::SubArray{Float32, 1, Matrix{…}, Tuple{…}, true}) @ Zygote C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\lib\array.jl:188 [4] iterate @ .\generator.jl:48 [inlined] [5] _collect @ .\array.jl:811 [inlined] [6] collect_similar @ .\array.jl:720 [inlined] [7] map @ .\abstractarray.jl:3371 [inlined] [8] ∇map @ C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\lib\array.jl:188 [inlined] [9] adjoint @ C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\lib\array.jl:214 [inlined] [10] _pullback @ C:\Users\yuchen.julia\packages\ZygoteRules\CkVIK\src\adjoint.jl:67 [inlined] [11] _map @ C:\Users\yuchen.julia\packages\KernelFunctions\A0P7n\src\transform\transform.jl:21 [inlined] [12] _pullback(::Zygote.Context{…}, ::typeof(KernelFunctions._map), ::PeriodicTransform{…}, ::ColVecs{…}) @ Zygote C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface2.jl:0 [13] kernelmatrix @ C:\Users\yuchen.julia\packages\KernelFunctions\A0P7n\src\kernels\transformedkernel.jl:117 [inlined] [14] _pullback(::Zygote.Context{…}, ::typeof(kernelmatrix), ::TransformedKernel{…}, ::ColVecs{…}, ::ColVecs{…}) @ Zygote C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface2.jl:0 [15] _apply(::Function, ::Vararg{Any}) @ Core .\boot.jl:946 [16] adjoint @ C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\lib\lib.jl:199 [inlined] [17] _pullback @ C:\Users\yuchen.julia\packages\ZygoteRules\CkVIK\src\adjoint.jl:67 [inlined] [18] _sum @ C:\Users\yuchen.julia\packages\KernelFunctions\A0P7n\src\kernels\kernelsum.jl:46 [inlined] [19] _pullback(::Zygote.Context{…}, ::typeof(KernelFunctions._sum), ::typeof(kernelmatrix), ::Tuple{…}, ::ColVecs{…}, ::ColVecs{…}) @ Zygote C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface2.jl:0 [20] _apply(::Function, ::Vararg{Any}) @ Core .\boot.jl:946 [21] adjoint @ C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\lib\lib.jl:199 [inlined] [22] _pullback @ C:\Users\yuchen.julia\packages\ZygoteRules\CkVIK\src\adjoint.jl:67 [inlined] [23] _sum @ C:\Users\yuchen.julia\packages\KernelFunctions\A0P7n\src\kernels\kernelsum.jl:46 [inlined] [24] _pullback(::Zygote.Context{…}, ::typeof(KernelFunctions._sum), ::typeof(kernelmatrix), ::Tuple{…}, ::ColVecs{…}, ::ColVecs{…}) @ Zygote C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface2.jl:0 [25] _apply(::Function, ::Vararg{Any}) @ Core .\boot.jl:946 [26] adjoint @ C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\lib\lib.jl:199 [inlined] [27] _pullback @ C:\Users\yuchen.julia\packages\ZygoteRules\CkVIK\src\adjoint.jl:67 [inlined] [28] _sum @ C:\Users\yuchen.julia\packages\KernelFunctions\A0P7n\src\kernels\kernelsum.jl:46 [inlined] [29] _pullback(::Zygote.Context{…}, ::typeof(KernelFunctions._sum), ::typeof(kernelmatrix), ::Tuple{…}, ::ColVecs{…}, ::ColVecs{…}) @ Zygote C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface2.jl:0 [30] kernelmatrix @ C:\Users\yuchen.julia\packages\KernelFunctions\A0P7n\src\kernels\kernelsum.jl:57 [inlined] [31] _pullback(::Zygote.Context{…}, ::typeof(kernelmatrix), ::KernelSum{…}, ::ColVecs{…}, ::ColVecs{…}) @ Zygote C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface2.jl:0 [32] build_kernel_matrix @ .\Untitled-1:41 [inlined] [33] _pullback(::Zygote.Context{…}, ::typeof(build_kernel_matrix), ::KernelSum{…}, ::CuArray{…}, ::Float32) @ Zygote C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface2.jl:0 [34] #13 @ .\Untitled-1:120 [inlined] [35] _pullback(ctx::Zygote.Context{…}, f::var"#13#14"{…}, args::CuArray{…}) @ Zygote C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface2.jl:0 [36] pullback(f::Function, cx::Zygote.Context{false}, args::CuArray{Float32, 1, CUDA.DeviceMemory}) @ Zygote C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface.jl:96 [37] pullback @ C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface.jl:94 [inlined] [38] gradient(f::Function, args::CuArray{Float32, 1, CUDA.DeviceMemory}) @ Zygote C:\Users\yuchen.julia\packages\Zygote\wfLOG\src\compiler\interface.jl:153 [39] train_gp_gpu(X::Matrix{Float32}, y::Vector{Float32}; max_epochs::Int64, lr::Float32, use_gpu::Bool) @ Main .\Untitled-1:117 [40] top-level scope @ Untitled-1:168 Some type information was truncated. Use show(err) to see complete types.

yuchenxiao95 avatar Jun 04 '25 03:06 yuchenxiao95