TE
TechEcho
Home24h TopNewestBestAskShowJobs
GitHubTwitter
Home

TechEcho

A tech news platform built with Next.js, providing global tech news and discussions.

GitHubTwitter

Home

HomeNewestBestAskShowJobs

Resources

HackerNews APIOriginal HackerNewsNext.js

© 2025 TechEcho. All rights reserved.

We fine-tuned Llama 405B on AMD GPUs

495 pointsby felarof8 months ago
Hey HN, we recently fine-tuned the llama3.1 405B model on 8xAMD MI300x GPUs using JAX instead of PyTorch. JAX&#x27;s advanced sharding APIs allowed us to achieve great performance. Check out our blog post to learn about the cool sharding tricks we used. We&#x27;ve also open-sourced the code: <a href="https:&#x2F;&#x2F;github.com&#x2F;felafax&#x2F;felafax">https:&#x2F;&#x2F;github.com&#x2F;felafax&#x2F;felafax</a><p>We&#x27;re a small startup building AI infra for fine-tuning and serving LLMs on non-NVIDIA hardware (TPUs, AMD, Trainium).<p>Problem: Many companies are trying to get PyTorch working on AMD GPUs, but we believe this is a treacherous path. PyTorch is deeply intertwined with the NVIDIA ecosystem in a lot of ways (e.g., `torch.cuda` or scaled_dot_product_attention is an NVIDIA CUDA kernel exposed as a PyTorch function). So, to get PyTorch code running on non-NVIDIA hardware, there&#x27;s a lot of &quot;de-NVIDIAfying&quot; that needs to be done.<p>Solution: We believe JAX is a better fit for non-NVIDIA hardware. In JAX, ML model code compiles to hardware-independent HLO graphs, which are then optimized by the XLA compiler before hardware-specific optimization. This clean separation allowed us to run the same LLaMA3 JAX code both on Google TPUs and AMD GPUs with no changes.<p>Our strategy as a company is to invest upfront in porting models to JAX, then leverage its framework and XLA kernels to extract maximum performance from non-NVIDIA backends. This is why we first ported Llama 3.1 from PyTorch to JAX, and now the same JAX model works great on TPUs and runs perfectly on AMD GPUs.<p>We&#x27;d love to hear your thoughts on our vision and repo!

13 comments

felarof8 months ago
Hey HN, we recently fine-tuned the llama3.1 405B model on 8xAMD MI300x GPUs using JAX instead of PyTorch. JAX&#x27;s advanced sharding APIs allowed us to achieve great performance. Check out our blog post to learn about the cool sharding tricks we used. We&#x27;ve also open-sourced the code: <a href="https:&#x2F;&#x2F;github.com&#x2F;felafax&#x2F;felafax">https:&#x2F;&#x2F;github.com&#x2F;felafax&#x2F;felafax</a><p>We&#x27;re a small startup building AI infra for fine-tuning and serving LLMs on non-NVIDIA hardware (TPUs, AMD, Trainium).<p>Problem: Many companies are trying to get PyTorch working on AMD GPUs, but we believe this is a treacherous path. PyTorch is deeply intertwined with the NVIDIA ecosystem in a lot of ways (e.g., `torch.cuda` or scaled_dot_product_attention is an NVIDIA CUDA kernel exposed as a PyTorch function). So, to get PyTorch code running on non-NVIDIA hardware, there&#x27;s a lot of &quot;de-NVIDIAfying&quot; that needs to be done.<p>Solution: We believe JAX is a better fit for non-NVIDIA hardware. In JAX, ML model code compiles to hardware-independent HLO graphs, which are then optimized by the XLA compiler before hardware-specific optimization. This clean separation allowed us to run the same LLaMA3 JAX code both on Google TPUs and AMD GPUs with no changes.<p>Our strategy as a company is to invest upfront in porting models to JAX, then leverage its framework and XLA kernels to extract maximum performance from non-NVIDIA backends. This is why we first ported Llama 3.1 from PyTorch to JAX, and now the same JAX model works great on TPUs and runs perfectly on AMD GPUs.<p>We&#x27;d love to hear your thoughts on our vision and repo!
评论 #41631842 未加载
评论 #41631662 未加载
评论 #41631816 未加载
评论 #41631646 未加载
评论 #41631278 未加载
评论 #41631611 未加载
评论 #41631926 未加载
评论 #41631362 未加载
评论 #41633824 未加载
评论 #41635953 未加载
评论 #41634271 未加载
chillee8 months ago
To be clear, this performance is quite bad (presumably because you didn&#x27;t manage to get compilation working).<p>You&#x27;re getting 35 tokens&#x2F;s for a 405B model, which comes out to about 85 Teraflops. 8 MI300x GPUs comes out to 10.4 <i>Peta</i>flops, so you&#x27;re getting about 0.8% MFU (which is about 40-50x worse than decent training performance of 30-40% MFU).<p>For AMD&#x27;s sake, I hope that it&#x27;s your software stack that&#x27;s limiting perf.
评论 #41635100 未加载
3abiton8 months ago
Firstly great work! I dabbled with AMD GPUs and ROCm support a year ago, and it was obvious AMD still a long way from catch ling up with Nvidia. While opting for JAX is in an interesting approach, what were the challenges for you deviating from pytorch (being the standard library for ML)?
评论 #41631961 未加载
评论 #41631410 未加载
latchkey8 months ago
Nice work! I was just playing with the inference side of things with 405B myself this weekend [0].<p>I&#x27;m not convinced that &#x27;torch.cuda&#x27; is really that bad since the AMD version of PyTorch just translates that for you. More like a naming problem, than anything. Fact is that it is just as easy to grab the rocm:pytorch container, as it is the rocm:jax container.<p>I don&#x27;t see very many numbers posted. What MFU did you get?<p>[0] <a href="https:&#x2F;&#x2F;x.com&#x2F;HotAisle&#x2F;status&#x2F;1837580046732874026" rel="nofollow">https:&#x2F;&#x2F;x.com&#x2F;HotAisle&#x2F;status&#x2F;1837580046732874026</a>
评论 #41631947 未加载
steeve8 months ago
We (ZML) measured MI300X at 30% faster than H100. These are great chips!
brutus12138 months ago
Does any Cloud provider have a 8xAMD MI300 host that one can rent? I use AWS for a lot of my professional work, and was hoping to try out an AMD GPU.
评论 #41640600 未加载
评论 #41641451 未加载
yeahwhatever108 months ago
Where is the performance data?
评论 #41631931 未加载
Stem00378 months ago
If possible, it would be interesting to explore ways to overcome the memory constraints and run a JIT-compiled version. This could potentially lead to further performance improvements.
评论 #41649891 未加载
yieldcrv8 months ago
Is AMD any closer to extracting value from this with large orders of their GPUs causing a shortage?<p>I’m getting the impression of “no”
评论 #41641493 未加载
system28 months ago
Why is obsidian (a note-taking app) doing this?
评论 #41633555 未加载
varispeed8 months ago
How do you buy such a GPU or is it still only reserved to the rich so they can get ahead of the game once the pleb gets their unwashed hands on these cards?
评论 #41643383 未加载
manojlds8 months ago
Thought this was a post from Obsidian at first. Why haven&#x27;t they done the GitHub.com vs GitHub.io thing yet.
评论 #41631802 未加载
评论 #41631619 未加载
abalaji8 months ago
@dang: could we get url to include the username since this isn&#x27;t about Obsidian itself, but rather a user generated blog?
评论 #41631756 未加载
评论 #41631546 未加载
评论 #41631642 未加载