jax
flax>=0.7.1
transformer_engine_rocm7==2.10.0+rocm7.13.0rc2
