JAX
repo: n2cholas/awesome-jax
category: Computer Science
Awesome JAX
<img src="https://raw.githubusercontent.com/google/jax/master/images/jax_logo_250px.png" alt="JAX Logo" align="right" height="100">
JAX brings automatic differentiation and the XLA compiler together through a NumPy-like API for high performance machine learning research on accelerators like GPUs and TPUs.
This is a curated list of awesome JAX libraries, projects, and other resources. Contributions are welcome!
Contents
- Libraries
- Models and Projects
- Videos
- Papershttps://github.com/jax-ml/jax
- Tutorials and Blog Posts
- Books
- Community
<a name="libraries" />
Libraries
- Neural Network Libraries
- Flax - Centered on flexibility and clarity. <img src="https://img.shields.io/github/stars/google/flax?style=social" align="center">
- Flax NNX - An evolution on Flax by the same team <img src="https://img.shields.io/github/stars/google/flax?style=social" align="center">
- Haiku - Focused on simplicity, created by the authors of Sonnet at DeepMind. <img src="https://img.shields.io/github/stars/deepmind/dm-haiku?style=social" align="center">
- Objax - Has an object oriented design similar to PyTorch. <img src="https://img.shields.io/github/stars/google/objax?style=social" align="center">
- Elegy - A High Level API for Deep Learning in JAX. Supports Flax, Haiku, and Optax. <img src="https://img.shields.io/github/stars/poets-ai/elegy?style=social" align="center">
- Trax - "Batteries included" deep learning library focused on providing solutions for common workloads. <img src="https://img.shields.io/github/stars/google/trax?style=social" align="center">
- Jraph - Lightweight graph neural network library. <img src="https://img.shields.io/github/stars/deepmind/jraph?style=social" align="center">
- Neural Tangents - High-level API for specifying neural networks of both finite and infinite width. <img src="https://img.shields.io/github/stars/google/neural-tangents?style=social" align="center">
- HuggingFace Transformers - Ecosystem of pretrained Transformers for a wide range of natural language tasks (Flax). <img src="https://img.shields.io/github/stars/huggingface/transformers?style=social" align="center">
- Equinox - Callable PyTrees and filtered JIT/grad transformations => neural networks in JAX. <img src="https://img.shields.io/github/stars/patrick-kidger/equinox?style=social" align="center">
- Scenic - A Jax Library for Computer Vision Research and Beyond. <img src="https://img.shields.io/github/stars/google-research/scenic?style=social" align="center">
- Penzai - Prioritizes legibility, visualization, and easy editing of neural network models with composable tools and a simple mental model. <img src="https://img.shields.io/github/stars/google-deepmind/penzai?style=social" align="center">
- Levanter - Legible, Scalable, Reproducible Foundation Models with Named Tensors and JAX. <img src="https://img.shields.io/github/stars/stanford-crfm/levanter?style=social" align="center">
- EasyLM - LLMs made easy: Pre-training, finetuning, evaluating and serving LLMs in JAX/Flax. <img src="https://img.shields.io/github/stars/young-geng/EasyLM?style=social" align="center">
- NumPyro - Probabilistic programming based on the Pyro library. <img src="https://img.shields.io/github/stars/pyro-ppl/numpyro?style=social" align="center">
- Chex - Utilities to write and test reliable JAX code. <img src="https://img.shields.io/github/stars/deepmind/chex?style=social" align="center">
- Optax - Gradient processing and optimization library. <img src="https://img.shields.io/github/stars/deepmind/optax?style=social" align="center">
- RLax - Library for implementing reinforcement learning agents. <img src="https://img.shields.io/github/stars/deepmind/rlax?style=social" align="center">
- JAX, M.D. - Accelerated, differential molecular dynamics. <img src="https://img.shields.io/github/stars/google/jax-md?style=social" align="center">
- Coax - Turn RL papers into code, the easy way. <img src="https://img.shields.io/github/stars/coax-dev/coax?style=social" align="center">
- Distrax - Reimplementation of TensorFlow Probability, containing probability distributions and bijectors. <img src="https://img.shields.io/github/stars/deepmind/distrax?style=social" align="center">
- cvxpylayers - Construct differentiable convex optimization layers. <img src="https://img.shields.io/github/stars/cvxgrp/cvxpylayers?style=social" align="center">
- TensorLy - Tensor learning made simple. <img src="https://img.shields.io/github/stars/tensorly/tensorly?style=social" align="center">
- NetKet - Machine Learning toolbox for Quantum Physics. <img src="https://img.shields.io/github/stars/netket/netket?style=social" align="center">
- Fortuna - AWS library for Uncertainty Quantification in Deep Learning. <img src="https://img.shields.io/github/stars/awslabs/fortuna?style=social" align="center">
- BlackJAX - Library of samplers for JAX. <img src="https://img.shields.io/github/stars/blackjax-devs/blackjax?style=social" align="center">
- Dynamax - Probabilistic state space models. <img src="https://img.shields.io/github/stars/probml/dynamax?style=social" align="center">
<a name="new-libraries" />
New Libraries
This section contains libraries that are well-made and useful, but have not necessarily been battle-tested by a large userbase yet.
- Neural Network Libraries
- FedJAX - Federated learning in JAX, built on Optax and Haiku. <img src="https://img.shields.io/github/stars/google/fedjax?style=social" align="center">
- Equivariant MLP - Construct equivariant neural network layers. <img src="https://img.shields.io/github/stars/mfinzi/equivariant-MLP?style=social" align="center">
- jax-resnet - Implementations and checkpoints for ResNet variants in Flax. <img src="https://img.shields.io/github/stars/n2cholas/jax-resnet?style=social" align="center">
- jax-raft - JAX/Flax port of the RAFT optical flow estimator. <img src="https://img.shields.io/github/stars/alebeck/jax-raft?style=social" align="center">
- Parallax - Immutable Torch Modules for JAX. <img src="https://img.shields.io/github/stars/srush/parallax?style=social" align="center">
- Nonlinear Optimization
- Optimistix - Root finding, minimisation, fixed points, and least squares. <img src="https://img.shields.io/github/stars/patrick-kidger/optimistix?style=social" align="center">
- JAXopt - Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX. <img src="https://img.shields.io/github/stars/google/jaxopt?style=social" align="center">
- jax-unirep - Library implementing the UniRep model for protein machine learning applications. <img src="https://img.shields.io/github/stars/ElArkk/jax-unirep?style=social" align="center">
- flowjax - Distributions and normalizing flows built as equinox modules. <img src="https://img.shields.io/github/stars/danielward27/flowjax?style=social" align="center">
- flaxdiff - Framework and Library for building and training Diffusion models in multi-node multi-device distributed settings (TPUs) <img src="https://img.shields.io/github/stars/AshishKumar4/FlaxDiff?style=social" align="center">
- jax-flows - Normalizing flows in JAX. <img src="https://img.shields.io/github/stars/ChrisWaites/jax-flows?style=social" align="center">
- sklearn-jax-kernels -
scikit-learnkernel matrices using JAX. <img src="https://img.shields.io/github/stars/ExpectationMax/sklearn-jax-kernels?style=social" align="center"> - jax-cosmo - Differentiable cosmology library. <img src="https://img.shields.io/github/stars/DifferentiableUniverseInitiative/jax_cosmo?style=social" align="center">
- efax - Exponential Families in JAX. <img src="https://img.shields.io/github/stars/NeilGirdhar/efax?style=social" align="center">
- mpi4jax - Combine MPI operations with your Jax code on CPUs and GPUs. <img src="https://img.shields.io/github/stars/PhilipVinc/mpi4jax?style=social" align="center">
- imax - Image augmentations and transformations. <img src="https://img.shields.io/github/stars/4rtemi5/imax?style=social" align="center">
- FlaxVision - Flax version of TorchVision. <img src="https://img.shields.io/github/stars/rolandgvc/flaxvision?style=social" align="center">
- Oryx - Probabilistic programming language based on program transformations.
- Optimal Transport Tools - Toolbox that bundles utilities to solve optimal transport problems.
- delta PV - A photovoltaic simulator with automatic differentation. <img src="https://img.shields.io/github/stars/romanodev/deltapv?style=social" align="center">
- jaxlie - Lie theory library for rigid body transformations and optimization. <img src="https://img.shields.io/github/stars/brentyi/jaxlie?style=social" align="center">
- BRAX - Differentiable physics engine to simulate environments along with learning algorithms to train agents for these environments. <img src="https://img.shields.io/github/stars/google/brax?style=social" align="center">
- flaxmodels - Pretrained models for Jax/Flax. <img src="https://img.shields.io/github/stars/matthias-wright/flaxmodels?style=social" align="center">
- CR.Sparse - XLA accelerated algorithms for sparse representations and compressive sensing. <img src="https://img.shields.io/github/stars/carnotresearch/cr-sparse?style=social" align="center">
- exojax - Automatic differentiable spectrum modeling of exoplanets/brown dwarfs compatible to JAX. <img src="https://img.shields.io/github/stars/HajimeKawahara/exojax?style=social" align="center">
- PIX - PIX is an image processing library in JAX, for JAX. <img src="https://img.shields.io/github/stars/deepmind/dm_pix?style=social" align="center">
- bayex - Bayesian Optimization powered by JAX. <img src="https://img.shields.io/github/stars/alonfnt/bayex?style=social" align="center">
- JaxDF - Framework for differentiable simulators with arbitrary discretizations. <img src="https://img.shields.io/github/stars/ucl-bug/jaxdf?style=social" align="center">
- tree-math - Convert functions that operate on arrays into functions that operate on PyTrees. <img src="https://img.shields.io/github/stars/google/tree-math?style=social" align="center">
- jax-models - Implementations of research papers originally without code or code written with frameworks other than JAX. <img src="https://img.shields.io/github/stars/DarshanDeshpande/jax-modelsa?style=social" align="center">
- PGMax - A framework for building discrete Probabilistic Graphical Models (PGM's) and running inference inference on them via JAX. <img src="https://img.shields.io/github/stars/vicariousinc/pgmax?style=social" align="center">
- EvoJAX - Hardware-Accelerated Neuroevolution <img src="https://img.shields.io/github/stars/google/evojax?style=social" align="center">
- evosax - JAX-Based Evolution Strategies <img src="https://img.shields.io/github/stars/RobertTLange/evosax?style=social" align="center">
- SymJAX - Symbolic CPU/GPU/TPU programming. <img src="https://img.shields.io/github/stars/SymJAX/SymJAX?style=social" align="center">
- mcx - Express & compile probabilistic programs for performant inference. <img src="https://img.shields.io/github/stars/rlouf/mcx?style=social" align="center">
- Einshape - DSL-based reshaping library for JAX and other frameworks. <img src="https://img.shields.io/github/stars/deepmind/einshape?style=social" align="center">
- ALX - Open-source library for distributed matrix factorization using Alternating Least Squares, more info in ALX: Large Scale Matrix Factorization on TPUs.
- Diffrax - Numerical differential equation solvers in JAX. <img src="https://img.shields.io/github/stars/patrick-kidger/diffrax?style=social" align="center">
- tinygp - The tiniest of Gaussian process libraries in JAX. <img src="https://img.shields.io/github/stars/dfm/tinygp?style=social" align="center">
- gymnax - Reinforcement Learning Environments with the well-known gym API. <img src="https://img.shields.io/github/stars/RobertTLange/gymnax?style=social" align="center">
- Mctx - Monte Carlo tree search algorithms in native JAX. <img src="https://img.shields.io/github/stars/deepmind/mctx?style=social" align="center">
- KFAC-JAX - Second Order Optimization with Approximate Curvature for NNs. <img src="https://img.shields.io/github/stars/deepmind/kfac-jax?style=social" align="center">
- TF2JAX - Convert functions/graphs to JAX functions. <img src="https://img.shields.io/github/stars/deepmind/tf2jax?style=social" align="center">
- jwave - A library for differentiable acoustic simulations <img src="https://img.shields.io/github/stars/ucl-bug/jwave?style=social" align="center">
- GPJax - Gaussian processes in JAX.
- Jumanji - A Suite of Industry-Driven Hardware-Accelerated RL Environments written in JAX. <img src="https://img.shields.io/github/stars/instadeepai/jumanji?style=social" align="center">
- Eqxvision - Equinox version of Torchvision. <img src="https://img.shields.io/github/stars/paganpasta/eqxvision?style=social" align="center">
- JAXFit - Accelerated curve fitting library for nonlinear least-squares problems (see arXiv paper). <img src="https://img.shields.io/github/stars/dipolar-quantum-gases/jaxfit?style=social" align="center">
- econpizza - Solve macroeconomic models with hetereogeneous agents using JAX. <img src="https://img.shields.io/github/stars/gboehl/econpizza?style=social" align="center">
- SPU - A domain-specific compiler and runtime suite to run JAX code with MPC(Secure Multi-Party Computation). <img src="https://img.shields.io/github/stars/secretflow/spu?style=social" align="center">
- jax-tqdm - Add a tqdm progress bar to JAX scans and loops. <img src="https://img.shields.io/github/stars/jeremiecoullon/jax-tqdm?style=social" align="center">
- safejax - Serialize JAX, Flax, Haiku, or Objax model params with 🤗
safetensors. <img src="https://img.shields.io/github/stars/alvarobartt/safejax?style=social" align="center"> - Kernex - Differentiable stencil decorators in JAX. <img src="https://img.shields.io/github/stars/ASEM000/kernex?style=social" align="center">
- MaxText - A simple, performant and scalable Jax LLM written in pure Python/Jax and targeting Google Cloud TPUs. <img src="https://img.shields.io/github/stars/google/maxtext?style=social" align="center">
- Pax - A Jax-based machine learning framework for training large scale models. <img src="https://img.shields.io/github/stars/google/paxml?style=social" align="center">
- Praxis - The layer library for Pax with a goal to be usable by other JAX-based ML projects. <img src="https://img.shields.io/github/stars/google/praxis?style=social" align="center">
- purejaxrl - Vectorisable, end-to-end RL algorithms in JAX. <img src="https://img.shields.io/github/stars/luchris429/purejaxrl?style=social" align="center">
- Lorax - Automatically apply LoRA to JAX models (Flax, Haiku, etc.)
- SCICO - Scientific computational imaging in JAX. <img src="https://img.shields.io/github/stars/lanl/scico?style=social" align="center">
- Spyx - Spiking Neural Networks in JAX for machine learning on neuromorphic hardware. <img src="https://img.shields.io/github/stars/kmheckel/spyx?style=social" align="center">
- Brain Dynamics Programming Ecosystem
- BrainPy - Brain Dynamics Programming in Python. <img src="https://img.shields.io/github/stars/brainpy/BrainPy?style=social" align="center">
- brainunit - Physical units and unit-aware mathematical system in JAX. <img src="https://img.shields.io/github/stars/chaobrain/brainunit?style=social" align="center">
- dendritex - Dendritic Modeling in JAX. <img src="https://img.shields.io/github/stars/chaobrain/dendritex?style=social" align="center">
- brainstate - State-based Transformation System for Program Compilation and Augmentation. <img src="https://img.shields.io/github/stars/chaobrain/brainstate?style=social" align="center">
- braintaichi - Leveraging Taichi Lang to customize brain dynamics operators. <img src="https://img.shields.io/github/stars/chaobrain/braintaichi?style=social" align="center">
- OTT-JAX - Optimal transport tools in JAX. <img src="https://img.shields.io/github/stars/ott-jax/ott?style=social" align="center">
- QDax - Quality Diversity optimization in Jax. <img src="https://img.shields.io/github/stars/adaptive-intelligent-robotics/QDax?style=social" align="center">
- JAX Toolbox - Nightly CI and optimized examples for JAX on NVIDIA GPUs using libraries such as T5x, Paxml, and Transformer Engine. <img src="https://img.shields.io/github/stars/NVIDIA/JAX-Toolbox?style=social" align="center">
- Pgx - Vectorized board game environments for RL with an AlphaZero example. <img src="https://img.shields.io/github/stars/sotetsuk/pgx?style=social" align="center">
- EasyDeL - EasyDeL 🔮 is an OpenSource Library to make your training faster and more Optimized With cool Options for training and serving (Llama, MPT, Mixtral, Falcon, etc) in JAX <img src="https://img.shields.io/github/stars/erfanzar/EasyDeL?style=social" align="center">
- XLB - A Differentiable Massively Parallel Lattice Boltzmann Library in Python for Physics-Based Machine Learning. <img src="https://img.shields.io/github/stars/Autodesk/XLB?style=social" align="center">
- dynamiqs - High-performance and differentiable simulations of quantum systems with JAX. <img src="https://img.shields.io/github/stars/dynamiqs/dynamiqs?style=social" align="center">
- foragax - Agent-Based modelling framework in JAX. <img src="https://img.shields.io/github/stars/i-m-iron-man/Foragax?style=social" align="center">
- tmmax - Vectorized calculation of optical properties in thin-film structures using JAX. Swiss Army knife tool for thin-film optics research <img src="https://img.shields.io/github/stars/bahremsd/tmmax" align="center">
- Coreax - Algorithms for finding coresets to compress large datasets while retaining their statistical properties. <img src="https://img.shields.io/github/stars/gchq/coreax?style=social" align="center">
- NAVIX - A reimplementation of MiniGrid, a Reinforcement Learning environment, in JAX <img src="https://img.shields.io/github/stars/epignatelli/navix?style=social" align="center">
- FDTDX - Finite-Difference Time-Domain Electromagnetic Simulations in JAX <img src="https://img.shields.io/github/stars/ymahlau/fdtdx?style=social" align="center">
- DiffeRT - Differentiable Ray Tracing toolbox for Radio Propagation powered by the JAX ecosystem. <img src="https://img.shields.io/github/stars/jeertmans/DiffeRT?style=social" align="center">
- JAX-in-Cell - Plasma physics simulations using a PIC (Particle-in-Cell) method to self-consistently solve for electron and ion dynamics in electromagnetic fields <img src="https://img.shields.io/github/stars/uwplasma/JAX-in-Cell?style=social" align="center">
- kvax - A FlashAttention implementation for JAX with support for efficient document mask computation and context parallelism. <img src="https://img.shields.io/github/stars/nebius/kvax?style=social" align="center">
- astronomix - differentiable (magneto)hydrodynamics for astrophysics in JAX <img src="https://img.shields.io/github/stars/leo1200/astronomix?style=social" align="center">
- vivsim - Fluid-structure interaction simulations using Immersed Boundary-Lattice Boltzmann Method. <img src="https://img.shields.io/github/stars/haimingz/vivsim?style=social" align="center">
- MBIRJAX - High-performance tomographic reconstruction. <img src="https://img.shields.io/github/stars/cabouman/mbirjax?style-social" align="center">
- torchax - torchax is a library for Jax to interoperate with model code written in PyTorch.<img src="https://img.shields.io/github/stars/google/torchax?style=social" align="center">
<a name="models-and-projects" />
Models and Projects
JAX
- Fourier Feature Networks - Official implementation of Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains.
- kalman-jax - Approximate inference for Markov (i.e., temporal) Gaussian processes using iterated Kalman filtering and smoothing.
- jaxns - Nested sampling in JAX.
- Amortized Bayesian Optimization - Code related to Amortized Bayesian Optimization over Discrete Spaces.
- Accurate Quantized Training - Tools and libraries for running and analyzing neural network quantization experiments in JAX and Flax.
- BNN-HMC - Implementation for the paper What Are Bayesian Neural Network Posteriors Really Like?.
- JAX-DFT - One-dimensional density functional theory (DFT) in JAX, with implementation of Kohn-Sham equations as regularizer: building prior knowledge into machine-learned physics.
- Robust Loss - Reference code for the paper A General and Adaptive Robust Loss Function.
- Symbolic Functionals - Demonstration from Evolving symbolic density functionals.
- TriMap - Official JAX implementation of TriMap: Large-scale Dimensionality Reduction Using Triplets.
Flax
- awesome-jax-flax-llms - Collection of LLMs implemented in JAX & Flax
- DeepSeek-R1-Flax-1.5B-Distill - Flax implementation of DeepSeek-R1 1.5B distilled reasoning LLM.
- Performer - Flax implementation of the Performer (linear transformer via FAVOR+) architecture.
- JaxNeRF - Implementation of NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis with multi-device GPU/TPU support.
- mip-NeRF - Official implementation of Mip-NeRF: A Multiscale Representation for Anti-Aliasing Neural Radiance Fields.
- RegNeRF - Official implementation of RegNeRF: Regularizing Neural Radiance Fields for View Synthesis from Sparse Inputs.
- JaxNeuS - Implementation of [NeuS: Learning Neural Implicit Surfaces by Volume Rendering for Multi-view Reconstruction](https://lingjie0206.github.io/papers/NeuS/)
- Big Transfer (BiT) - Implementation of Big Transfer (BiT): General Visual Representation Learning.
- JAX RL - Implementations of reinforcement learning algorithms.
- gMLP - Implementation of Pay Attention to MLPs.
- MLP Mixer - Minimal implementation of MLP-Mixer: An all-MLP Architecture for Vision.
- Distributed Shampoo - Implementation of Second Order Optimization Made Practical.
- NesT - Official implementation of Aggregating Nested Transformers.
- XMC-GAN - Official implementation of [Cross-Modal Contrastive Learning for Text-to-Image Generation](https://arxiv.org/abs/2101.04702).
- FNet - Official implementation of FNet: Mixing Tokens with Fourier Transforms.
- GFSA - Official implementation of Learning Graph Structure With A Finite-State Automaton Layer.
- IPA-GNN - Official implementation of Learning to Execute Programs with Instruction Pointer Attention Graph Neural Networks.
- Flax Models - Collection of models and methods implemented in Flax.
- Protein LM - Implements BERT and autoregressive models for proteins, as described in [Biological Structure and Function Emerge from Scaling Unsupervised Learning to 250 Million Protein Sequences](https://www.biorxiv.org/content/10.1101/622803v1.full) and ProGen: Language Modeling for Protein Generation.
- Slot Attention - Reference implementation for Differentiable Patch Selection for Image Recognition.
- Vision Transformer - Official implementation of An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale.
- FID computation - Port of mseitzer/pytorch-fid to Flax.
- ARDM - Official implementation of Autoregressive Diffusion Models.
- D3PM - Official implementation of Structured Denoising Diffusion Models in Discrete State-Spaces.
- Gumbel-max Causal Mechanisms - Code for Learning Generalized Gumbel-max Causal Mechanisms, with extra code in GuyLor/gumbel_max_causal_gadgets_part2.
- Latent Programmer - Code for the ICML 2021 paper Latent Programmer: Discrete Latent Codes for Program Synthesis.
- SNeRG - Official implementation of Baking Neural Radiance Fields for Real-Time View Synthesis.
- Spin-weighted Spherical CNNs - Adaptation of Spin-Weighted Spherical CNNs.
- VDVAE - Adaptation of Very Deep VAEs Generalize Autoregressive Models and Can Outperform Them on Images, original code at openai/vdvae.
- MUSIQ - Checkpoints and model inference code for the ICCV 2021 paper MUSIQ: Multi-scale Image Quality Transformer
- AQuaDem - Official implementation of Continuous Control with Action Quantization from Demonstrations.
- Combiner - Official implementation of Combiner: Full Attention Transformer with Sparse Computation Cost.
- Dreamfields - Official implementation of the ICLR 2022 paper Progressive Distillation for Fast Sampling of Diffusion Models.
- GIFT - Official implementation of Gradual Domain Adaptation in the Wild:When Intermediate Distributions are Absent.
- Light Field Neural Rendering - Official implementation of Light Field Neural Rendering.
- Sharpened Cosine Similarity in JAX by Raphael Pisoni - A JAX/Flax implementation of the Sharpened Cosine Similarity layer.
- GNNs for Solving Combinatorial Optimization Problems - A JAX + Flax implementation of Combinatorial Optimization with Physics-Inspired Graph Neural Networks.
- DETR - Flax implementation of DETR: End-to-end Object Detection with Transformers using Sinkhorn solver and parallel bipartite matching.
Haiku
- AlphaFold - Implementation of the inference pipeline of AlphaFold v2.0, presented in Highly accurate protein structure prediction with AlphaFold.
- Adversarial Robustness - Reference code for Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples and Fixing Data Augmentation to Improve Adversarial Robustness.
- Bootstrap Your Own Latent - Implementation for the paper Bootstrap your own latent: A new approach to self-supervised Learning.
- Gated Linear Networks - GLNs are a family of backpropagation-free neural networks.
- Glassy Dynamics - Open source implementation of the paper Unveiling the predictive power of static structure in glassy systems.
- MMV - Code for the models in Self-Supervised MultiModal Versatile Networks.
- Normalizer-Free Networks - Official Haiku implementation of NFNets.
- NuX - Normalizing flows with JAX.
- OGB-LSC - This repository contains DeepMind's entry to the PCQM4M-LSC (quantum chemistry) and MAG240M-LSC (academic graph) tracks of the OGB Large-Scale Challenge (OGB-LSC).
- Persistent Evolution Strategies - Code used for the paper Unbiased Gradient Estimation in Unrolled Computation Graphs with Persistent Evolution Strategies.
- Two Player Auction Learning - JAX implementation of the paper [Auction learning as a two-player game](https://arxiv.org/abs/2006.05684).
- WikiGraphs - Baseline code to reproduce results in WikiGraphs: A Wikipedia Text - Knowledge Graph Paired Datase.
Trax
- Reformer - Implementation of the Reformer (efficient transformer) architecture.
NumPyro
- lqg - Official implementation of Bayesian inverse optimal control for linear-quadratic Gaussian problems from the paper Putting perception into action with inverse optimal control for continuous psychophysics
Equinox
- Sampling Path Candidates with Machine Learning - Official tutorial and implementation from the paper Towards Generative Ray Path Sampling for Faster Point-to-Point Ray Tracing.
<a name="videos" />
Videos
- NeurIPS 2020: JAX Ecosystem Meetup - JAX, its use at DeepMind, and discussion between engineers, scientists, and JAX core team.
- Introduction to JAX - Simple neural network from scratch in JAX.
- [JAX: Accelerated Machine Learning Research | SciPy 2020 | VanderPlas](https://youtu.be/z-WSrQDXkuM) - JAX's core design, how it's powering new research, and how you can start using it.
- Bayesian Programming with JAX + NumPyro — Andy Kitchen - Introduction to Bayesian modelling using NumPyro.
- [JAX: Accelerated machine-learning research via composable function transformations in Python | NeurIPS 2019 | Skye Wanderman-Milne](https://slideslive.com/38923687/jax-accelerated-machinelearning-research-via-composable-function-transformations-in-python) - JAX intro presentation in Program Transformations for Machine Learning workshop.
- JAX on Cloud TPUs | NeurIPS 2020 | Skye Wanderman-Milne and James Bradbury - Presentation of TPU host access with demo.
- Deep Implicit Layers - Neural ODEs, Deep Equilibirum Models, and Beyond | NeurIPS 2020 - Tutorial created by Zico Kolter, David Duvenaud, and Matt Johnson with Colab notebooks avaliable in Deep Implicit Layers.
- Solving y=mx+b with Jax on a TPU Pod slice - Mat Kelcey - A four part YouTube tutorial series with Colab notebooks that starts with Jax fundamentals and moves up to training with a data parallel approach on a v3-32 TPU Pod slice.
- JAX, Flax & Transformers 🤗 - 3 days of talks around JAX / Flax, Transformers, large-scale language modeling and other great topics.
<a name="papers" />
Papers
This section contains papers focused on JAX (e.g. JAX-based library whitepapers, research on JAX, etc). Papers implemented in JAX are listed in the Models/Projects section.
- [Compiling machine learning programs via high-level tracing. Roy Frostig, Matthew James Johnson, Chris Leary. MLSys 2018.](https://mlsys.org/Conferences/doc/2018/146.pdf) - White paper describing an early version of JAX, detailing how computation is traced and compiled.
- JAX, M.D.: A Framework for Differentiable Physics. Samuel S. Schoenholz, Ekin D. Cubuk. NeurIPS 2020. - Introduces JAX, M.D., a differentiable physics library which includes simulation environments, interaction potentials, neural networks, and more.
- Enabling Fast Differentially Private SGD via Just-in-Time Compilation and Vectorization. Pranav Subramani, Nicholas Vadivelu, Gautam Kamath. arXiv 2020. - Uses JAX's JIT and VMAP to achieve faster differentially private than existing libraries.
- XLB: A Differentiable Massively Parallel Lattice Boltzmann Library in Python. Mohammadmehdi Ataei, Hesam Salehipour. arXiv 2023. - White paper describing the XLB library: benchmarks, validations, and more details about the library.
<a name="tutorials-and-blog-posts" />
Tutorials and Blog Posts
- Using JAX to accelerate our research by David Budden and Matteo Hessel - Describes the state of JAX and the JAX ecosystem at DeepMind.
- Getting started with JAX (MLPs, CNNs & RNNs) by Robert Lange - Neural network building blocks from scratch with the basic JAX operators.
- Learn JAX: From Linear Regression to Neural Networks by Rito Ghosh - A gentle introduction to JAX and using it to implement Linear and Logistic Regression, and Neural Network models and using them to solve real world problems.
- Tutorial: image classification with JAX and Flax Linen by 8bitmp3 - Learn how to create a simple convolutional network with the Linen API by Flax and train it to recognize handwritten digits.
- Plugging Into JAX by Nick Doiron - Compares Flax, Haiku, and Objax on the Kaggle flower classification challenge.
- [Meta-Learning in 50 Lines of JAX by Eric Jang](https://blog.evjang.com/2019/02/maml-jax.html) - Introduction to both JAX and Meta-Learning.
- Normalizing Flows in 100 Lines of JAX by Eric Jang - Concise implementation of RealNVP.
- Differentiable Path Tracing on the GPU/TPU by Eric Jang - Tutorial on implementing path tracing.
- Ensemble networks by Mat Kelcey - Ensemble nets are a method of representing an ensemble of models as one single logical model.
- Out of distribution (OOD) detection by Mat Kelcey - Implements different methods for OOD detection.
- Understanding Autodiff with JAX by Srihari Radhakrishna - Understand how autodiff works using JAX.
- [From PyTorch to JAX: towards neural net frameworks that purify stateful code by Sabrina J. Mielke](https://sjmielke.com/jax-purify.htm) - Showcases how to go from a PyTorch-like style of coding to a more Functional-style of coding.
- Extending JAX with custom C++ and CUDA code by Dan Foreman-Mackey - Tutorial demonstrating the infrastructure required to provide custom ops in JAX.
- Evolving Neural Networks in JAX by Robert Tjarko Lange - Explores how JAX can power the next generation of scalable neuroevolution algorithms.
- Exploring hyperparameter meta-loss landscapes with JAX by Luke Metz - Demonstrates how to use JAX to perform inner-loss optimization with SGD and Momentum, outer-loss optimization with gradients, and outer-loss optimization using evolutionary strategies.
- Deterministic ADVI in JAX by Martin Ingram - Walk through of implementing automatic differentiation variational inference (ADVI) easily and cleanly with JAX.
- Evolved channel selection by Mat Kelcey - Trains a classification model robust to different combinations of input channels at different resolutions, then uses a genetic algorithm to decide the best combination for a particular loss.
- Introduction to JAX by Kevin Murphy - Colab that introduces various aspects of the language and applies them to simple ML problems.
- Writing an MCMC sampler in JAX by Jeremie Coullon - Tutorial on the different ways to write an MCMC sampler in JAX along with speed benchmarks.
- How to add a progress bar to JAX scans and loops by Jeremie Coullon - Tutorial on how to add a progress bar to compiled loops in JAX using the
host_callbackmodule. - Get started with JAX by Aleksa Gordić - A series of notebooks and videos going from zero JAX knowledge to building neural networks in Haiku.
- Writing a Training Loop in JAX + FLAX by Saurav Maheshkar and Soumik Rakshit - A tutorial on writing a simple end-to-end training and evaluation pipeline in JAX, Flax and Optax.
- Implementing NeRF in JAX by Soumik Rakshit and Saurav Maheshkar - A tutorial on 3D volumetric rendering of scenes represented by Neural Radiance Fields in JAX.
- [Deep Learning tutorials with JAX+Flax by Phillip Lippe](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial2/Introduction_to_JAX.html) - A series of notebooks explaining various deep learning concepts, from basics (e.g. intro to JAX/Flax, activiation functions) to recent advances (e.g., Vision Transformers, SimCLR), with translations to PyTorch.
- Achieving 4000x Speedups with PureJaxRL - A blog post on how JAX can massively speedup RL training through vectorisation.
- Simple PDE solver + Constrained Optimization with JAX by Philip Mocz - A simple example of solving the advection-diffusion equations with JAX and using it in a constrained optimization problem to find initial conditions that yield desired result.
<a name="books" />
Books
- Jax in Action - A hands-on guide to using JAX for deep learning and other mathematically-intensive applications.
<a name="community" />
Community
- JaxLLM (Unofficial) Discord
- [JAX GitHub Discussions](https://github.com/google/jax/discussions)
Contributing
Contributions welcome! Read the contribution guidelines first.