ocannl
ocannl copied to clipboard
Extend the `%cd` syntax with commenting (`Block_comment` with string interpolation)
For example, the comment breaks the flow here:
let sgd_one ?(lr = 0.001) ?(momentum = 0.0) ?(weight_decay = 0.0) ?(nesterov = false) p =
if not @@ is_param p then raise @@ Tensor.Session_error ("Train.sgd_one: not a parameter", Some p);
let pg = NTDSL.term ~label:(p.value.label ^ " sgd delta") () in
let b = NTDSL.term ~label:(p.value.label ^ " sgd momentum") () in
Assignments.Block_comment
( desc_label_suffix p.value.label ^ " param sgd step",
[%cd
pg =: p.grad + (!.weight_decay *. p);
if Float.(momentum > 0.0) then (
b =: (!.momentum *. b) + pg;
if nesterov then pg =+ !.momentum *. b else pg =: b);
p =- !.lr *. pg] )
Would be nicer with something like (syntax to be decided):
let sgd_one ?(lr = 0.001) ?(momentum = 0.0) ?(weight_decay = 0.0) ?(nesterov = false) p =
if not @@ is_param p then raise @@ Tensor.Session_error ("Train.sgd_one: not a parameter", Some p);
let pg = NTDSL.term ~label:(p.value.label ^ " sgd delta") () in
let b = NTDSL.term ~label:(p.value.label ^ " sgd momentum") () in
[%cd
# p " param sgd step";
pg =: p.grad + (!.weight_decay *. p);
if Float.(momentum > 0.0) then (
b =: (!.momentum *. b) + pg;
if nesterov then pg =+ !.momentum *. b else pg =: b);
p =- !.lr *. pg]