当前位置:首页 > 技术 > 正文内容

JAX框架:高性能AI模型训练的新选择

访客 技术 2026年6月12日 1

引言

在人工智能模型开发过程中,选择合适的机器学习框架至关重要。历史上,众多库都曾竞相争夺开发者青睐(还记得Caffe和Theano吗?)。近年来,TensorFlow凭借其对高效图计算的重视似乎占据领先地位,而PyTorch则凭借其Python友好的接口获得了广泛认可。然而,近年来,一个新兴框架迅速崛起,不容忽视——JAX。JAX专注于提升AI模型训练和推理性能,同时保持良好的用户体验,正在挑战现有框架的主导地位。

JAX背景 - XLA编译

JAX的强大之处在于其利用了XLA编译技术。JAX展现的卓越性能归功于XLA提供的硬件特定优化。许多JAX核心功能,如即时编译(JIT)和函数式编程范式,都源于XLA。虽然TensorFlow和PyTorch也支持XLA,但JAX从设计之初就全面拥抱XLA,使JIT编译、自动微分、向量化、并行化等特性与XLA底层实现紧密集成。

XLA JIT编译器全面分析计算图,将连续张量操作合并为单一内核,消除冗余组件,并生成针对硬件加速器优化的机器代码。这减少了计算操作量,降低了主机与加速器间的通信开销,提高了内存利用率和加速器效率。

XLA的另一关键特性是其可扩展基础设施,支持更多AI加速器。XLA是OpenXLA项目的一部分,由多个ML领域参与者共同开发。

然而,依赖XLA也存在局限性,特别是具有动态张量形状的模型在XLA中可能无法达到最佳性能,且需要注意图断裂和重新编译问题,这可能影响代码调试。

JAX实际应用

本节展示如何使用JAX在单个GPU上训练AI模型,并与PyTorch进行性能比较。我们将使用HuggingFace的Transformers库,该库为Transformer架构模型提供了PyTorch和JAX实现。

下面是模型定义的代码示例:

import torch
import jax, flax, optax
import jax.numpy as jnp

def create_model(implementation='jax'):
    from transformers import ViTConfig

    if implementation == 'jax':
        from transformers import FlaxViTForImageClassification as ModelClass
    else:
        from transformers import ViTForImageClassification as ModelClass

    config = ViTConfig(
        num_labels = 1000,
        _attn_implementation = 'eager'  # 禁用flash attention
    )
    
    return ModelClass(config)


我们选择不使用"flash-attention"功能,因为该优化在本文撰写时仅适用于PyTorch模型。

为关注运行时性能,我们在随机生成的数据集上训练模型。利用JAX对PyTorch数据加载器的支持:

def prepare_dataloader(batch_size, framework='jax'):
    from torch.utils.data import Dataset, DataLoader, default_collate

    # 创建随机图像和标签数据集
    class RandomDataset(Dataset):
        def __len__(self):
            return 1000000

        def __getitem__(self, index):
            if framework == 'jax': # 使用nhwc格式
                random_image = torch.randn([224, 224, 3], dtype=torch.float32)
            else: # 使用nchw格式
                random_image = torch.randn([3, 224, 224], dtype=torch.float32)
            label = torch.tensor([index % 1000], dtype=torch.int64)
            return random_image, label

    dataset = RandomDataset()
    
    if framework == 'jax':  # 将torch张量转换为numpy数组
        def numpy_collate(batch):
            from jax.tree_util import tree_map
            import jax.numpy as jnp
            return tree_map(jnp.asarray, default_collate(batch))
        collate_function = numpy_collate
    else:
        collate_function = default_collate
 
    dataset = RandomDataset()
    dataloader = DataLoader(dataset, batch_size=batch_size,
                    collate_fn=collate_function)
    return dataloader


接下来定义PyTorch和JAX训练循环。JAX训练循环基于Flax的TrainState对象:

@jax.jit
def execute_jax_step(training_state, batch):
    with jax.default_matmul_precision('tensorfloat32'):
        def forward(params):
            outputs = training_state.apply_fn({'params': params}, batch[0])
            loss = optax.softmax_cross_entropy(
                logits=outputs.logits, labels=batch[1]).mean()
            return loss

        gradient_function = jax.grad(forward)
        gradients = gradient_function(training_state.params)
        training_state = training_state.apply_gradients(grads=gradients)
        return training_state

def execute_torch_step(batch, model, optimizer, loss_fn, device):
    inputs = batch[0].to(device=device, non_blocking=True)
    labels = batch[1].squeeze(-1).to(device=device, non_blocking=True)
    outputs = model(inputs)
    loss = loss_fn(outputs.logits, labels)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()


现在整合所有组件。以下脚本包含PyTorch图JIT编译选项的控件:

def run_training(batch_size, framework, compile_model):
    print(f"框架: {framework} \n"
          f"批次大小: {batch_size} \n"
          f"编译模型: {compile_model}")

    # 初始化模型和数据加载器
    is_jax = framework == 'jax'
    is_torch_xla = framework == 'torch_xla'
    model = create_model('jax' if is_jax else 'pytorch')
    data_loader = prepare_dataloader(batch_size, 'jax' if is_jax else 'pytorch')

    if is_jax:
        # 初始化JAX设置
        from flax.training import train_state
        parameters = model.module.init(jax.random.key(0), 
                                   jnp.ones([1, 224, 224, 3]))['params']
        optimizer = optax.sgd(learning_rate=1e-3)
        state = train_state.TrainState.create(apply_fn=model.module.apply,
                                              params=parameters, tx=optimizer)
    else:
        if is_torch_xla:
            import torch_xla
            import torch_xla.core.xla_model as xm
            import torch_xla.distributed.parallel_loader as pl
            torch_xla._XLAC._xla_set_use_full_mat_mul_precision(
                use_full_mat_mul_precision=False)
       
            device = xm.xla_device()
            backend = 'openxla'
        
            # 包装数据加载器
            data_loader = pl.MpDeviceLoader(data_loader, device)
        else:
            device = torch.device('cuda')
            backend = 'inductor'
    
        model = model.to(device)
        if compile_model:
            model = torch.compile(model, backend=backend)
        model.train()
        optimizer = torch.optim.SGD(model.parameters())
        loss_fn = torch.nn.CrossEntropyLoss()

    import time
    start_time = time.perf_counter()
    total_time = 0
    step_count = 0

    for step, data in enumerate(data_loader):
        if is_jax:
            state = execute_jax_step(state, data)
        else:
            execute_torch_step(data, model, optimizer, loss_fn, device)

        # 记录步骤时间
        step_duration = time.perf_counter() - start_time
        if step > 10:  # 跳过初始步骤
            total_time += step_duration
        step_count += 1
        start_time = time.perf_counter()
        if step > 50:
            break

    print(f'平均步骤时间: {total_time / step_count}')


if __name__ == '__main__':
    import argparse
    torch.set_float32_matmul_precision('high')
    
    parser = argparse.ArgumentParser(description='模型训练脚本.')
    parser.add_argument('--batch-size', type=int, default=32,
                        help='训练批次大小 (默认: 32)')
    parser.add_argument('--framework', choices=['pytorch', 'jax', 'torch_xla'],
                        default='jax',
                        help='选择训练框架')
    parser.add_argument('--compile-model', action='store_true', default=False,
                        help='是否对模型应用torch.compile')
    args = parser.parse_args()

    run_training(**vars(args))


性能基准测试

进行性能基准测试时需谨慎,因为多个因素可能影响结果,如浮点精度、矩阵乘法精度、数据加载方式以及是否使用flash/fused注意力机制。例如,PyTorch默认矩阵乘法精度为float32,而JAX使用tensorfloat32,直接比较可能缺乏参考价值。这些精度设置可通过相应API调整,如jax.default_matmul_precision和torch.set_float32_matmul_precision。我们在脚本中已尝试识别并排除这些潜在问题,但无法保证完全成功。

测试结果

我们在Google Cloud的两台虚拟机上执行训练脚本:g2-standard-16(配备NVIDIA L4 GPU)和a2-highgpu-1g(配备NVIDIA A100 GPU)。使用深度学习专用镜像(common-cu121-v20240514-ubuntu-2204-py310),预装PyTorch(2.3.0)、PyTorch/XLA(2.3.0)、JAX(0.4.28)、Flax(0.8.4)、Optax(0.2.2)和HuggingFace Transformers(4.41.1)。

下表汇总了多次实验的运行时间。需要注意的是,模型架构和环境差异可能导致性能比较结果显著变化,代码的细微调整也可能影响结果。

尽管JAX在L4 GPU上表现优于其他选项,但在A100 GPU上与PyTorch/XLA性能相当。这并不意外,因为它们共享XLA后端。理论上,JAX生成的任何XLA图都应能被PyTorch/XLA实现。在这两种平台上,torch.compile表现均不理想,考虑到我们使用全精度浮点数计算,这在一定程度上可以预见。

为什么选择JAX?

  • 性能优化

JAX训练的主要吸引力在于JIT编译可能带来的性能提升。然而,随着PyTorch新增JIT功能(PyTorch/XLA)和torch.compile选项,JAX的这一优势可能受到质疑。考虑到PyTorch庞大的开发者社区和原生支持而JAX/FLAX尚未涵盖的特性(如自动混合精度、高级注意力机制层),有人可能认为无需投入时间学习JAX。除可能的性能提升外,还有其他动力因素:

  • XLA友好性

与PyTorch后来通过PyTorch/XLA实现的"函数化"不同,JAX从设计之初就内嵌XLA支持。这意味着在PyTorch/XLA中可能显得复杂的操作,在JAX中可以更简洁优雅地实现。例如,在训练过程中混合使用JIT和非JIT函数,在JAX中直接可行,而在PyTorch/XLA中可能需要技巧。

理论上,PyTorch/XLA和TensorFlow都能生成与JAX相同的XLA图,实现同等性能。但实际生成的图质量取决于框架如何转换为XLA代码,更高效的转换带来更好的性能。由于JAX原生支持XLA,它可能具有竞争优势。

JAX对XLA的友好性使其对专用AI加速器开发者尤其有吸引力,如Google Cloud TPU、Intel Gaudi和AWS Trainium芯片,这些通常被称为"XLA设备"。特别是在TPU上训练的团队可能会发现JAX的支持生态系统比PyTorch/XLA更先进。

  • 高级特性

近年来,JAX发布了多项高级功能,远早于同行。例如,SPMD是一种先进的设备并行技术,提供最先进的模型分片机会,几年前在JAX中引入,最近才被移植到PyTorch。另一个例子是Pallas,能够为XLA设备构建自定义内核。

开源模型

随着JAX日益普及,越来越多的开源AI模型以JAX发布。经典例子包括Google的开源MaxText(大语言模型)和AlphaFold v2(蛋白质结构预测)模型。要充分利用这些模型,需要学习JAX,或将模型移植到其他语言。

结论

本文深入探讨了正在崛起的JAX机器学习框架。我们阐述了其基于XLA编译器的特点,并通过示例展示了应用。虽然JAX常因快速执行速度而备受关注,但PyTorch的JIT功能(包括torch.compile和PyTorch/XLA)同样具备性能优化潜力。每种选择的性能表现很大程度上取决于模型细节和运行环境。

值得注意的是,每个机器学习框架都可能拥有独特特性(如本文撰写时,JAX的SPMD自动分片和PyTorch的SDPA注意力机制),这些特性可能在性能比较中起关键作用。因此,选择最佳框架的决定因素可能是你的模型能在多大程度上利用这些特性。

标签: JAX

相关文章

Linux crontab 详解

1) crontab 是什么cron 是 Linux 的定时任务守护进程;crontab 是用来编辑/查看“按时间周期执行命令”的表(cron table)。常见两类:用户 crontab:每个用户一份(crontab -e 编辑)系统级 crontab / cron.d:可指定执行用户(/etc/crontab、/etc/cron.d/*)2) crontab 时间...

富文本里可以允许的 HTML 属性

一、所有标签默认允许的安全属性(极少)class        (可选)id           (通常建议禁用)title️ 注意:id 容易被滥用做锚点注入,很多系统直接禁用class 允许的话最好只允许固定前缀(如 editor-*)二、a 标签允许属性<a href="" t...

Dom\HTML_NO_DEFAULT_NS 的副作用:自动加闭合标签

在使用Dom\HTMLDocument时,Dom\HTML_NO_DEFAULT_NS 将禁止在解析过程中设置元素的命名空间, 此设置是为了与DOMDocument向后兼容而存在的。当使用它时,已知的一个副作用就是:自动加闭合标签例如 </img> 为什么会这样?当你使用:Dom\HTML_NO_DEFAULT_NS文档会变成 无命名空间模式,此时内部更接近 XML...

Laravel 事件和监听器创建

在 Laravel 中,使用 Artisan 命令创建 Events(事件) 和 Listeners(监听器) 是非常高效的。你可以通过以下几种方式来实现:1. 手动创建单个 Event如果你只想创建一个事件类,可以使用 make:event 命令:Bashphp artisan make:event UserRegistered执行后,文件将生成在 app/Even...

自定义域名解析神器 dnsmasq

什么是 dnsmasq?dnsmasq 是一个轻量级、功能强大的网络服务工具,专为小型和中等规模网络设计。它是一个综合的网络基础设施解决方案[1]。dnsmasq 能做什么?功能说明应用场景DNS 转发与缓存将 DNS 查询转发到上游服务器(ISP、Google DNS 等),并在本地缓存结果加快 DNS 查询速度,减少外部 DNS 流量本地 DNS解析本地网络设备的主机名,无需编辑&n...

linux screen 用法详情 (nohup 的替代方案)

一、screen 是什么?能干嘛?screen 是一个终端复用器,可以:在一个 SSH 会话中开多个“虚拟终端”SSH 断线后,程序仍然在后台运行随时重新连接到原来的会话特别适合:nohup 的替代方案跑脚本 / 爬虫 / 训练模型运维、远程开发二、安装 screen# CentOS / Rocky / Almayum install -y screen# Debian / Ubuntuapt i...

发表评论

访客

◎欢迎参与讨论,请在这里发表您的看法和观点。