JAX框架:高性能AI模型训练的新选择
引言
在人工智能模型开发过程中,选择合适的机器学习框架至关重要。历史上,众多库都曾竞相争夺开发者青睐(还记得Caffe和Theano吗?)。近年来,TensorFlow凭借其对高效图计算的重视似乎占据领先地位,而PyTorch则凭借其Python友好的接口获得了广泛认可。然而,近年来,一个新兴框架迅速崛起,不容忽视——JAX。JAX专注于提升AI模型训练和推理性能,同时保持良好的用户体验,正在挑战现有框架的主导地位。
JAX背景 - XLA编译
JAX的强大之处在于其利用了XLA编译技术。JAX展现的卓越性能归功于XLA提供的硬件特定优化。许多JAX核心功能,如即时编译(JIT)和函数式编程范式,都源于XLA。虽然TensorFlow和PyTorch也支持XLA,但JAX从设计之初就全面拥抱XLA,使JIT编译、自动微分、向量化、并行化等特性与XLA底层实现紧密集成。
XLA JIT编译器全面分析计算图,将连续张量操作合并为单一内核,消除冗余组件,并生成针对硬件加速器优化的机器代码。这减少了计算操作量,降低了主机与加速器间的通信开销,提高了内存利用率和加速器效率。
XLA的另一关键特性是其可扩展基础设施,支持更多AI加速器。XLA是OpenXLA项目的一部分,由多个ML领域参与者共同开发。
然而,依赖XLA也存在局限性,特别是具有动态张量形状的模型在XLA中可能无法达到最佳性能,且需要注意图断裂和重新编译问题,这可能影响代码调试。
JAX实际应用
本节展示如何使用JAX在单个GPU上训练AI模型,并与PyTorch进行性能比较。我们将使用HuggingFace的Transformers库,该库为Transformer架构模型提供了PyTorch和JAX实现。
下面是模型定义的代码示例:
import torch
import jax, flax, optax
import jax.numpy as jnp
def create_model(implementation='jax'):
from transformers import ViTConfig
if implementation == 'jax':
from transformers import FlaxViTForImageClassification as ModelClass
else:
from transformers import ViTForImageClassification as ModelClass
config = ViTConfig(
num_labels = 1000,
_attn_implementation = 'eager' # 禁用flash attention
)
return ModelClass(config)
我们选择不使用"flash-attention"功能,因为该优化在本文撰写时仅适用于PyTorch模型。
为关注运行时性能,我们在随机生成的数据集上训练模型。利用JAX对PyTorch数据加载器的支持:
def prepare_dataloader(batch_size, framework='jax'):
from torch.utils.data import Dataset, DataLoader, default_collate
# 创建随机图像和标签数据集
class RandomDataset(Dataset):
def __len__(self):
return 1000000
def __getitem__(self, index):
if framework == 'jax': # 使用nhwc格式
random_image = torch.randn([224, 224, 3], dtype=torch.float32)
else: # 使用nchw格式
random_image = torch.randn([3, 224, 224], dtype=torch.float32)
label = torch.tensor([index % 1000], dtype=torch.int64)
return random_image, label
dataset = RandomDataset()
if framework == 'jax': # 将torch张量转换为numpy数组
def numpy_collate(batch):
from jax.tree_util import tree_map
import jax.numpy as jnp
return tree_map(jnp.asarray, default_collate(batch))
collate_function = numpy_collate
else:
collate_function = default_collate
dataset = RandomDataset()
dataloader = DataLoader(dataset, batch_size=batch_size,
collate_fn=collate_function)
return dataloader
接下来定义PyTorch和JAX训练循环。JAX训练循环基于Flax的TrainState对象:
@jax.jit
def execute_jax_step(training_state, batch):
with jax.default_matmul_precision('tensorfloat32'):
def forward(params):
outputs = training_state.apply_fn({'params': params}, batch[0])
loss = optax.softmax_cross_entropy(
logits=outputs.logits, labels=batch[1]).mean()
return loss
gradient_function = jax.grad(forward)
gradients = gradient_function(training_state.params)
training_state = training_state.apply_gradients(grads=gradients)
return training_state
def execute_torch_step(batch, model, optimizer, loss_fn, device):
inputs = batch[0].to(device=device, non_blocking=True)
labels = batch[1].squeeze(-1).to(device=device, non_blocking=True)
outputs = model(inputs)
loss = loss_fn(outputs.logits, labels)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
现在整合所有组件。以下脚本包含PyTorch图JIT编译选项的控件:
def run_training(batch_size, framework, compile_model):
print(f"框架: {framework} \n"
f"批次大小: {batch_size} \n"
f"编译模型: {compile_model}")
# 初始化模型和数据加载器
is_jax = framework == 'jax'
is_torch_xla = framework == 'torch_xla'
model = create_model('jax' if is_jax else 'pytorch')
data_loader = prepare_dataloader(batch_size, 'jax' if is_jax else 'pytorch')
if is_jax:
# 初始化JAX设置
from flax.training import train_state
parameters = model.module.init(jax.random.key(0),
jnp.ones([1, 224, 224, 3]))['params']
optimizer = optax.sgd(learning_rate=1e-3)
state = train_state.TrainState.create(apply_fn=model.module.apply,
params=parameters, tx=optimizer)
else:
if is_torch_xla:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
torch_xla._XLAC._xla_set_use_full_mat_mul_precision(
use_full_mat_mul_precision=False)
device = xm.xla_device()
backend = 'openxla'
# 包装数据加载器
data_loader = pl.MpDeviceLoader(data_loader, device)
else:
device = torch.device('cuda')
backend = 'inductor'
model = model.to(device)
if compile_model:
model = torch.compile(model, backend=backend)
model.train()
optimizer = torch.optim.SGD(model.parameters())
loss_fn = torch.nn.CrossEntropyLoss()
import time
start_time = time.perf_counter()
total_time = 0
step_count = 0
for step, data in enumerate(data_loader):
if is_jax:
state = execute_jax_step(state, data)
else:
execute_torch_step(data, model, optimizer, loss_fn, device)
# 记录步骤时间
step_duration = time.perf_counter() - start_time
if step > 10: # 跳过初始步骤
total_time += step_duration
step_count += 1
start_time = time.perf_counter()
if step > 50:
break
print(f'平均步骤时间: {total_time / step_count}')
if __name__ == '__main__':
import argparse
torch.set_float32_matmul_precision('high')
parser = argparse.ArgumentParser(description='模型训练脚本.')
parser.add_argument('--batch-size', type=int, default=32,
help='训练批次大小 (默认: 32)')
parser.add_argument('--framework', choices=['pytorch', 'jax', 'torch_xla'],
default='jax',
help='选择训练框架')
parser.add_argument('--compile-model', action='store_true', default=False,
help='是否对模型应用torch.compile')
args = parser.parse_args()
run_training(**vars(args))
性能基准测试
进行性能基准测试时需谨慎,因为多个因素可能影响结果,如浮点精度、矩阵乘法精度、数据加载方式以及是否使用flash/fused注意力机制。例如,PyTorch默认矩阵乘法精度为float32,而JAX使用tensorfloat32,直接比较可能缺乏参考价值。这些精度设置可通过相应API调整,如jax.default_matmul_precision和torch.set_float32_matmul_precision。我们在脚本中已尝试识别并排除这些潜在问题,但无法保证完全成功。
测试结果
我们在Google Cloud的两台虚拟机上执行训练脚本:g2-standard-16(配备NVIDIA L4 GPU)和a2-highgpu-1g(配备NVIDIA A100 GPU)。使用深度学习专用镜像(common-cu121-v20240514-ubuntu-2204-py310),预装PyTorch(2.3.0)、PyTorch/XLA(2.3.0)、JAX(0.4.28)、Flax(0.8.4)、Optax(0.2.2)和HuggingFace Transformers(4.41.1)。
下表汇总了多次实验的运行时间。需要注意的是,模型架构和环境差异可能导致性能比较结果显著变化,代码的细微调整也可能影响结果。
尽管JAX在L4 GPU上表现优于其他选项,但在A100 GPU上与PyTorch/XLA性能相当。这并不意外,因为它们共享XLA后端。理论上,JAX生成的任何XLA图都应能被PyTorch/XLA实现。在这两种平台上,torch.compile表现均不理想,考虑到我们使用全精度浮点数计算,这在一定程度上可以预见。
为什么选择JAX?
- 性能优化
JAX训练的主要吸引力在于JIT编译可能带来的性能提升。然而,随着PyTorch新增JIT功能(PyTorch/XLA)和torch.compile选项,JAX的这一优势可能受到质疑。考虑到PyTorch庞大的开发者社区和原生支持而JAX/FLAX尚未涵盖的特性(如自动混合精度、高级注意力机制层),有人可能认为无需投入时间学习JAX。除可能的性能提升外,还有其他动力因素:
- XLA友好性
与PyTorch后来通过PyTorch/XLA实现的"函数化"不同,JAX从设计之初就内嵌XLA支持。这意味着在PyTorch/XLA中可能显得复杂的操作,在JAX中可以更简洁优雅地实现。例如,在训练过程中混合使用JIT和非JIT函数,在JAX中直接可行,而在PyTorch/XLA中可能需要技巧。
理论上,PyTorch/XLA和TensorFlow都能生成与JAX相同的XLA图,实现同等性能。但实际生成的图质量取决于框架如何转换为XLA代码,更高效的转换带来更好的性能。由于JAX原生支持XLA,它可能具有竞争优势。
JAX对XLA的友好性使其对专用AI加速器开发者尤其有吸引力,如Google Cloud TPU、Intel Gaudi和AWS Trainium芯片,这些通常被称为"XLA设备"。特别是在TPU上训练的团队可能会发现JAX的支持生态系统比PyTorch/XLA更先进。
- 高级特性
近年来,JAX发布了多项高级功能,远早于同行。例如,SPMD是一种先进的设备并行技术,提供最先进的模型分片机会,几年前在JAX中引入,最近才被移植到PyTorch。另一个例子是Pallas,能够为XLA设备构建自定义内核。
开源模型
随着JAX日益普及,越来越多的开源AI模型以JAX发布。经典例子包括Google的开源MaxText(大语言模型)和AlphaFold v2(蛋白质结构预测)模型。要充分利用这些模型,需要学习JAX,或将模型移植到其他语言。
结论
本文深入探讨了正在崛起的JAX机器学习框架。我们阐述了其基于XLA编译器的特点,并通过示例展示了应用。虽然JAX常因快速执行速度而备受关注,但PyTorch的JIT功能(包括torch.compile和PyTorch/XLA)同样具备性能优化潜力。每种选择的性能表现很大程度上取决于模型细节和运行环境。
值得注意的是,每个机器学习框架都可能拥有独特特性(如本文撰写时,JAX的SPMD自动分片和PyTorch的SDPA注意力机制),这些特性可能在性能比较中起关键作用。因此,选择最佳框架的决定因素可能是你的模型能在多大程度上利用这些特性。