towardsdatascience.com/design-an-easy-to-use-deep-learning-framework-52d7c37e415f?source=collection_archive---------6-----------------------#2024-04-10
作为开源贡献者,我学到的三条软件设计原则
作者:Haifeng Jin
发布于 Towards Data Science · 约9分钟阅读 · 2024年4月10日
深度学习框架的发展极为迅速。若将当前主流使用的工具与八年前进行对比,整个技术生态已然焕然一新。Theano、Caffe2 和 MXNet 曾经占据主导地位,如今却逐渐淡出视野。而现今最受欢迎的框架如 TensorFlow 和 PyTorch,在当年才刚刚向公众开放。
在这一波浪潮中,Keras 作为一个高层接口库,能够兼容多种后端(包括 TensorFlow、PyTorch 和 JAX),成功地延续了下来。作为 Keras 的一名贡献者,我深刻体会到团队对用户体验的高度重视,并通过遵循一些简洁而有力的设计原则,持续优化用户的使用感受。
本文将总结我在参与 Keras 开发过程中所领悟到的三项核心软件设计原则。这些原则不仅适用于深度学习框架,也可能广泛应用于各类软件项目,帮助你在开源社区中建立影响力。
为何用户体验在开源软件中至关重要?
我们可以从 PyTorch 与 TensorFlow 的竞争格局中找到答案。这两个框架分别由 Meta 和 Google 主导开发,背后体现了两种不同的企业文化:Meta 擅长打造用户友好的产品,而 Google 更专注于工程技术上的极致优化。
因此,Google 推出的 TensorFlow 和 JAX 在性能层面表现卓越,尤其在稀疏张量处理和分布式训练方面具备明显优势。然而,尽管技术领先,TensorFlow 仍失去了大量市场份额给 PyTorch——原因正是后者更注重用户体验。
对于研究人员而言,构建模型只是第一步;更重要的是如何让工程师顺利接手并部署这些模型。如果迁移成本过高,工程师往往会选择围绕 PyTorch 构建新的工具链,从而进一步强化其生态系统。
TensorFlow 自身也曾因用户体验问题导致用户流失。例如,其 GPU 安装指南长期存在缺陷,直到 2022 年才得以完善。此外,TensorFlow 2.0 版本打破了向后兼容性,使得许多用户在升级过程中面临巨大成本,甚至造成数百万美元的损失。
由此可见,即便技术实力强大,最终决定用户选择的往往是整体体验。良好的用户体验不仅能留住开发者,还能促进生态繁荣。
主流框架都在用户体验上大力投入
无论是 TensorFlow、PyTorch 还是 JAX,它们都在提升用户体验方面投入了大量资源。一个显著的证据是:这些框架的代码库中,Python 所占比例相当高。
实际上,所有核心功能——如张量运算、自动微分、图编译和分布式计算——都是用 C++ 实现的。但为什么还要向上层暴露一套完整的 Python API?答案很简单:因为用户偏爱 Python。为了贴近用户习惯,各大团队不惜投入精力打磨 Python 层的接口设计。
投资用户体验具有高回报率
设想一下,要让你的框架比现有方案快上几个百分点,需要多少底层优化和工程投入?答案是:极其庞大。
相比之下,改善用户体验则更具性价比。只要遵循科学的设计流程和清晰的原则,就能显著提升可用性。在吸引开发者方面,优秀的用户体验与卓越的计算效率同样关键。因此,这方面的投入往往能带来更高的回报。
三大设计原则分享
接下来,我将结合自己在 Keras 贡献中的实际经验,分享三条我认为最为关键的软件设计原则,并辅以来自不同框架的正反案例说明。
原则一:设计端到端的工作流
当我们着手设计一个软件 API 时,常见的做法可能是这样:
class Model:
def __call__(self, input):
"""模型的前向传播方法。
参数:
input: 张量,模型的输入数据。
"""
pass
这种方式定义了类结构、方法签名和文档说明,看似完整。但我们仍然难以从中感知用户在真实场景下是如何使用这个接口的。
真正有效的方法,是从业务流程出发,模拟用户从开始到完成任务的完整路径。比如,一个典型的建模流程应包含数据加载、模型构建、训练、评估和保存等步骤。API 设计应当围绕这一全流程展开,而不是孤立地看待单个类或函数。
以 Keras 为例,它提供了一套高度抽象且连贯的接口:
model = Sequential([
Dense(64, activation='relu'),
Dense(10, activation='softmax')
])
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=10)
score = model.evaluate(x_test, y_test)
model.save('my_model.keras')
这段代码展示了清晰的线性流程,每一步都自然衔接,极大降低了学习和使用门槛。这种“端到端可运行”的设计思维,正是良好用户体验的核心体现。
相反,某些框架要求用户手动管理设备分配、图构建、会话启动等多个环节,导致初学者极易迷失在细节中。即使功能强大,也因流程断裂而劝退大量潜在用户。
因此,优秀的设计不是堆砌功能,而是引导用户顺畅地走完全程。
原则二:保持一致性与最小惊喜原则
用户在使用 API 时,期望行为具有一致性和可预测性。所谓“最小惊喜原则”(Principle of Least Surprise),即接口的行为应符合用户的直觉预期,避免出现反常识的设计。
举例来说,若一个方法名为 save(),用户自然预期它会持久化当前对象的状态;若该方法实际上只保存部分配置,则会造成困惑。
Keras 在这方面做得较好。无论你使用的是 TensorFlow、PyTorch 还是 JAX 后端,model.fit()、model.predict() 等核心方法的行为始终保持一致。切换后端不会改变高层接口语义,大大降低了迁移成本。
反观某些早期框架,同一操作在不同模块中有不同命名,例如 run()、execute()、forward() 混杂使用,增加了记忆负担。这种不一致性直接损害了可用性。
为此,我们在设计时应统一术语、参数顺序和返回值格式。哪怕底层实现差异巨大,对外暴露的接口也应尽可能保持一致。
原则三:渐进式复杂性(Progressive Complexity)
一个好的系统应当允许用户从简单入手,随着需求增长逐步深入高级功能,而不是一开始就面对复杂的配置项。
Keras 提供了“由简入繁”的路径:新手可以使用 Sequential 快速搭建线性模型;当需要分支结构时,可转向函数式 API;若需自定义训练逻辑,则可通过子类化 Model 类实现。整个过程平滑过渡,无需推倒重来。
这种设计被称为“渐进式复杂性”——基础用法简单直观,高级功能依然强大灵活。用户可以在不破坏已有代码的前提下,逐步演进模型架构和训练策略。
相比之下,某些框架要求用户从一开始就编写大量样板代码,哪怕只是训练一个简单的全连接网络。这种“重量级入门”模式极大地阻碍了快速实验和迭代。
通过支持多层次抽象,Keras 成功吸引了从初学者到专家的广泛用户群体。
结语
开源项目的成功不仅仅取决于技术先进性,更在于能否为用户提供流畅、一致且易于扩展的使用体验。Keras 的持续生命力正是源于对这三条原则的坚持:
- 设计端到端工作流:让用户能顺畅完成整个任务流程;
- 保持一致性与最小惊喜:降低认知负担,增强可预测性;
- 支持渐进式复杂性:兼顾易用性与灵活性,适应不同阶段的需求。
这些经验不仅适用于深度学习领域,也可推广至通用软件设计。如果你希望自己的开源项目被更多人采用,不妨从用户体验的角度重新审视你的 API 设计。
https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/f0f2539047f9ad89bd753971d4f517d5.png
图片来源:Sheldon @ Unsplash
我们希望构建一个完整的用户使用软件的工作流。理想情况下,这应当呈现为一个清晰的教程,展示如何实际使用该工具。相比仅罗列类和方法,这种端到端流程能更深入地揭示用户体验中的潜在问题,帮助我们在设计早期阶段发现并优化交互细节。
让我们通过 KerasTuner 的实际开发过程来看一个具体例子。在实现过程中,我正是通过撰写完整使用流程,发现了若干关键的用户体验痛点。
在使用 KerasTuner 时,用户通常会借助以下结构来搜索最优模型:
RandomSearch
其中包含 metrics 和 objective 参数。默认情况下,objective 被设为验证损失(如 "val_loss"),用于寻找验证损失最小的模型配置。
class RandomSearch:
def __init__(self, ..., metrics, objective="val_loss", ...):
"""初始化函数。
Args:
metrics: Keras 指标列表。
objective: 字符串或自定义指标函数,表示要最小化的目标指标名称。
"""
pass
单从接口定义来看,一切似乎合理且直观。然而,当我们尝试编写一个完整的使用场景时,问题开始浮现。
假设用户想要引入一个名为 custom_metric 的自定义评估函数:
custom_metric
此时,objective 参数应如何设置?用户可能会困惑于应该传入什么样的字符串标识。
tuner = RandomSearch(
...,
metrics=[custom_metric],
objective="val_???",
)
按照原设计,用户需要手动拼接前缀与函数名,例如使用
"val_custom_metric"
并添加
"val_"
作为前缀,形成类似 "val_custom_metric" 的字符串。这种方式不仅不够直观,还增加了用户的记忆负担和出错概率。
更进一步,如果我们要完整实现这个 custom_metric 函数,就必须了解 Keras 自定义指标的具体编写方式——包括必须遵循特定的函数签名:
def custom_metric(y_true, y_pred):
squared_diff = ops.square(y_true - y_pred)
return ops.mean(squared_diff, axis=-1)
这意味着用户不仅要理解如何定义逻辑,还需掌握框架层面的技术细节,这对大多数使用者而言是不必要的学习成本。
发现问题后,我们重构了自定义指标的使用流程。新的设计允许用户直接在一个方法中计算并返回结果,无需关心命名规则或函数签名:
HyperModel.fit()
class MyHyperModel(HyperModel):
def fit(self, trial, model, validation_data):
x_val, y_true = validation_data
y_pred = model(x_val)
return custom_metric(y_true, y_pred)
tuner = RandomSearch(MyHyperModel(), max_trials=20)
这一改进显著提升了可用性:用户不再需要记忆字符串格式,也不必遵循复杂的接口规范,只需关注核心逻辑本身。整个体验更加自然、流畅。
这个案例再次印证了一个重要原则:设计应始终以用户体验为起点。工作流的构建反过来指导和优化底层实现。
原则 2:最小化认知负担
除非绝对必要,否则不应要求用户学习新概念或复杂规则。优秀的 API 设计应当基于用户已有的知识体系。
一个典型的成功案例是 Keras 的模型构建 API。大多数模型构建者已经熟悉“模型是由层堆叠而成”、“训练需要损失函数”、“可以用数据进行拟合和预测”等基本理念。因此,Keras 的使用方式几乎无需额外学习:
model = keras.Sequential([
layers.Dense(10, activation="relu"),
layers.Dense(num_classes, activation="softmax"),
])
model.compile(loss='categorical_crossentropy')
model.fit(...)
model.predict(...)
由于这些操作符合直觉且贴近常见模式,用户可以快速上手,而不需要掌握全新的抽象概念。这种低认知负荷的设计,正是良好用户体验的核心体现。
在模型构建方面,PyTorch 提供了一种非常直观的方式。其代码执行方式与标准 Python 一致,所有张量都是具备实际数值的真实对象。你可以直接根据张量的值来决定程序流程,并使用原生的 Python 控制语句实现条件分支。
例如:
class MyModel(nn.Module):
def forward(self, x):
if x.sum() > 0:
return self.path_a(x)
return self.path_b(x)
相比之下,在 Keras 结合 TensorFlow 或 JAX 后端时,虽然也能实现类似逻辑,但写法不同。必须通过特定函数来表达条件控制,如下所示:
class MyModel(keras.Model):
def call(self, inputs):
return ops.cond(
ops.sum(inputs) > 0,
lambda : self.path_a(inputs),
lambda : self.path_b(inputs),
)
if
ops.cond
这种方式要求用户学习新的操作范式,而非直接使用熟悉的 if-else 结构,这增加了理解成本。尽管它带来了训练效率上的显著提升作为补偿,但从用户体验角度来说并不理想。
这也揭示了 PyTorch 灵活性背后的代价。当你需要对模型进行内存和速度优化时,就必须手动引入一系列底层 API 和新概念,比如操作的 inplace 参数、并行处理接口以及显式的设备分配机制。这些都显著提高了学习门槛。
torch.relu(x, inplace=True)
x = torch._foreach_add(x, y)
torch._foreach_add_(x, y)
x = x.cuda()
keras.ops
tensorflow.numpy
jax.numpy
一些设计良好的例子是对 NumPy API 的复刻。与其让用户从头学习一套全新的接口(可能包含上百个函数),不如复用已被广泛掌握的现有模式。NumPy 拥有详尽的文档支持,Stack Overflow 上也有大量相关问答资源,极大降低了用户的认知负担。每个深度学习框架都需要提供一定的低级操作能力,而沿用主流 API 是更优选择。
然而,在用户体验中最不可取的做法之一是“误导”——让接口看起来像用户熟悉的形式,实则行为完全不同。以下将展示两个典型反例,分别来自 PyTorch 和 TensorFlow。
假设你有一个形状为
(100, 3, 32, 32)
的输入张量,想要将其填充至
(100, 3, 1+32+1, 2+32+2)
或
(100, 3, 34, 36)
的尺寸,那么在调用
F.pad()
函数时,pad 参数应如何设置?
import torch.nn.functional as F
# 将 32x32 图像填充为 (1+32+1)x(2+32+2)
# 即从 (100, 3, 32, 32) 变为 (100, 3, 34, 36)
out = F.pad(
torch.empty(100, 3, 32, 32),
pad=???,
)
直觉上,很多人会认为参数应该是
((0, 0), (0, 0), (1, 1), (2, 2))
,即每个子元组对应一个维度,两个数字分别表示该维度前后要添加的元素数量,这种思路源自对 NumPy 填充规则的理解。
但实际上,正确答案是 (2, 2, 1, 1)。这里没有嵌套元组结构,仅使用平铺元组,且维度顺序是逆序的:最后一个维度的填充信息排在最前面。
再来看 TensorFlow 的另一个问题示例:
value = True
@tf.function
def get_value():
return value
value = False
print(get_value())
若不使用
tf.function
装饰器,输出显然是 False。但加上装饰器后,输出却为 True。原因在于,TensorFlow 会将被装饰的函数编译为计算图,其中引用的 Python 变量会被固化为常量。后续对原变量的修改不会影响已编译的图中对应的值。
这种设计容易造成误解:表面上看是普通的 Python 代码,实际上运行逻辑已被静态化,违背了用户的直觉预期。
原则 3:互动优于文档
当用户可以通过运行示例代码并自行调试来理解系统行为时,很少有人愿意去阅读冗长的说明文档。因此,理想的软件设计应当使用户的工作流自然符合直觉逻辑。
一个成功的案例是 PyTorch 中方法命名的约定:所有带下划线的方法表示原地操作(inplace operation),而不带下划线的则是生成新对象的操作。这种设计具有良好的交互性,用户无需频繁查阅文档即可推测出方法的行为。
x = x.add(y) # 创建新张量
x.add_(y) # 在原张量上修改
x = x.mul(y) # 非原地乘法
x.mul_(y) # 原地乘法
当然,这也带来一定的认知负荷——用户需理解什么是原地操作,以及在何种场景下使用它们才是安全的。但总体而言,这种一致性增强了可探索性和可用性。
在设计深度学习框架时,良好的 API 设计至关重要。一个典型的例子是 Keras 中的层命名方式。这些名称遵循一致且清晰的命名规则,使用户能够快速理解并记住它们的功能,而无需频繁查阅文档。
例如,以下代码展示了 Keras 中几种常见的池化层:
from keras import layers
layers.MaxPooling2D()
layers.GlobalMaxPooling1D()
layers.GlobalAveragePooling3D()
towardsdatascience.com/design-an-easy-to-use-deep-learning-framework-52d7c37e415f?source=collection_archive---------6-----------------------#2024-04-10
从命名中可以直观看出:操作类型(如 MaxPooling 或 AveragePooling)、作用范围(Global 是否存在)以及数据维度(1D、2D、3D)。这种结构化的命名显著降低了用户的记忆成本。
除了命名规范,错误信息的设计也是用户体验的重要组成部分。由于用户不可能每次都写出完全正确的代码,因此程序应当具备合理的校验机制,并返回足够清晰的错误提示,帮助用户快速定位问题。
对比下面两个抛出异常的例子:
# 不够友好的错误提示:
raise ValueError("Tensor shape mismatch.")
# 更具指导性的错误提示:
raise ValueError(
"Tensor shape mismatch. "
"Expected: (batch, num_features). "
f"Received: {x.shape}"
)
显然,第二个示例不仅指出了错误类型,还明确列出了期望的张量形状和实际接收到的形状,极大提升了调试效率。
最理想的错误信息甚至能主动建议修复方案。比如 Python 在遇到近似拼写时会提供可能的正确选项:
import math
math.sqr(4)
# AttributeError: module 'math' has no attribute 'sqr'. Did you mean: 'sqrt'?
这样的提示直接引导用户修正拼写错误,减少了试错成本。
总结来说,在参与深度学习框架开发的过程中,我总结出三条关键的软件设计原则:
- 构建端到端的工作流程,以便更早发现用户体验中的潜在问题。
- 尽可能降低用户的认知负担——除非必要,不要要求用户学习额外的概念。
- 保持 API 的逻辑一致性,结合有意义的错误信息,让用户在与系统的交互中自然学会使用方法,而不是依赖反复查阅文档。
当然,优秀的软件设计远不止这些。若希望进一步提升 API 质量,推荐参考 Keras 官方的 API 设计指南,其中包含了更为系统和全面的最佳实践。


雷达卡


京公网安备 11010802022788号







