syft.js icon indicating copy to clipboard operation
syft.js copied to clipboard

Check dimensions of input tensors before executing plans

Open sachin-101 opened this issue 4 years ago • 5 comments

Description

// Execute the plan and get updated model params back.
let [loss, acc, ...updatedModelParams] = await job.plans[
    'training_plan'
].execute(
    job.worker,
    tf.zeros([chunkSize, 50*120]),
    tf.zeros([chunkSize, 2]),
    chunkSize,
    lr,
    ...modelParams
);            

If the dimensions of the input tensors in the above code snippet are wrong, then syft.js throws the following error Error: We cannot find function mul in TensorFlow.js, performing a manual lookup.. This error message is not very intuitive.

Instead, we can check for the dimensions of the input tensors, and throw error warning the user regarding the wrong dimension.

Are you interested in working on this improvement yourself?

  • Yes, I am.

sachin-101 avatar Jun 19 '20 07:06 sachin-101

Good find @sachin-101! Would you like to work on this and make a PR? If so, please let me know.

cereallarceny avatar Jun 19 '20 10:06 cereallarceny

Yes @cereallarceny

sachin-101 avatar Jun 22 '20 04:06 sachin-101

Can we split into 2 issues?

  1. [fix] Propagate TFJS error when it occurs during plan execution. Right now it's overwritten with translation error because we don't differentiate between TFJS error and Threepio error.
  2. [improvement] check inputs shape. shapes can be obtrained from Plan.input_placeholders, the Placeholder should have expected_shape property. I'm not sure what to do with the batch size dimension. Ideally, PySyft should record it as -1 or None or similar, maybe it's already possible via @func2plan(args_shape=...), needs check.

vvmnnnkv avatar Jun 22 '20 05:06 vvmnnnkv

@vvmnnnkv Go ahead. Shall I modify this issue to only contain the 2'nd part? While you can open the first one in threepio?

sachin-101 avatar Jun 22 '20 11:06 sachin-101

Done: #171. Note, it's not threepio problem.

vvmnnnkv avatar Jun 23 '20 05:06 vvmnnnkv