SGDLibrary icon indicating copy to clipboard operation
SGDLibrary copied to clipboard

The 'NAG' submode in sgd_cm.m

Open hxyokokok opened this issue 4 years ago • 3 comments

Hi @hiroyuki-kasai ,

Thanks for publishing this project. This looks great and I would like do some research with this toolbox.

My trouble is in the sgd_cm.m. It looks like containing two momentum schemes, the classic one ('CM') and the Nesterov's ('NAG'), but in the current implementain, they seem to differ only in the setting of the momentum coefficient. See lines 78, 80, and 82.

In my impression NAG should 'look one step ahead' before the gradient calculation, but in the code, the gradient is evaluated just in the current point. This seems to be inconsistent to the original paper. See equations 3 and 4 in Ilya Sutskever, James Martens, George Dahl and Geoffrey Hinton, "On the importance of initialization and momentum in deep learning," ICML, 2013.

Thank you!

hxyokokok avatar May 08 '20 08:05 hxyokokok

Dear Xiaoyu He,

I really thank you for your email. If possible, please share the correct code with me. Let me check it out.

Best,

Hiro

hiroyuki-kasai avatar May 10 '20 00:05 hiroyuki-kasai

Dear Hiroyuki KASAI ,

Thanks for this reply.

Please see the code below.

function [w, infos] = sgd_hxy(problem, in_opts)
    % set dimensions and samples
    d = problem.dim();
    n = problem.samples();  

    % set local options 
    local_opts.sub_mode = 'Nesterov';  
    local_opts.mu = 0.99;
    local_opts.epsilon = 1e-4;
    local_opts.mu_max = 0.99;
    
    % merge options
    opts = mergeOptions(get_default_options(d), local_opts);   
    opts = mergeOptions(opts, in_opts);  
    
    % counters
    iters = 0; % index of mini-batch processing
    epoch = 0; % index of epochs
    grad_calc_count = 0; % number of gradient evaluation

    w = opts.w_init; % initial variable
    v = zeros(size(w));
    % store first infos
    clear infos;    
    [infos, f_val, optgap] = store_infos(problem, w, opts, [], epoch, grad_calc_count, 0);
    
    % display infos
    if opts.verbose > 0
        fprintf('SGD: Epoch = %03d, cost = %.16e, optgap = %.4e\n', epoch, f_val, optgap);
    end    

    % set start time
    start_time = tic();

    % main loop
    while (optgap > opts.tol_optgap) && (epoch < opts.max_epoch)

        % re-permute in each epoch
        if opts.permute_on
            perm_idx = randperm(n);
        else
            perm_idx = 1:n;
        end

        for j = 1 : floor(n / opts.batch_size)

            % mini-batch
            indice_j = (j-1) * opts.batch_size + (1:opts.batch_size);
            indice_j = perm_idx(indice_j);

            grad_calc_count = grad_calc_count + opts.batch_size;        

            if strcmp(opts.sub_mode, 'none') % standard SGD
                grad =  problem.grad(w, indice_j); % evaluate at current step
                ss = opts.stepsizefun(iters, opts);
                v = - ss * grad;
            elseif strcmp(opts.sub_mode, 'classic')
                grad =  problem.grad(w, indice_j); % evaluate at current step
                v = opts.mu * v - opts.epsilon * grad;
            elseif strcmp(opts.sub_mode, 'Nesterov') 
                mu_ = min(1 - 2 ^ (-1-log2(floor(iters/250)+1)),opts.mu_max);
                grad =  problem.grad(w + mu_ * v, indice_j); % evaluate at the next step
                v = mu_ * v - opts.epsilon * grad; % and then correct
            else
                error(opts.sub_mode);
            end


            % descent
            w = w + v;
            
            % % proximal operator
            % if ismethod(problem, 'prox')
            %     w = problem.prox(w, ss);
            % end  
            
            iters = iters + 1;
        end
        
        % measure elapsed time
        elapsed_time = toc(start_time);
        
        % count gradient evaluations
        epoch = epoch + 1;

        % store infos
        [infos, f_val, optgap] = store_infos(problem, w, opts, infos, epoch, grad_calc_count, elapsed_time);        

        % display infos
        if opts.verbose > 0
            fprintf('SGD: Epoch = %03d, cost = %.16e, optgap = %.4e\n', epoch, f_val, optgap);
        end

    end
    
    if optgap < opts.tol_optgap
        fprintf('Optimality gap tolerance reached: tol_optgap = %g\n', opts.tol_optgap);
    elseif epoch == opts.max_epoch
        fprintf('Max epoch reached: max_epoch = %g\n', opts.max_epoch);
    end
    
end

hxyokokok avatar May 10 '20 08:05 hxyokokok

Dear Xiaoyu He,

I appreciate your support. Let me check your code. It takes some times, although, due to the preparation of my online lectures in my university for the current crisis.

Best regards,

Hiro

hiroyuki-kasai avatar May 12 '20 07:05 hiroyuki-kasai