xla
5327033b - [Pallas] Refactor the gmm kernel (#7099)

Commit
1 year ago
[Pallas] Refactor the gmm kernel (#7099) Summary: This is an effort to refactor the code from #6940 and aims to remove useless code in that part. It reduces the amount of code from ~400 lines to ~50 lines. However, a bummer is the original gmm kernel doesn't work at all... It assumes groups_sizes is a cpu tensor. That means we need to materialize this input in order to use this gmm kernel, and that will introduce graph breaks in the computation. I will need yet another follow up to make this code actually functional... Good news is the test cases seem functional, yay... Test Plan: python test/test_megablox.py
Author
Parents
Loading