PyTorch 1.8 版本发布,包含分布式训练更新和 AMD ROCm 支持

PyTorch 1.8 版本发布总结

PyTorch 是 Facebook 开源的深度学习框架,近期发布了 1.8 版本。该版本包含多项更新,包括 API 的改进、分布式训练的优化以及对 AMD GPU 加速器的 ROCm 平台支持。此外,领域特定库 TorchVision、TorchAudio 和 TorchText 也发布了新版本。

主要特性

  1. ROCm 平台支持:新版本提供了针对 ROCm 平台的二进制文件,以提升在 AMD GPU 系统上的性能。
  2. NumPy 兼容 API 更新:新增了与 NumPy 兼容的 API,包括快速傅里叶变换(FFT)和常用线性代数函数。
  3. 分布式训练优化

    • 管道并行:类似于 GPipe,将输入的小批量数据分割为多个微批次,通过 GPU 管道化处理,减少空闲时间。
    • 梯度压缩:引入了通信钩子,优化训练中的梯度通信步骤,包括梯度压缩和 PowerSGD 等预构建钩子。
  4. torch.fx 工具包:这是一个新的 beta 工具包,用于 Python 到 Python 的功能转换,灵感来自 Jax 和 TensorFlow。其主要组件包括符号追踪器、中间表示和 Python 代码生成器,允许开发者将 Module 子类转换为 Graph 表示,修改 Graph 并生成 Python 源代码。

领域特定库更新

  1. TorchVision:增加了对移动设备的支持,包括 Detectron2 的移动版本。
  2. TorchAudio:改进了 I/O 性能。
  3. TorchText:使其数据集 API 与 PyTorch DataLoader 工具兼容。

用户反馈

在 Hacker News 的讨论中,用户对比了 PyTorch 和 TensorFlow,指出尽管 TensorFlow 在 ROCm 支持上落后于 PyTorch,但其对 Google TPU 设备的支持更为优越。一位用户称赞 PyTorch 为:

[T]he most impressive piece of software engineering that I know of....There's just an incredible amount of complexity being hidden behind behind a very simple interface there...

发布信息

PyTorch 1.8 的发布说明和代码可在 GitHub 上获取。

总结

PyTorch 1.8 版本在性能优化、API 更新和分布式训练方面带来了显著改进,特别是对 AMD GPU 的 ROCm 平台支持,进一步扩展了其硬件兼容性。新增的 torch.fx 工具包为开发者提供了更灵活的功能转换机制,而领域特定库的更新则提升了其在移动设备和 I/O 性能方面的表现。整体来看,PyTorch 继续巩固其在深度学习框架领域的领先地位。

阅读 44
0 条评论