JAX, M.D. End-to-End Differentiable, Hardware Accelerated, Molecular Dynamics in Pure Python

ORAL

Abstract

A large fraction of computational science involves simulating the dynamics of particles that interact via pairwise or many-body interactions. These simulations, called Molecular Dynamics (MD), span a vast range of subjects from physics to drug discovery. Most MD software involves significant use of handwritten derivatives and code reuse across C++, FORTRAN, and CUDA. This is reminiscent of the state of machine learning (ML) before automatic differentiation became popular. Here, we bring the substantial advances in software that have taken place in ML to MD. JAX, M.D. is an end-to-end differentiable MD package written entirely in Python that can be just-in-time compiled to CPU/GPU/TPU. JAX MD allows researchers to iterate extremely quickly and lets researchers easily incorporate ML into their workflows. Finally, since all of the simulation code is in Python, researchers can have unprecedented flexibility in setting up experiments. JAX MD also allows researchers to take derivatives through whole-simulations as well as seamlessly incorporate neural networks into simulations. In this presentation we explore the architecture of JAX MD and its capabilities through several vignettes. Code is at github.com/google/jax-md along with a Colab notebook with experiments from the presentation.

Presenters

  • Sam Schoenholz

    Google, Google Inc., Google Brain

Authors

  • Sam Schoenholz

    Google, Google Inc., Google Brain

  • Ekin Dogus Cubuk

    Google, Google Inc., Google Inc, Google Brain