tf-fed-demo
tf-fed-demo copied to clipboard
Help Wanted: What the significance of hyperparameter `CLIENT_RATIO_PER_ROUND`
作者大大,你好。你写的这个框架,对我理解 FL 帮助很大,简直太棒了!
然后我不理解,在训练过程中,为什么要随机选取参与者来训练:
-
为什么不把 100 个参与者全都放进来呢,然后减少训练的轮数,不用 360 那么多了
-
CLIENT_RATIO_PER_ROUND
为什么选定为 0.12 呢,是有什么深层的理由嘛
望回复,不胜感激。
@The-MinGo 你好,感谢你的关注,我也只是学生仔不是什么大佬hhh。我按照我的理解回答一下这两个问题:
- “随机选取参与者”是依照Google的文章里的一个假设:并非所有的参与者都会随时在线(available),例如手机作为参与者是有可能关机、断网的
-
0.12
没有什么深层理由,一般是作为一个可调超参,用于模拟不同现实场景(比如手机场景可能是0.12,工业集群场景可能是0.99)
@Zing22 十分感谢回复!
也就是说 FL 其实也是一种在线学习 (这是我没意识到的),然后参与者们的联网状态是不能保证的了,所以用了 CLIENT_RATIO_PER_ROUND
来模拟这种因素。所以能把这篇论文的标题也给我嘛,俺想看看具体的描述。
然后,俺这边是导师新开了一个 FL 的研究方向,我作为萌新啥也不知道的入坑了,是做的啥 FL 下的训练集信息泄漏的相关问题。然后作者君是做的 FL 的哪方面呢,能方便告知吗,我这边真的是筚路蓝缕一七三零了
@The-MinGo 具体在哪篇文章我也记不清了,可能是“Federated Optimization: Distributed Optimization Beyond the Datacenter”,也可能是在这篇博客里:http://ai.googleblog.com/2017/04/federated-learning-collaborative.html 我现在做的方向是联邦传输过程的压缩、训练模式,也还在头秃中...
@Zing22 感谢答复 博客里好像没有提到,估计在论文里了。 俺这有其他老师也做的这个 FL 里的传输成本的问题,还有对参与者进行聚类的。 啥机器学习的基础都还没有就给我上这个了,太南了,只有共勉了
@Zing22 你好,我又来了 (ಥ _ ಥ)
在文件 tf-fed-demo/src/Client.py
的第 49 行:
dataset.size // batch_size
已经保留整数位从而取整了,然后 math.ceil()
向上取整的意图就实现不了了。
好像参与者的数据集就不能完全参与训练了,这里好像只有 480 条数据被用到了,还有 20 条没覆盖到。
不知道是不是这么理解的
@The-MinGo 你应该是对的,是一个bug,过几天我改一改再上传。感谢!
@Zing22 你好,我又发现了几处我认为有问题的地方,想与作者君讨论一下,无意冒犯
- 在文件
tf-fed-demo/master/src/Dataset.py
中:
如前面提到的那个已经确定的 BUG,训练批次的大小定为 32,但是存在数据条数无法与之整除的情况,也就是存在某次分批训练的数据量不足 32,因此这句断言 assert len(perm0) == batch_size
会导致程序终止
- 在文件
/tf-fed-demo/master/src/Server.py
中:
这里使用 for ... zip()
循环的意图是累加模型参数,但是在这个循环中,采用的是值传递,对 cv
的操作并不会影响 client_vars_sum
,
就像 Java 中的 foreach 循环只能用于遍历而不能修改,测试后发现这里的 client_vars_sum
始终都等于初始化时的 current_client_vars
- 在文件
tf-fed-demo/master/src/Client.py
中也存在这样的情况:
不知道是否是这样,期待答复
@The-MinGo 你好,还是非常感谢你的关注。我分点回答一下这几个问题:
-
我目测不足
batch_size
的时候perm0
的值应该是18行给定的,长度应该会和batch_size
一致。不过我还没测试过不确定对不对。 -
client_vars_sum
里的元素是np.array
,是mutable object。而你举的例子是int
,是immutable object。我理解的最小样例是这样的:
- 同上,而且这个
load()
是写到显存里的,应该没有问题
再次感谢你的提问,我最近会找个时间检查一遍代码的。谢谢!
@Zing22 非常感谢答复!
关于第 2 条和 第 3 条,是我没理解 mutable object 和 immutable object 的指针传递及值传递,测试的时候也是简单的用 ==
来比较 client_vars_sum
与最初的 current_client_vars
是否相等,而由于是 immutable 对象,显然会相等的,所以做出了误判
╯︿╰
关于第 1 点,我现在理解你的意思了,尽管数据总数除之 batch_size 有余,但可以将上一个批次的部分数据取过来补足 (这部分数据参与了 2 次训练),使最后一个批次的大小也为 batch_size
你都是对的,我闹了个大乌龙 (ಥ _ ಥ)
@Zing22 你好,关于 CLIENT_RATIO_PER_ROUND
, 我还有一个问题想请教一下
已知这个变量存在是为了模拟参与者们的联网情况。
因为是单机模拟联邦学习过程,所有参与者共享同一个计算图,不同的参与者可以视为不同的"训练批次“。每轮中按比例随机选取若干参与者参与训练,而不是全部参与训练的话,这会是导致每轮之间的预测准确度来回波动的原因之一吗?
虽然总体来说准确度是在逐渐上升的。
@The-MinGo 是这个原因,更深层的原因可能是每个设备上的数据分布不同,专业名词是non-IID,也是联邦研究里的一个核心问题。