JAX加载大文件默认直冲GPU显存,小心爆显存
JAX默认把数据加载到GPU显存,而不是CPU内存。当加载训练数据集这类大文件时,数据直接进显存。
解决方法是用JAX提供的上下文管理器jax.default_device临时切换默认设备。比如:
`
with jax.default_device(jax.devices('cpu')[0]):
full_dataset = load_file(...)
`
这段代码把当前上下文的默认设备设为第一个CPU设备,文件会加载到CPU内存。
要使用CPU设备,先通过jax.devices('cpu')获取CPU设备对象。jax.devices()默认只返回当前后端的设备列表——通常是GPU。JAX还有一个jax_default_device配置选项,可以通过jax.config.update或环境变量设置默认设备。上下文管理器的好处是临时切换,不影响全局设置。
数据加载到CPU后,用jax.device_put就能传回GPU。这是一种常见的操作模式:CPU负责加载,训练循环逐批移入GPU。
JAX没有直接列出所有可用后端的方法,官方推荐的做法是尝试加载不同后端的设备,捕获RuntimeError异常来判断哪些可用。另外,Safetensors的load_file函数有一个backend参数,但它控制的是文件读取方式,跟设备选择无关。
📚 相关主题
工程