文章摘要 | 手机、可穿戴设备等终端设备每天产生海量数据,但这些数据往往涉及敏感隐私而不能直接公开并使用.为解决隐私保护下的机器学习问题,联邦学习应运而生,旨在通过构建协同训练机制,在不共享客户端数据条件下,训练高性能全局模型.然而,在实际应用中,现有联邦学习机制面临两大不足:(1)全局模型需考虑多个客户端的数据,但各客户端往往仅包含部分类别数据且类别间数据量严重不均衡,使得全局模型难以训练;(2)各客户端之间的数据分布往往存在较大差异,导致各客户端模型往往差异较大,使得传统通过模型参数加权平均以获得全局模型的方法难以奏效.为降低客户端类别不均衡和数据分布差异的影响,本文提出一种基于数据生成的类别均衡联邦学习(Class-BalancedFederatedLearning,CBFL)方法.CBFL旨在通过数据生成技术,针对各客户端构造符合全局模型学习的类别均衡数据集.为此,CBFL设计了一个包含类别均衡采样器和数据生成器的类别分布均衡器.其中,类别均衡采样器对客户端数据量不足的类别以较高概率进行采样.然后,数据生成器则根据所采样的类别生成相应的虚拟数据以均衡客户端数据的类别分布并用于后续的模型训练.为验证所提出方法的有效性,本文在四个标准数据集上进行了大量实验.实验表明,本文方法可大幅提升联邦学习性能:如在CIFAR-100数据集上,CBFL训练的ResNet20模型与现有方法相比,分类准确率提高了5.82%. |