federated
federated copied to clipboard
Using Functional Model to build an iterative process
Describe the bug
I am trying to use a Functional Model in an averaging process but I cant make it work.
I have also tried to use tff.learning.models.model_from_functional and tff.learning.algorithms.build_unweighted_fed_avg but still failed.
For more details you can take a look at this notebook: https://github.com/teo-milea/PySyft/blob/pysytff/notebooks/PySyTFF/functional_bug.ipynb
Environment (please complete the following information):
- OS Platform and Distribution: ubuntu 20.04
- Python package versions: TensorFlow Federated = 0.27.0 , TensorFlow = 2.9.1
- Python version: 3.9.7
- Bazel version (if building from source): -
- CUDA/cuDNN version: -
- What TensorFlow Federated execution stack are you using? -
Expected behavior I would like to create an iterative process with a model from a functional model (either the functional itself or a transformation from it) in any way possible.
Additional context
Thanks for the details, they're really helpful.
What you're running into is the fact that things like build_federated_averaging_process
require a tff.learning.Model
to work. We have a wrapper to do this conversion for you. All you'll need to make your colab work (which I verified at TFF's latest version) is something like the following:
...
functional_model = tff.learning.functional_model_from_keras(...)
def tff_model_fn() -> tff.learning.Model:
return tff.learning.models.model_from_functional(functional_model)
This can be plugged into algorithms that TFF provides.
Side note: tff.learning.build_federated_averaging_process
has been removed in TFF's newest version, and is replaced by tff.learning.algorithms.build_weighted_fed_avg
.
For clarification: You would do something like
tff.learning.algorithms.build_weighted_fed_avg(
tff_model_fn,
...)
after my code snippet above.
Hi, thanks for your help, this indeed solves my problem. I tried to use tff.learning.functional_model_from_keras
but instead of wrapping the call itself, I wrapped the result with a lambda. Could you maybe clarify why this works while the lambda with the result fails? Is the function called multiple times and requires different instances of a tff.learning.Model
? For experimentation, I replaced your function with
tff.learning.algorithms.build_weighted_fed_avg(
lambda: tff.learning.models.model_from_functional(functional_model),
...)
and indeed worked.
The tff.learning.algorithms
API generally only accepts functions that produce a tff.learning.Model
, while functional_model_from_keras
produces a FunctionalModel
(which does not inherit from tff.learning.Model
).
The utility of the functional model for now is that it enables a potentially easier way to express model logic, and that it can avoid relying on tf.Variable
(and therefore can be serialized/deserialized/what have you). This API is still evolving though, so contributions are welcome.
Thank you for your answer! One thing that still bugs me is why this doesn't work:
while this works:
This is a good question, and I think an instance where a more descriptive error could be thrown. Basically, in the first approach you've already created the tff.learning.Model
(including associated tf.Variable
s). However, TFF essentially serializes model update logic (similar to graph mode in TF), and can't use variables created outside the scope of the logic. This is why the second approach works; it provides a function to create the model, while avoiding creating tf.Variable
s outside of certain scopes.
I'll keep this open until we've improved the error message here.
Please, how can I resolve this problem
That API symbol doesn't exist. Please use tff.learning.algorithms.build_weighted_fed_avg
instead. We encourage you to check out the tutorials and documentation, which are all up to date with the latest TFF version.
Hello, thanks for your help, this is already solving my problem but I have this problem now
@ayouch33 Please file a separate github bug for any bugs you believe are caused by TFF. However, we encourage you to do invesigation first to make sure that this is relevant to TFF. In the case above, it looks like the error is exactly as the traceback says - you haven't defined model_fn
globally anywhere (your first screenshot has some suspicious whitespace shifting that I suspect is problematic).
Thanks a lot for your assistance. i really apperciate it