梯度正常损失不降?哈希值秒查Flax NNX参数是否更新
问题:损失固定在10.82
一位开发者在Flax NNX中训练一个从零构建的LLM,模型有7700万个参数。损失函数输出稳定在10.82,换算成困惑度约为50011——接近GPT-2词表大小50257,说明模型输出的概率基本均匀分布,完全在随机猜测。
检查梯度值:输出层kernel的梯度在10⁻⁹到10⁻⁴量级,学习率设为0.0014。梯度看着合理——不小到完全没影响,也没变成NaN或零。但损失就是不降。
调试:参数哈希暴露真相
既然梯度正常,问题可能出在参数实际有没有被更新。7700万个参数,肉眼不可能看出逐个变化。作者用了一个简单方法:计算参数数组的哈希值。具体做法是 hash(np.asarray(some_array).tobytes())——参数只要变一个比特,哈希值就会完全不同。
首次打印哈希值时,作者发现它在多次迭代中完全不变。参数没有更新。
根因:用错了JIT装饰器
问题出在训练循环的JIT编译方式上。Flax官方示例中使用 @nnx.jit 装饰 train_step 函数,但作者最初用的是普通的 @jax.jit。这个差异很关键:NNX 的 optimizer.update(model, grads) 是原地更新参数,不是函数式返回新参数。@jax.jit 不会跟踪这种带副作用的更新,每次调用的入参和返回值都没变——JIT编译器直接用了缓存,参数纹丝不动。
把 @jax.jit 换成 @nnx.jit 后,参数哈希值每次迭代都改变,损失也开始缓慢下降。
训练任务与模型结构
这次训练目标比较特殊:模型学的是将输入序列映射到自身(比如输入“The fat cat sat on the mat”,目标也是“The fat cat sat on the mat”),不是通常的下一个词预测。模型结构也极其简化——只有token嵌入层和输出线性层直接相连,没有中间层。
作者提到调试时实际用了 print,这会破坏JIT编译。如果当时用 jax.debug.print,就能在保留JIT的同时输出信息。
开放问题
@nnx.jit 具体怎么实现状态传播、它与 @jax.jit 在底层有何不同,目前没有详细文档说明。参数哈希法虽然有效,但更大的模型可能影响性能——jax.debug.print 或许是更轻量的替代,同样能保留JIT编译。