jax
eab55a49 - [hijax] add HiType.{inc_rank,dec_rank} for vmap and scan.

Commit
15 days ago
[hijax] add HiType.{inc_rank,dec_rank} for vmap and scan. Scan needs both but interestingly vmap only needs dec_rank. We don't support vmap out_axes with hi types. To do that we would need to add a match_spec method on HiTyp. Co-authored-by: Yash Katariya <yashkatariya@google.com> Co-authored-by: Cristian Garcia <cgarciae@google.com> Co-authored-by: Robert Dyro <rdyro@google.com>
Author
Committer
Parents
Loading