Transformers.jl icon indicating copy to clipboard operation
Transformers.jl copied to clipboard

BERT pretrain example not working

Open jackn11 opened this issue 2 years ago • 15 comments

I tried running the BERT example found here https://github.com/chengchingwen/Transformers.jl/blob/master/example/BERT/_pretrain/pretrain.jl but I got the following error on the train!() line at the bottom of the code.

[ Info: loading pretrain bert model: uncased_L-12_H-768_A-12.tfbson wordpiece
[ Info: loading pretrain bert model: uncased_L-12_H-768_A-12.tfbson tokenizer
[ Info: loading pretrain bert model: uncased_L-12_H-768_A-12.tfbson bert_model
[ Info: start training
[ Info: epoch: 1
ERROR: LoadError: GPU compilation of kernel #broadcast_kernel#15(CUDA.CuKernelContext, CUDA.CuDeviceArray{Float32, 4, 1}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{4}, NTuple{4, Base.OneTo{Int64}}, typeof(+), Tuple{Base.Broadcast.Extruded{CUDA.CuDeviceArray{Float32, 4, 1}, NTuple{4, Bool}, NTuple{4, Int64}}, Base.Broadcast.Extruded{Array{Float32, 4}, NTuple{4, Bool}, NTuple{4, Int64}}}}, Int64) failed
KernelError: passing and using non-bitstype argument

Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{4}, NTuple{4, Base.OneTo{Int64}}, typeof(+), Tuple{Base.Broadcast.Extruded{CUDA.CuDeviceArray{Float32, 4, 1}, NTuple{4, Bool}, NTuple{4, Int64}}, Base.Broadcast.Extruded{Array{Float32, 4}, NTuple{4, Bool}, NTuple{4, Int64}}}}, which is not isbits:
  .args is of type Tuple{Base.Broadcast.Extruded{CUDA.CuDeviceArray{Float32, 4, 1}, NTuple{4, Bool}, NTuple{4, Int64}}, Base.Broadcast.Extruded{Array{Float32, 4}, NTuple{4, Bool}, NTuple{4, Int64}}} which is not isbits.
    .2 is of type Base.Broadcast.Extruded{Array{Float32, 4}, NTuple{4, Bool}, NTuple{4, Int64}} which is not isbits.
      .x is of type Array{Float32, 4} which is not isbits.


Stacktrace:
  [1] check_invocation(job::GPUCompiler.CompilerJob)
    @ GPUCompiler C:\Users\jackn\.julia\packages\GPUCompiler\iaKrd\src\validation.jl:86
  [2] macro expansion
    @ C:\Users\jackn\.julia\packages\GPUCompiler\iaKrd\src\driver.jl:413 [inlined]
  [3] macro expansion
    @ C:\Users\jackn\.julia\packages\TimerOutputs\jgSVI\src\TimerOutput.jl:252 [inlined]
  [4] macro expansion
    @ C:\Users\jackn\.julia\packages\GPUCompiler\iaKrd\src\driver.jl:412 [inlined]
  [5] emit_asm(job::GPUCompiler.CompilerJob, ir::LLVM.Module; strip::Bool, validate::Bool, format::LLVM.API.LLVMCodeGenFileType)
    @ GPUCompiler C:\Users\jackn\.julia\packages\GPUCompiler\iaKrd\src\utils.jl:64
  [6] cufunction_compile(job::GPUCompiler.CompilerJob, ctx::LLVM.Context)
    @ CUDA C:\Users\jackn\.julia\packages\CUDA\tTK8Y\src\compiler\execution.jl:354
  [7] #224
    @ C:\Users\jackn\.julia\packages\CUDA\tTK8Y\src\compiler\execution.jl:347 [inlined]
  [8] JuliaContext(f::CUDA.var"#224#225"{GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams, GPUCompiler.FunctionSpec{GPUArrays.var"#broadcast_kernel#15", Tuple{CUDA.CuKernelContext, CUDA.CuDeviceArray{Float32, 4, 1}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{4}, NTuple{4, Base.OneTo{Int64}}, typeof(+), Tuple{Base.Broadcast.Extruded{CUDA.CuDeviceArray{Float32, 4, 1}, NTuple{4, Bool}, NTuple{4, Int64}}, Base.Broadcast.Extruded{Array{Float32, 4}, NTuple{4, Bool}, NTuple{4, Int64}}}}, Int64}}}})
    @ GPUCompiler C:\Users\jackn\.julia\packages\GPUCompiler\iaKrd\src\driver.jl:74
  [9] cufunction_compile(job::GPUCompiler.CompilerJob)
    @ CUDA C:\Users\jackn\.julia\packages\CUDA\tTK8Y\src\compiler\execution.jl:346
 [10] cached_compilation(cache::Dict{UInt64, Any}, job::GPUCompiler.CompilerJob, compiler::typeof(CUDA.cufunction_compile), linker::typeof(CUDA.cufunction_link))
    @ GPUCompiler C:\Users\jackn\.julia\packages\GPUCompiler\iaKrd\src\cache.jl:90
 [11] cufunction(f::GPUArrays.var"#broadcast_kernel#15", tt::Type{Tuple{CUDA.CuKernelContext, CUDA.CuDeviceArray{Float32, 4, 1}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{4}, NTuple{4, Base.OneTo{Int64}}, typeof(+), Tuple{Base.Broadcast.Extruded{CUDA.CuDeviceArray{Float32, 4, 1}, NTuple{4, Bool}, NTuple{4, Int64}}, Base.Broadcast.Extruded{Array{Float32, 4}, NTuple{4, Bool}, NTuple{4, Int64}}}}, Int64}}; name::Nothing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ CUDA C:\Users\jackn\.julia\packages\CUDA\tTK8Y\src\compiler\execution.jl:299
 [12] cufunction
    @ C:\Users\jackn\.julia\packages\CUDA\tTK8Y\src\compiler\execution.jl:293 [inlined]
 [13] macro expansion
    @ C:\Users\jackn\.julia\packages\CUDA\tTK8Y\src\compiler\execution.jl:102 [inlined]
 [14] #launch_heuristic#248
    @ C:\Users\jackn\.julia\packages\CUDA\tTK8Y\src\gpuarrays.jl:17 [inlined]
 [15] _copyto!
    @ C:\Users\jackn\.julia\packages\GPUArrays\EVTem\src\host\broadcast.jl:73 [inlined]
 [16] copyto!
    @ C:\Users\jackn\.julia\packages\GPUArrays\EVTem\src\host\broadcast.jl:56 [inlined]
 [17] copy
    @ C:\Users\jackn\.julia\packages\GPUArrays\EVTem\src\host\broadcast.jl:47 [inlined]
 [18] materialize
    @ .\broadcast.jl:860 [inlined]
 [19] apply_mask(score::CUDA.CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}, mask::Array{Float32, 3})
    @ Transformers.Basic C:\Users\jackn\.julia\packages\Transformers\K1F88\src\basic\mh_atten.jl:182
 [20] apply_mask
    @ C:\Users\jackn\.julia\packages\Transformers\K1F88\src\basic\mh_atten.jl:188 [inlined]
 [21] attention(query::CUDA.CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}, key::CUDA.CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}, value::CUDA.CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}, mask::Array{Float32, 3}, future::Bool, dropout::Dropout{Float64, Colon, CUDA.RNG})
    @ Transformers.Basic C:\Users\jackn\.julia\packages\Transformers\K1F88\src\basic\mh_atten.jl:204
 [22] (::Transformers.Basic.MultiheadAttention{Dense{typeof(identity), CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dropout{Float64, Colon, CUDA.RNG}})(query::CUDA.CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}, key::CUDA.CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}, value::CUDA.CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}; mask::Array{Float32, 3})
    @ Transformers.Basic C:\Users\jackn\.julia\packages\Transformers\K1F88\src\basic\mh_atten.jl:102
 [23] (::Transformer{Transformers.Basic.MultiheadAttention{Dense{typeof(identity), CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dropout{Float64, Colon, CUDA.RNG}}, LayerNorm{typeof(identity), Flux.Scale{typeof(identity), CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Float32, 1}, Transformers.Basic.PwFFN{Dense{typeof(gelu), CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, LayerNorm{typeof(identity), Flux.Scale{typeof(identity), CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Float32, 1}, Dropout{Float64, Colon, CUDA.RNG}})(x::CUDA.CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}, mask::Array{Float32, 3})
    @ Transformers.Basic C:\Users\jackn\.julia\packages\Transformers\K1F88\src\basic\transformer.jl:69
 [24] macro expansion
    @ C:\Users\jackn\.julia\packages\Transformers\K1F88\src\stacks\stack.jl:0 [inlined]
 [25] (::Stack{Symbol("((x, m) => x':(x, m)) => 12"), NTuple{12, Transformer{Transformers.Basic.MultiheadAttention{Dense{typeof(identity), CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dropout{Float64, Colon, CUDA.RNG}}, LayerNorm{typeof(identity), Flux.Scale{typeof(identity), CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Float32, 1}, Transformers.Basic.PwFFN{Dense{typeof(gelu), CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, LayerNorm{typeof(identity), Flux.Scale{typeof(identity), CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Float32, 1}, Dropout{Float64, Colon, CUDA.RNG}}}})(::CUDA.CuArray{Float32, 3, 
CUDA.Mem.DeviceBuffer}, ::Array{Float32, 3})
    @ Transformers.Stacks C:\Users\jackn\.julia\packages\Transformers\K1F88\src\stacks\stack.jl:19
 [26] (::Bert{Stack{Symbol("((x, m) => x':(x, m)) => 12"), NTuple{12, Transformer{Transformers.Basic.MultiheadAttention{Dense{typeof(identity), CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CUDA.CuArray{Float32, 2, 
CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 
1, CUDA.Mem.DeviceBuffer}}, Dropout{Float64, Colon, CUDA.RNG}}, LayerNorm{typeof(identity), Flux.Scale{typeof(identity), CUDA.CuArray{Float32, 1, 
CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Float32, 1}, Transformers.Basic.PwFFN{Dense{typeof(gelu), CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, LayerNorm{typeof(identity), Flux.Scale{typeof(identity), CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Float32, 1}, Dropout{Float64, Colon, CUDA.RNG}}}}, Dropout{Float64, Colon, CUDA.RNG}})(x::CUDA.CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}, mask::Array{Float32, 3}; all::Bool)
    @ Transformers.BidirectionalEncoder C:\Users\jackn\.julia\packages\Transformers\K1F88\src\bert\bert.jl:55
 [27] Bert
    @ C:\Users\jackn\.julia\packages\Transformers\K1F88\src\bert\bert.jl:50 [inlined]
 [28] loss(data::NamedTuple{(:tok, :segment), Tuple{Matrix{Int64}, Matrix{Int64}}}, ind::Vector{Tuple{Int64, Int64}}, masklabel::Flux.OneHotArray{UInt32, 30522, 1, 2, Vector{UInt32}}, nextlabel::Flux.OneHotArray{UInt32, 2, 1, 2, Vector{UInt32}}, mask::Array{Float32, 3})

jackn11 avatar Jul 04 '22 17:07 jackn11

Out of curiosity, you want to pre-train your own Bert model?

chengchingwen avatar Jul 04 '22 18:07 chengchingwen

No, I simply want to fine tune an existing Bert model for use in categorization task. I would appreciate guidance in how to do that as I have not been able to figure it out with this package.

jackn11 avatar Jul 04 '22 18:07 jackn11

Should mostly look like this one https://github.com/chengchingwen/Transformers.jl/blob/master/example/BERT/cola/train.jl

chengchingwen avatar Jul 04 '22 18:07 chengchingwen

Or you want to do token level classification?

chengchingwen avatar Jul 04 '22 19:07 chengchingwen

Thank you, I will take a look at that link! My current project is classifying sentences into a few given categories. I am not quite sure what you mean by token level classifications? If you mean classifying parts of sentences as tokens, for example finding nouns, verbs, etc., then I am not trying to implement token level classification.

jackn11 avatar Jul 04 '22 19:07 jackn11

If you mean classifying parts of sentences as tokens, for example finding nouns, verbs, etc.,

Yes, that's what I means.

My current project is classifying sentences into a few given categories.

Then you can just follow the link. CoLA is also a sentence classification dataset.

chengchingwen avatar Jul 04 '22 19:07 chengchingwen

I got the following error when trying to run the cola example as you recommended. It seems like it may be an issue with the example.

PS C:\Users\jackn\Documents\GitHub\GitHub2\Transformers.jl\example\BERT> julia --proj -i main.jl --gpu cola
[ Info: loading pretrain bert model: uncased_L-12_H-768_A-12.tfbson 
               _
   _       _ _(_)_     |  Documentation: https://docs.julialang.org
  (_)     | (_) (_)    |
   _ _   _| |_  __ _   |  Type "?" for help, "]?" for Pkg help.    
  | | | | | | |/ _` |  |
  | | |_| | | | (_| |  |  Version 1.7.3 (2022-05-06)
 _/ |\__'_|_|_|\__'_|  |  Official https://julialang.org/ release  
|__/                   |

julia> train!()
[ Info: start training: cola
[ Info: epoch: 1
ERROR: ArgumentError: 'CoLA.zip' exists. `force=true` is required to remove 'CoLA.zip' before moving.
Stacktrace:
  [1] checkfor_mv_cp_cptree(src::String, dst::String, txt::String; force::Bool)
    @ Base.Filesystem .\file.jl:325
  [2] #mv#17
    @ .\file.jl:411 [inlined]
  [3] mv
    @ .\file.jl:411 [inlined]
  [4] (::Transformers.Datasets.GLUE.var"#1#2")(fn::String)
    @ Transformers.Datasets.GLUE C:\Users\jackn\.julia\packages\Transformers\K1F88\src\datasets\glue\cola.jl:11
  [5] #16
    @ C:\Users\jackn\.julia\packages\DataDeps\jDkzU\src\resolution_automatic.jl:122 [inlined]
  [6] cd(f::DataDeps.var"#16#17"{Transformers.Datasets.GLUE.var"#1#2", String}, dir::String)
    @ Base.Filesystem .\file.jl:99
  [7] run_post_fetch(post_fetch_method::Transformers.Datasets.GLUE.var"#1#2", fetched_path::String)
    @ DataDeps C:\Users\jackn\.julia\packages\DataDeps\jDkzU\src\resolution_automatic.jl:119
  [8] download(datadep::DataDeps.DataDep{String, String, typeof(DataDeps.fetch_default), Transformers.Datasets.GLUE.var"#1#2"}, localdir::String; remotepath::String, i_accept_the_terms_of_use::Nothing, skip_checksum::Bool)
    @ DataDeps C:\Users\jackn\.julia\packages\DataDeps\jDkzU\src\resolution_automatic.jl:84
  [9] download
    @ C:\Users\jackn\.julia\packages\DataDeps\jDkzU\src\resolution_automatic.jl:70 [inlined]
 [10] handle_missing
    @ C:\Users\jackn\.julia\packages\DataDeps\jDkzU\src\resolution_automatic.jl:10 [inlined]
 [11] _resolve
    @ C:\Users\jackn\.julia\packages\DataDeps\jDkzU\src\resolution.jl:83 [inlined]
 [12] resolve(datadep::DataDeps.DataDep{String, String, typeof(DataDeps.fetch_default), Transformers.Datasets.GLUE.var"#1#2"}, inner_filepath::String, calling_filepath::String)
    @ DataDeps C:\Users\jackn\.julia\packages\DataDeps\jDkzU\src\resolution.jl:29
 [13] resolve(datadep_name::String, inner_filepath::String, calling_filepath::String)
    @ DataDeps C:\Users\jackn\.julia\packages\DataDeps\jDkzU\src\resolution.jl:54
 [14] resolve
    @ C:\Users\jackn\.julia\packages\DataDeps\jDkzU\src\resolution.jl:73 [inlined]
 [15] trainfile(#unused#::Transformers.Datasets.GLUE.CoLA)
    @ Transformers.Datasets.GLUE C:\Users\jackn\.julia\packages\Transformers\K1F88\src\datasets\glue\cola.jl:23
 [16] #datafile#2
    @ C:\Users\jackn\.julia\packages\Transformers\K1F88\src\datasets\dataset.jl:14 [inlined]
 [17] datafile
    @ C:\Users\jackn\.julia\packages\Transformers\K1F88\src\datasets\dataset.jl:14 [inlined]
 [18] dataset(::Type{Train}, ::Transformers.Datasets.GLUE.CoLA; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Transformers.Datasets C:\Users\jackn\.julia\packages\Transformers\K1F88\src\datasets\dataset.jl:9
 [19] dataset
    @ C:\Users\jackn\.julia\packages\Transformers\K1F88\src\datasets\dataset.jl:9 [inlined]
 [20] train!()
    @ Main C:\Users\jackn\Documents\GitHub\GitHub2\Transformers.jl\example\BERT\cola\train.jl:75
 [21] top-level scope
    @ REPL[1]:1

jackn11 avatar Jul 05 '22 15:07 jackn11

@jackn11 Holy! That is a bug for downloading the datasets. I was using the local files all the time so didn't notice it.

For a quick workaround:

  1. Download the dataset from https://dl.fbaipublicfiles.com/glue/data/CoLA.zip
  2. Unzip the CoLA.zip
  3. Move the directory to ~/.julia/datadeps/
  4. Renamed the directory as GLUE-CoLA (so the full path is ~/.julia/datadeps/GLUE-CoLA/

chengchingwen avatar Jul 05 '22 15:07 chengchingwen

Thank you for letting me know! Will do.

Also, comments in the example code would be a huge plus to make it easier to understand.

jackn11 avatar Jul 05 '22 15:07 jackn11

Did you ever resolve this issue? I'm having the same. Using the training example Peter linked

Broever101 avatar Aug 05 '22 17:08 Broever101

The main issue for the pretrain example is that the code is really really old and I haven't find time to update that. But I didn't make it a high priority because people seldom need/want to do pretrain themselves

chengchingwen avatar Aug 05 '22 18:08 chengchingwen

If i recall correctly, I was able to get the CoLA example working, but it took me going through the code line by line, one piece at a time to figure out where the errors were coming from. I may have some time in the coming days to update the example @chengchingwen. Would that be helpful?

jackn11 avatar Aug 05 '22 19:08 jackn11

The main issue for the pretrain example is that the code is really really old and I haven't find time to update that. But I didn't make it a high priority because people seldom need/want to do pretrain themselves

Yea pretrain is old but I'm having the issue with the fine-tune example https://github.com/chengchingwen/Transformers.jl/tree/master/example/BERT

Fine-tuning is a bit more common than pre-training. The dataset API is pretty gucci but it needs some comments. Maybe I'll contribute. Got the 1.8M tweet dataset working with your training script in a few lines. But again, getting the thing running on gpu is a bit of a pain.

Broever101 avatar Aug 05 '22 20:08 Broever101

Feel free to open issues/PRs! Sometimes it's hard to make stuff clear from my side because I wrote those code and they all look straightforward to me. Let me know if you need help or explanations.

chengchingwen avatar Aug 06 '22 11:08 chengchingwen

That sounds great! I should have a PR up soon for the COLA example!

jackn11 avatar Aug 07 '22 01:08 jackn11