openfl
openfl copied to clipboard
Update TensorFlow Task Runner and related workspaces
Related Issue: #973
Summary: This PR aims to update the TensorFlow Task Runner to use Keras as the high-level API, which is in line with best practices as well as updates existing TF workspaces. This enables the usage of non-legacy optimizers (which will be deprecated in future versions of TF/Keras)
Specifically, this PR:
- Creates a new
TensorFlowTaskRunner
class inopenfl.federated.task.runner_tf
which borrows heavily from theKerasTaskRunner
task. Major difference is in handling the weights of the optimizer which was necessitated by the removal of the.get_weight()
and.weights()
attributes from the optimizer. This newTensorFlowTaskRunner
extracts weights from the.variables()
attribute- Also updated the train and validation task names to
train_validation
andtask_validation
to be consistent with the torch taskrunner
- Also updated the train and validation task names to
- Archived old
TensorFlowTaskRunner
asTensorFlowTaskRunner_v1
withinopenfl.federated.task.runner_tf
and updated the__init__
files to make it callable. Rationale is to avoid any breaking changes for tutorials or upstream applications that still relied on the low-level TF taskrunner. This can be removed entirely in a future release as needed- Also updated the train and validation task names to
train_validation
andtask_validation
to be consistent with the torch taskrunner
- Also updated the train and validation task names to
- Created a new
tf_cnn_mnist
workspace and updated thetorch_cnn_histology
workspace to run on the newTensorFlowTaskRunner
using thesrc/dataloader.py
andsrc/taskrunner.py
convention.- update to
TensorFlow v2.15.1
(latest TensorFlow to not useKeras v3.x
by default)
- update to
- Minor
tf_3dunet_brats
to use newTensorFlowTaskRunner
(did not make changes tosrc
files because I did not have Brats3D dataset to verify a large update - Minor updates to
tf_2dunet
to run on archivedTensorFlowTaskRunner_v1
Future work
- Consolidation step still needed:
- Migrate all
tf_2d_unet
fromTensorFlowTaskRunner_v1
to newTensorFlowTaskRunner
- Migrate all keras workspaces from
KerasTaskRunner
to newTensorFlowTaskRunner
and remove/archive KerasTaskRunner
- Migrate all
- Look into updated
TensorFlowTaskRunner
to run onTF v2.16+
withKeras 3.x
(this may need some large changes to weight handling that will likely not have backwards compatibility)