syft.js
syft.js copied to clipboard
Check dimensions of input tensors before executing plans
Description
// Execute the plan and get updated model params back.
let [loss, acc, ...updatedModelParams] = await job.plans[
'training_plan'
].execute(
job.worker,
tf.zeros([chunkSize, 50*120]),
tf.zeros([chunkSize, 2]),
chunkSize,
lr,
...modelParams
);
If the dimensions of the input tensors in the above code snippet are wrong, then syft.js throws the following error
Error: We cannot find function mul in TensorFlow.js, performing a manual lookup.
.
This error message is not very intuitive.
Instead, we can check for the dimensions of the input tensors, and throw error warning the user regarding the wrong dimension.
Are you interested in working on this improvement yourself?
- Yes, I am.
Good find @sachin-101! Would you like to work on this and make a PR? If so, please let me know.
Yes @cereallarceny
Can we split into 2 issues?
- [fix] Propagate TFJS error when it occurs during plan execution. Right now it's overwritten with translation error because we don't differentiate between TFJS error and Threepio error.
- [improvement] check inputs shape. shapes can be obtrained from
Plan.input_placeholders
, the Placeholder should haveexpected_shape
property. I'm not sure what to do with the batch size dimension. Ideally, PySyft should record it as -1 or None or similar, maybe it's already possible via@func2plan(args_shape=...)
, needs check.
@vvmnnnkv Go ahead. Shall I modify this issue to only contain the 2'nd part? While you can open the first one in threepio?
Done: #171. Note, it's not threepio problem.