AlphaZero.jl icon indicating copy to clipboard operation
AlphaZero.jl copied to clipboard

Redesign when ?

Open fabricerosay opened this issue 1 year ago • 21 comments

I played a little bit with the dev branch of redesign by GSOC author. It is a bit slower than AlphaGPU I think (pure performance wise) but with a few adjustement to Gumbel it works impressively well: this is with 32 rollouts and 4 actions, 20000 env on 3080, after around 6 minutes it catches AlphaZero.jl and then gets slightly better on tests. pascal_pons_benchmark_error_rates

fabricerosay avatar Oct 21 '24 17:10 fabricerosay

Yes, @AndrewSpano did some really good work on this. I still haven't had time to finish the redesign but I am not losing hope! You mentioned adjustments to Gumbel: if you changed anything, would you mind submitting a PR?

jonathan-laurent avatar Oct 21 '24 18:10 jonathan-laurent

Essentially you have to use the improved gumbel policy as training target (not the sequential halving move one hot encoded) and use the policy for moving during selfplay (again not the move you get from search), and add diversity at starting position (when reseting env, it randomly choose a position in all the positions you get with 0, 1 or two ply). My code is messy and naming is poor, here are essentially the 3 function to add/chnage in BatchedMCTS

"""
    gumbel_policy(tree, mcts_config, gumbel)

Returns an array of size (num_envs,) containing the resulting actions selected
by the sequential halving procedure with gumbel for each environment. This function should
be used after `gumbel_explore()` has been run.
"""
function gumbel_policy(tree, mcts_config, current_steps,rng::AbstractRNG)
    num_actions = Val(n_actions(tree))
#    τ = mcts_config.tau
 #   deterministic_move_idx = mcts_config.collapse_tau_move
    c_scale, c_visit = mcts_config.value_scale, mcts_config.max_visit_init
    probs = DeviceArray(mcts_config.device)(rand(rng, Float32, batch_size(tree)))
    actions = zeros(Int16, mcts_config.device, batch_size(tree))
    Devices.foreach(1:batch_size(tree), mcts_config.device) do bid
        # if current_steps[bid] >= 30#deterministic_move_idx 
        #     t=0.3f0
        # else
        #     t=1
        # end
        policy=get_ipolicy(tree,c_scale,c_visit,bid,num_actions,1.0f0)
        actions[bid] = categorical_sample(policy, probs[bid])#gumbel_mcts_action(c_scale, c_visit, tree, bid, gumbel, num_actions)
    end

    return actions
end


"""
    get_root_children_visits(tree, mcts_config)

Returns an array of size (num_actions, num_envs) containing the number of visits for
each action at the root node for each environment. This function should be used after
`gumbel_explore()` or `explore()` has been run.
"""
function get_ipolicy(tree,c_scale,c_visit, bid, num_actions::Val{A},τ=1.0f0) where {A}
    logits = SVector{A}(imap(aid -> tree.valid_actions[aid, 1, bid] ? tree.logit_prior[aid, 1, bid]/τ : -Inf32, 1:A))+
    transformed_qvalues(c_scale, c_visit, tree, 1, bid, num_actions)
    return softmax(logits)
end
function get_root_ipolicy(tree, mcts_config,c_scale=0.1f0,c_visit=50)
    # compute the policy: π′ = softmax(logits + σ(completedQ))
    num_actions = Val(n_actions(tree))
    #c_scale, c_visit = mcts_config.value_scale, mcts_config.max_visit_init
    ipolicy=zeros(Float32, mcts_config.device, (n_actions(tree), batch_size(tree)))
    Devices.foreach(1:batch_size(tree), mcts_config.device) do bid
        ipolicy[:,bid] .= get_ipolicy(tree,c_scale,c_visit,bid,num_actions)
    end
    ipolicy
end

And in train:

if config.use_gumbel_mcts
            t = @elapsed tree, gumbel = MCTS.gumbel_explore(mcts_config, envs, mcts_rng)
            times.explore_times[step] = t

            t = @elapsed begin
                actions = MCTS.gumbel_policy(tree, mcts_config,  steps_counter)
                policy= get_root_ipolicy(tree, mcts_config)|>cpu
                policy=[SVector{7}(policy[:,k]) for k in 1:size(policy)[2]]
            end
            times.selection_times[step] = t

How i choose random opening is through reset env but very ugly and not generic: but what should be done is randomly choose a number of steps and then play randomly those steps from beginning( for connect 4 i use 2 steps max)

fabricerosay avatar Oct 21 '24 18:10 fabricerosay

Thanks, this is very useful!

jonathan-laurent avatar Oct 22 '24 06:10 jonathan-laurent