楼主: CDA网校
62 0

联邦学习(第一部分):数据本地化训练模型的基础 [推广有奖]

管理员

已卖:189份资源

泰斗

4%

还不是VIP/贵宾

-

威望
3
论坛币
120347 个
通用积分
11064.3488
学术水平
278 点
热心指数
286 点
信用等级
253 点
经验
228990 点
帖子
6981
精华
19
在线时间
4386 小时
注册时间
2019-9-13
最后登录
2026-1-20

初级热心勋章

楼主
CDA网校 学生认证  发表于 昨天 15:39 |AI写论文

+2 论坛币
k人 参与回答

经管之家送您一份

应届毕业生专属福利!

求职就业群
赵安豆老师微信:zhaoandou666

经管之家联合CDA

送您一个全额奖学金名额~ !

感谢您参与论坛问题回答

经管之家送您两个论坛币!

+2 论坛币

我第一次接触联邦学习(FL)的概念,是通过谷歌2019年发布的一则漫画。这幅漫画构思巧妙,清晰地阐释了产品如何在不将用户数据上传至云端的前提下实现优化。最近,我希望更深入地探究这一领域的技术细节。训练数据已成为至关重要的资源——构建优质模型离不开它,但大量数据因分散、非结构化或被禁锢在数据孤岛中而未被利用。

在探索过程中,我发现Flower框架是入门联邦学习最简洁、对新手最友好的工具。它开源免费、文档清晰,且拥有活跃热心的社区支持,这也是我重新燃起对该领域兴趣的原因之一。

本文是系列文章的第一部分,将深入探讨联邦学习:包括其定义、实现方式、面临的开放性问题,以及为何在隐私敏感场景中不可或缺。后续文章中,我将结合Flower框架展开实操教学,探讨联邦学习中的隐私保护问题,并分析这些理念在更复杂场景中的延伸应用。

何时不适合使用集中式机器学习

我们知道,人工智能模型依赖海量数据,但多数高价值数据具有敏感性、分布式且难以获取的特点——比如医院、手机、汽车、传感器及其他边缘设备中存储的数据。隐私顾虑、本地法规限制、存储资源不足和网络条件约束,使得将这些数据迁移至中心节点变得异常困难,甚至完全不可行。最终,大量宝贵数据被闲置。在医疗领域,这一问题尤为突出:医院每年产生数十PB的数据,但研究估计,其中高达97%的数据从未被利用。

传统机器学习的前提是,所有训练数据可集中存储于单一位置(通常是中心服务器或数据中心)。这种模式在数据可自由迁移时可行,但在数据涉及隐私保护或受监管限制时便会失效。实际上,集中式训练还依赖稳定的网络连接、充足的带宽和低延迟,这些条件在分布式或边缘环境中难以保障。

面对这种情况,通常有两种选择:一是完全放弃使用这些数据,让有价值的信息继续被禁锢在数据孤岛中;

二是让每个本地主体基于自身数据训练模型,仅共享模型学到的知识,原始数据始终留在本地。第二种选择正是联邦学习的核心思想——它允许模型在不迁移数据的前提下,从分布式数据中学习。谷歌安卓系统的Gboard输入法就是典型案例,其下一词预测、智能补全等功能,便是在数亿台设备上通过联邦学习实现的。

联邦学习:让模型“走向”数据

联邦学习可理解为一种协作式机器学习架构,无需将数据集中存储,即可完成模型训练。在深入其底层原理前,我们先通过几个真实案例,看看这种方法在医疗、安全等高危敏感场景中的价值。

医疗健康领域

在医疗领域,联邦学习通过Curial AI系统实现了早期新冠筛查。该系统基于多家英国国民保健署(NHS)医院的常规生命体征和血液检测数据训练而成。由于患者数据无法跨医院共享,模型训练在各医院本地完成,仅交换模型更新参数。最终得到的全局模型,其泛化能力优于单家医院训练的模型,尤其在未参与训练的医院场景中表现更出色。

医学影像领域

配图说明:《自然》期刊发表的视网膜基础模型研究,展示了如何基于敏感眼疾影像数据训练大规模医学影像模型 | 开放获取
配图说明:《自然》期刊发表的视网膜基础模型研究,展示了如何基于敏感眼疾影像数据训练大规模医学影像模型 | 开放获取

联邦学习也在医学影像领域得到应用。伦敦大学学院和穆尔菲尔德眼科医院的研究人员,正利用该技术在敏感眼疾扫描数据上微调大型视觉基础模型,这些数据无法被集中存储。

国防领域

除医疗外,联邦学习还被应用于国防、航空等安全敏感领域。这些场景中,模型训练依赖分布式的生理数据和运行数据,而这些数据必须留在本地。

联邦学习的主要类型

从宏观上看,联邦学习可根据参与主体(客户端)的类型和数据拆分方式,分为以下几类:

1. 跨设备与跨机构联邦学习

跨设备联邦学习的参与客户端数量庞大(可达数百万),通常是个人设备(如手机),每个客户端仅存储少量本地数据,且网络连接不稳定。在任意一轮训练中,仅有少数设备参与。谷歌Gboard输入法就是这类联邦学习的典型应用。

跨机构联邦学习的客户端数量较少,通常是医院、银行等组织。每个客户端拥有大规模数据集,且具备稳定的计算资源和网络连接。现实中的企业级应用和医疗场景,大多采用跨机构联邦学习架构。

2. 水平与垂直联邦学习

水平联邦学习描述的是数据在客户端间的“横向”拆分:所有客户端的特征空间一致,但各自拥有不同的样本数据。例如,多家医院记录的医疗指标相同,但对应不同患者。这是联邦学习最常见的形式。

垂直联邦学习适用于客户端拥有相同样本实体,但特征空间不同的场景。例如,医院和保险公司可能拥有同一批用户的数据,但数据属性不同(医院存储诊疗记录,保险公司存储投保信息)。这类场景的训练需要安全协同机制(因特征空间不一致),应用范围不如水平联邦学习广泛。

上述分类并非互斥。实际系统通常结合两种维度描述,例如“跨机构-水平联邦学习”架构。

联邦学习的工作原理

联邦学习遵循一套简单的循环流程,由中心服务器协调,多个持有本地数据的客户端执行,流程如下图所示。

配图说明:联邦学习循环流程可视化
配图说明:联邦学习循环流程可视化

联邦学习的训练过程由多轮“联邦学习回合”构成。每一轮中,服务器随机选择少量客户端,发送当前全局模型权重,等待客户端返回更新结果;每个客户端利用随机梯度下降(SGD)在本地训练模型(通常在本地数据批次上训练若干轮),仅将更新后的权重返回服务器。整体流程可概括为以下五个步骤:

1. 初始化阶段

服务器(作为协调者)初始化一个全局模型。该模型可随机初始化,也可基于预训练模型启动。

2. 模型分发阶段

每一轮训练中,服务器通过随机抽样或预设策略选择一组客户端参与训练,并向其发送当前全局模型权重。这些客户端可以是手机、物联网设备,也可以是独立医院。

3. 本地训练阶段

每个被选中的客户端,利用自身本地数据训练模型。数据始终不离开客户端,所有计算均在设备本地或机构内部(如医院、银行)完成。

4. 模型更新传输阶段

本地训练完成后,客户端仅将更新后的模型参数(权重或梯度)发送回服务器,全程不共享原始数据。

5. 聚合阶段

服务器聚合所有客户端的模型更新,生成新的全局模型。联邦平均(Fed Avg)是最常用的聚合策略,也可采用其他方法。更新后的全局模型将再次分发至客户端,重复上述流程,直至模型收敛。

联邦学习是一个迭代过程,每完成一次上述循环称为一轮训练。训练一个联邦模型通常需要多轮迭代(有时达数百轮),具体取决于模型规模、数据分布和问题复杂度等因素。

联邦平均的数学原理

上述流程可通过更严谨的数学公式表示。下图展示了谷歌开创性论文中提出的原始联邦平均(Fed Avg)算法,该算法后来成为联邦学习的核心参考标准,证明了联邦学习的实践可行性,也是当前多数联邦学习系统的设计基础。

配图说明:原始联邦平均算法,展示服务器-客户端训练循环及本地模型的加权聚合 | 来源:《从去中心化数据中高效学习深度网络》
配图说明:原始联邦平均算法,展示服务器-客户端训练循环及本地模型的加权聚合 | 来源:《从去中心化数据中高效学习深度网络》

联邦平均的核心是聚合步骤:服务器通过对客户端本地训练模型进行加权平均,更新全局模型。数学表达式如下:

公式说明:联邦平均算法的数学表示
公式说明:联邦平均算法的数学表示

该公式清晰体现了各客户端对全局模型的贡献:本地数据量越多的客户端,权重越大,对全局模型的影响越显著;数据量少的客户端,贡献则按比例降低。正是这一简洁的设计思路,让联邦平均成为联邦学习的默认基准算法。

基于NumPy的简易实现

我们通过一个极简案例演示联邦平均的实现:假设已选中5个客户端,且每个客户端均完成本地训练,返回更新后的模型权重及所用样本数量。服务器基于这些数据计算加权和,生成下一轮训练的全局模型。该案例直接对应联邦平均公式,暂不涉及训练过程和客户端细节。


import numpy as np

# 客户端本地训练后的模型权重(w_{t+1}^k)
client_weights = [
    np.array([1.00.80.5]),     # 客户端1
    np.array([1.20.90.6]),     # 客户端2
    np.array([0.90.70.4]),     # 客户端3
    np.array([1.10.850.55]),   # 客户端4
    np.array([1.31.00.65])     # 客户端5
]

# 每个客户端的样本数量(n_k)
client_sizes = [501501003004000]

# m_t = 选中客户端(S_t)的总样本数
m_t = sum(client_sizes) # 计算:50+150+100+300+4000 = 4600

# 初始化全局模型 w_{t+1}
w_t_plus_1 = np.zeros_like(client_weights[0])

# 联邦平均聚合过程:
# w_{t+1} = Σ(k∈S_t)(n_k / m_t) * w_{t+1}^k
# 即:(50/4600)*客户端1权重 + (150/4600)*客户端2权重 + ...

for w_k, n_k in zip(client_weights, client_sizes):
    w_t_plus_1 += (n_k / m_t) * w_k

print("聚合后的全局模型 w_{t+1}:", w_t_plus_1)

输出结果:聚合后的全局模型 w_{t+1}: [1.27173913 0.97826087 0.63478261]

聚合过程解析

为更直观理解,我们可拆解两个客户端的聚合计算过程,验证结果合理性。

配图说明:双客户端聚合计算过程示意
配图说明:双客户端聚合计算过程示意

联邦学习面临的挑战

联邦学习也存在自身特有的挑战。核心问题之一是客户端数据通常呈非独立同分布(non-IID)特征——不同客户端的数据集分布差异较大,这会减缓训练速度,降低全局模型的稳定性。例如,联邦学习中的多家医院,服务的患者群体不同,数据分布模式也会存在差异。

联邦系统的参与主体可从少数组织扩展至数百万台设备,随着规模扩大,客户端参与管理、故障处理(如客户端中途退出)和聚合协调的难度会显著增加。

此外,联邦学习虽能保持原始数据本地化,但无法单独彻底解决隐私问题:若缺乏保护机制,模型更新参数仍可能泄露隐私信息,因此通常需要额外的隐私增强技术。最后,通信瓶颈也是关键问题——边缘网络可能速度慢、稳定性差,频繁传输模型更新会产生高昂成本。

总结与后续计划

本文概述了联邦学习的核心原理,并通过NumPy实现了简易案例。但实际应用中,无需手动编写核心逻辑——Flower等框架可提供简洁灵活的工具,帮助构建联邦学习系统。下一篇文章中,我们将借助Flower框架简化开发流程,聚焦模型设计和数据处理,而非联邦学习的底层机制。同时,我们还将探讨联邦大语言模型(联邦LLMs),这类场景中,模型规模、通信成本和隐私约束的重要性会进一步凸显。

推荐学习书籍 《CDA一级教材》适合CDA一级考生备考,也适合业务及数据分析岗位的从业者提升自我。完整电子版已上线CDA网校,累计已有10万+在读~ !

免费加入阅读:https://edu.cda.cn/goods/show/3151?targetId=5147&preview=0

二维码

扫码加我 拉你入群

请注明:姓名-公司-职位

以便审核进群资格,未注明则拒绝

关键词:本地化 weights Preview Client flower

您需要登录后才可以回帖 登录 | 我要注册

本版微信群
扫码
拉您进交流群
GMT+8, 2026-1-21 04:28