Redesign when ?
I played a little bit with the dev branch of redesign by GSOC author. It is a bit slower than AlphaGPU I think (pure performance wise) but with a few adjustement to Gumbel it works impressively well: this is with 32 rollouts and 4 actions, 20000 env on 3080, after around 6 minutes it catches AlphaZero.jl and then gets slightly better on tests.
Yes, @AndrewSpano did some really good work on this. I still haven't had time to finish the redesign but I am not losing hope! You mentioned adjustments to Gumbel: if you changed anything, would you mind submitting a PR?
Essentially you have to use the improved gumbel policy as training target (not the sequential halving move one hot encoded) and use the policy for moving during selfplay (again not the move you get from search), and add diversity at starting position (when reseting env, it randomly choose a position in all the positions you get with 0, 1 or two ply). My code is messy and naming is poor, here are essentially the 3 function to add/chnage in BatchedMCTS
"""
gumbel_policy(tree, mcts_config, gumbel)
Returns an array of size (num_envs,) containing the resulting actions selected
by the sequential halving procedure with gumbel for each environment. This function should
be used after `gumbel_explore()` has been run.
"""
function gumbel_policy(tree, mcts_config, current_steps,rng::AbstractRNG)
num_actions = Val(n_actions(tree))
# τ = mcts_config.tau
# deterministic_move_idx = mcts_config.collapse_tau_move
c_scale, c_visit = mcts_config.value_scale, mcts_config.max_visit_init
probs = DeviceArray(mcts_config.device)(rand(rng, Float32, batch_size(tree)))
actions = zeros(Int16, mcts_config.device, batch_size(tree))
Devices.foreach(1:batch_size(tree), mcts_config.device) do bid
# if current_steps[bid] >= 30#deterministic_move_idx
# t=0.3f0
# else
# t=1
# end
policy=get_ipolicy(tree,c_scale,c_visit,bid,num_actions,1.0f0)
actions[bid] = categorical_sample(policy, probs[bid])#gumbel_mcts_action(c_scale, c_visit, tree, bid, gumbel, num_actions)
end
return actions
end
"""
get_root_children_visits(tree, mcts_config)
Returns an array of size (num_actions, num_envs) containing the number of visits for
each action at the root node for each environment. This function should be used after
`gumbel_explore()` or `explore()` has been run.
"""
function get_ipolicy(tree,c_scale,c_visit, bid, num_actions::Val{A},τ=1.0f0) where {A}
logits = SVector{A}(imap(aid -> tree.valid_actions[aid, 1, bid] ? tree.logit_prior[aid, 1, bid]/τ : -Inf32, 1:A))+
transformed_qvalues(c_scale, c_visit, tree, 1, bid, num_actions)
return softmax(logits)
end
function get_root_ipolicy(tree, mcts_config,c_scale=0.1f0,c_visit=50)
# compute the policy: π′ = softmax(logits + σ(completedQ))
num_actions = Val(n_actions(tree))
#c_scale, c_visit = mcts_config.value_scale, mcts_config.max_visit_init
ipolicy=zeros(Float32, mcts_config.device, (n_actions(tree), batch_size(tree)))
Devices.foreach(1:batch_size(tree), mcts_config.device) do bid
ipolicy[:,bid] .= get_ipolicy(tree,c_scale,c_visit,bid,num_actions)
end
ipolicy
end
And in train:
if config.use_gumbel_mcts
t = @elapsed tree, gumbel = MCTS.gumbel_explore(mcts_config, envs, mcts_rng)
times.explore_times[step] = t
t = @elapsed begin
actions = MCTS.gumbel_policy(tree, mcts_config, steps_counter)
policy= get_root_ipolicy(tree, mcts_config)|>cpu
policy=[SVector{7}(policy[:,k]) for k in 1:size(policy)[2]]
end
times.selection_times[step] = t
How i choose random opening is through reset env but very ugly and not generic: but what should be done is randomly choose a number of steps and then play randomly those steps from beginning( for connect 4 i use 2 steps max)
Thanks, this is very useful!