AI Pulse

从零手写AI训练循环,效率高还能中途恢复

从零手写AI训练循环,效率高还能中途恢复

从零写训练循环,炸了三次

一位技术博主决定不依赖任何现成框架,只用JAX构建一个完整的LLM训练循环。他选的组件是Flax NNX和Optax。

开写之后,问题接二连三:先是JIT编译报错,接着20GB显存爆了,再后来发现Orbax检查点库的API文档跟实际代码对不上——文档里写save,实际模块里没有这个属性。每个错误都得翻文档、读源码、试错,直到跑通第一版。

跑起来后,效率惊人

这条训练循环在单张GPU上实现了约109,253 tokens/秒的吞吐量。训练的内容是一个简单模型,输入输出相同,叫A-to-A模型。花了92,160,000个token后,训练损失降到0.000,接近完美收敛。

能跑这么快,靠的是几项标准技术:梯度累积(用Optax的MultiSteps包装器,让小批次模拟大批次效果)、学习率预热加余弦衰减(起步稳、后期快速收敛)、以及全局梯度裁剪(防止梯度爆炸)。

能中断后恢复,省时省钱

作者还实现了检查点恢复功能,用Orbax库保存和加载优化器状态。训练中途断电或出bug,直接从上一步继续,不用从头来。这个功能对节约计算成本很关键。

下一步:训练完整LLM

这篇博客只验证了训练循环的正确性。作者计划在下一篇文章中构建并训练一个完整的LLM。目前不清楚这个完整模型的性能,也没有与GPT-2 small的对比数据。JAX训练循环跟PyTorch版本在速度和易用性上的具体差异,同样没有给出。Orbax API文档不一致的问题会不会在后续版本修掉,也没有说明。

但这些都不是重点——重点是一条路已经铺好了。

阅读原文
📚 相关主题 大语言模型JAXAI工程

📬 订阅 AI Pulse

每天三次更新,不错过重要信号

▲ 回到顶部