Skip to content

Learning JAX in 2023: Part 1 — The Ultimate Guide to Accelerating Numerical Computation and Machine Learning Aritra Roy Gosthipaty and Ritwik Raha PyImageSearch

Table of Contents

Learning JAX in 2023: Part 1 — The Ultimate Guide to Accelerating Numerical Computation and Machine Learning

In this tutorial, you will learn the basics of the JAX library, including how to install and use it to perform numerical computation and machine learning tasks using NumPy-like syntax and GPU acceleration.

This lesson is the 1st of a 3-part series on Learning JAX in 2023:

Learning JAX in 2023: Part 1 — The Ultimate Guide to Accelerating Numerical Computation and Machine Learning (today’s tutorial)Learning JAX in 2023: Part 2 — JAX’s Power Tools grad, jit, vmap, and pmapLearning JAX in 2023: Part 3 — A Step-by-Step Guide to Training Your First Machine Learning Model with JAX

To learn how to get started with JAX, just keep reading.

Looking for the source code to this post?

Jump Right To The Downloads Section

Learning JAX in 2023: Part 1 — The Ultimate Guide to Accelerating Numerical Computation and Machine Learning

🙌🏻 Introduction

As deep learning practitioners, it can be tough to keep up with all the new developments. New academic papers and models are always coming out; there’s a new framework to learn every few years. Recently, many people have been talking about JAX, a new numerical computing library that can make your code run faster.

Many people have asked us to create a course about JAX, so we decided to take on the challenge. In this series, we’ll not only teach you about JAX, but also how to learn and understand new concepts. We’ll keep the language simple and avoid using jargon, but if you need help understanding anything, please let us know, and we’ll do our best to help.

Once you complete this course, you’ll be able to understand and work with any code written in JAX/FLAX. Major companies like Google Research, Hugging Face, and OpenAI are already using JAX heavily, so this is a valuable skill to have. Let’s get started and learn all about it!

Configuring Your Development Environment

To follow this guide, you need to have the JAX library installed on your system. JAX is written in pure Python, but it depends on XLA, which needs to be installed as the jaxlib package (from: jax repository).

Luckily, jaxlib and jax are pip-installable:

$ pip install jaxlib
$ pip install numpy
$ pip install autograd
$ pip install jax

If you need help configuring your development environment for OpenCV, we highly recommend that you read our pip install OpenCV guide — it will have you up and running in a matter of minutes.

Having Problems Configuring Your Development Environment?

Figure 1: Having trouble configuring your dev environment? Want access to pre-configured Jupyter Notebooks running on Google Colab? Be sure to join PyImageSearch University — you’ll be up and running with this tutorial in a matter of minutes.

All that said, are you:

Short on time?Learning on your employer’s administratively locked system?Wanting to skip the hassle of fighting with the command line, package managers, and virtual environments?Ready to run the code right now on your Windows, macOS, or Linux system?

Then join PyImageSearch University today!

Gain access to Jupyter Notebooks for this tutorial and other PyImageSearch guides that are pre-configured to run on Google Colab’s ecosystem right in your web browser! No installation required.

And best of all, these Jupyter Notebooks will run on Windows, macOS, and Linux!

🤔 What Is JAX?

JAX is the combination of autograd and XLA. Before diving into the nitty-gritty of JAX, let us look into autograd and XLA briefly.

Note: The section about autograd and XLA is meant to provide a more holistic understanding of the principles with which JAX was built. They are optional to getting started with JAX.

Click here to skip to “What Is JAX (revisited)?”

autograd

Gradients run the deep learning world quite literally. For example, we can compute the gradients (derivatives) of an equation in the following ways:

Manual: We use our calculus knowledge and derive the derivatives by hand. The problem with this approach is that it is manual. It would take a lot of time for a Deep Learning researcher to derive the model’s derivatives by hand.Symbolic: We can obtain the derivatives via symbols and a program that can mimic the manual process. The problem with this approach is termed expression swell. Here the derivatives of a particular expression are exponentially longer (think chain rule) than the expression itself. This becomes quite difficult to track.Numeric: Here, we use the finite differences method to derive the derivatives.Automatic: The star ⭐️ of the show.

Automatic differentiation (autodiff) is the type of differentiation we all love and use when training our deep neural networks. In a previous tutorial, we covered the math and code base to understand how autodiff works. If you are interested in how autodiff works, we urge you to read the linked tutorial first and then return to the current one.

autograd is a python package that performs automatic differentiation on native python and NumPy code. The code base is fairly simple.

Two points to note here:

There is a light wrapper autograd.numpy around the native NumPy codebase. This allows users to use NumPy-like semantics, harnessing the power of automatic differentiation.autograd.grad and autograd.elementwise_grad help with the actual automatic differentiation.

We will now look at a few code snippets to demonstrate how autograd works.

Let’s start by importing the necessary packages.

Note: autograd’s numpy module is imported as anp to distinguish from np (original NumPy) and jnp (jax numpy).

# Import the necessary packages
from autograd import numpy as anp
from autograd import grad
from autograd import elementwise_grad as egrad

We will first define a simple function and then compute the function’s gradient using automatic differentiation. Next, we find the function’s gradient at a particular point (scalar) and for a list of points (vector).

We build a function that takes a value and returns the square of it .

def func_square(x):
# Return the square of the input
return x**2

# Build a scalar input and pass it to the
# square function
x = 4.0
squared_x = func_square(x=x)
print(f”x => {x}nx**2 => {squared_x}”)

>>> x => 4.0
>>> x**2 => 16.0

The function that we have defined is . We know that the function’s derivative is .

We can achieve this derivative by applying the autograd.grad function. Let us see how that works.

# Compute the derivative of the square function
grad_func = grad(func_square)

point = 1.0
# Retrieve the gradient of the function at a particular point
print(f”Gradient of square func at {point} => {grad_func(1.0)}”)

>>> Gradient of square func at 1.0 => 2.0

The above code snippet allows us to calculate the gradient of a scalar function. Next, let us see how to do the same with vectors.

# Let’s pass a vector to the square function
vector = anp.arange(1, 10, dtype=anp.float32)
out_vector = func_square(vector)

# Iterate over the vector and its output
for v, o in zip(vector, out_vector):
print(f”Value at point {v} => {o}”)

>>> Value at point 1.0 => 1.0
>>> Value at point 2.0 => 4.0
>>> Value at point 3.0 => 9.0
>>> Value at point 4.0 => 16.0
>>> Value at point 5.0 => 25.0
>>> Value at point 6.0 => 36.0
>>> Value at point 7.0 => 49.0
>>> Value at point 8.0 => 64.0
>>> Value at point 9.0 => 81.0

What happens if we send the entire vector to our gradient function?

try:
out_vector = grad_func(vector)
except Exception as ex:
print(f”Type of exception => {type(ex).__name__}”)
print(f”Exception => {ex}”)

>>> Type of exception => TypeError
>>> Exception => Grad only applies to real scalar-output functions. Try jacobian, elementwise_grad or holomorphic_grad.

Let us do what the exception tells us to do. Use elementwise_grad for the vectorization process.

# Let us now vectorize the gradient code
egrad_func = egrad(func_square)

try:
out_vector = egrad_func(vector)
for v, o in zip(vector, out_vector):
print(f”Grad at point {v} => {o}”)
except Exception as ex:
print(f”Type of exception => {type(ex).__name__}”)
print(f”Exception => {ex}”)

With the elementwise_grad function, we could vectorize the gradient function.

>>> Grad at point 1.0 => 2.0
>>> Grad at point 2.0 => 4.0
>>> Grad at point 3.0 => 6.0
>>> Grad at point 4.0 => 8.0
>>> Grad at point 5.0 => 10.0
>>> Grad at point 6.0 => 12.0
>>> Grad at point 7.0 => 14.0
>>> Grad at point 8.0 => 16.0
>>> Grad at point 9.0 => 18.0

Automatic Differentiation is at the very heart of Deep Learning. Any framework that facilitates differential programming allows users to navigate and exploit patterns in data through backpropagation.

Learn about Automatic Differentiation and Differential programming in our blog post series:

Automatic Differentiation Part 1: Understanding the MathAutomatic Differentiation Part 2: Implementation Using Micrograd

XLA

It is safe to say that the fields of Deep Learning (DL) and Machine Learning (ML) consist of an enormous amount of Linear Algebra. All computations from start to finish are mostly Linear Algebra.

What if we told you there is a compiler in town that can make Linear Algebra operations more efficient?

Enters XLA: XLA stands for Accelerated Linear Algebra. It is a domain-specific compiler that accelerates linear algebra operations. The compiled operations are device agnostic. It runs on the CPU, GPU, and TPU with no code change.

👀 What Is JAX (revisited)?

Understanding what autograd and XLA does gives us a basic intuition about JAX.

JAX is a high-performance, numerical computing library incorporating composable function transformations.

Why You Should (or Shouldn’t) be Using Google’s JAX in 2023

That sounds intimidating, but think about it again. Thanks to autograd, the NumPy-like API and automatic differentiation engine make JAX a very efficient numerical computing library.

The inclusion of the XLA compiler makes JAX a highly performant numerical computing library incorporating composable function transformations.

We will talk about what composable function transformation means in an upcoming blog post.

⬇️ Import JAX

Let’s talk about JAX by working on it hands-on. We have already installed JAX on our system. Now let’s import it to get started.

import jax

📚 Understanding the Components: API Layering of JAX

Before we start multiplying matrices and backpropagating on them, let us take a moment to understand the various components of JAX. While starting with a library, knowing its basic API design is always a good practice.

The version of JAX used when writing this tutorial is 0.3.25. The API design of JAX is done in a way where we have the high-level abstraction of jax.numpy and the low-level abstraction of jax.lax.

Where jax.numpy is similar to the original NumPy package, jax.lax is a wrapper around Google’s XLA compiler.

Note: Did you notice that lax is an anagram of xla? 🤯

If you head over to the official documentation of JAX API, you will see several sub-packages and sub-topics with their APIs listed.

The most used APIs are the following:

jax.numpyjax.lax

While the topics that are very important in the API design paradigm are:

Just-in-time compilation (jit)Automatic differentiation (grad)Vectorization (vmap)Parallelization (pmap)

We will discuss these topics and sub-packages with corresponding code snippets as we go through the tutorial. Let us import them first into our work environment.

import numpy as np

import jax
from jax import numpy as jnp
from jax import make_jaxpr
from jax import grad, jit, vmap, pmap

💯 Numerical Computation in JAX

This section will take us through the most used APIs of JAX: jax.numpy and jax.lax. Before diving in, we must note that JAX is not a Deep Learning (DL) framework. Instead, it is a numerical computation library. It is just that DL falls into the numerical computation paradigm.

For the ease of numerical computation, it has a NumPy API that mirrors the API of yet another very powerful numerical computation library (yes, you guessed it, NumPy 😁).

The thing that makes JAX stand out is its wrapper for the XLA compiler, jax.lax. The jax.numpy wrapper is basic XLA code with the jax.lax API. This makes JAX code not only device agnostic but also jit compilable.

Being device agnostic means that the same code can be run on different hardware (CPUs, GPUs, and TPUs). With the JIT compilation, the same code can run much faster and more efficiently. This is why JAX is referred to as NumPy on steroids.

jax.numpy

In this section, we learn how to write NumPy-like code using jax.numpy.

# Build an array of 0 to 9 with the `jax.numpy` API
array = jnp.arange(0, 10, dtype=jnp.int8)
print(f”array => {array}”)

>>> array => [0 1 2 3 4 5 6 7 8 9]

This is a great thing to have. One with a fair amount of knowledge in NumPy does not need to learn something new. We can easily port the programs built on NumPy into JAX by adding the extra j. With Python duck typing, jax.numpy can be a drop-in replacement for any numpy code.

Let’s now look into some differences. The first one is the data type of the values in jax.numpy.

print(type(array))

>>> <class ‘jaxlib.xla_extension.DeviceArray’>

The DeviceArray is the JAX equivalent to numpy.ndarray. However, the two are not exactly the same. For example, JAX is device agnostic, while NumPy is not.

We have seen that the jax.numpy wrapper mirrors the NumPy python library well, but there remains a few stark differences between the two.

A major one is DeviceArrays are immutable, unlike numpy.ndarrays. We illustrate this using the following code snippet.

jax_array = jnp.arange(1, 10, dtype=jax.numpy.int8)
numpy_array = np.arange(1, 10).astype(np.int8)
try:
numpy_array[2] = 2
except Exception as ex:
print(f”Type of exception => {type(ex).__name__}”)
print(f”Exception => {ex}”)
try:
jax_array[2] = 2
except Exception as ex:
print(f”Type of exception => {type(ex).__name__}”)
print(f”Exception => {ex}”)

Type of exception => TypeError
Exception => ‘<class ‘jaxlib.xla_extension.DeviceArray’>’ object does not support item assignment. JAX arrays are immutable. Instead of “x[idx] = y“, use “x = x.at[idx].set(y)“ or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

To combat this problem, JAX has an at[].set() clause.

Note: The clause does not make changes in place. The mutation creates another DeviceArray with the necessary changes. This is only correct outside JIT; in most cases within JIT, updates will happen in-place (from: GitHub Discussion).

This design decision was taken to make functions pure (we will discuss pure functions in an upcoming blog post).

try:
mutated_jax_array = jax_array.at[2].set(200)
except Exception as ex:
print(f”Type of exception => {type(ex).__name__}”)
print(f”Exception => {ex}”)
print(f”Original Array => {jax_array}”)
print(f”Mutated Array => {mutated_jax_array}”)

>>> Original Array => [1 2 3 4 5 6 7 8 9]
>>> Mutated Array => [ 1 2 -56 4 5 6 7 8 9]

Another key point to note is the indexing of tensors in JAX.

try:
print(“Indexing 1000th position of a NumPy array…”)
print(numpy_array[1000])
except Exception as ex:
print(type(ex).__name__)
print(ex)

>>> Indexing 1000th position of a NumPy array…
>>> IndexError
>>> index 1000 is out of bounds for axis 0 with size 9

try:
print(“Indexing 1000th position of a JAX array…”)
print(jax_array[1000])
except Exception as ex:
print(type(ex).__name__)
print(ex)

>>> Indexing 1000th position of a JAX array…
>>> 9

👀 What happened here?

In JAX, the indexing is capped. This is a little caveat that we need to take care of so that we do not see our code fail silently.

jax.lax

Let’s talk a little bit about jax.lax now. While the NumPy API makes it easier for you to enter the world of JAX, jax.lax is what powers the library with all of its functionalities.

jax.lax is a library of primitive operations that underpins libraries such as jax.numpy.

jax.lax module

While jax.numpy is a high-level abstraction that makes it easier to code, jax.lax is much more powerful with many constraints.

jax.lax does not even support automatic type casting. This is demonstrated using the following code snippets.

# Checking the lenient `jax.numpy` API
try:
print(jax.lax.add(jnp.float32(1), 2.0))
except Exception as ex:
print(f”Type of exception => {type(ex).__name__}”)
print(f”Exception => {ex}”)

>>> 3.0

# Checking the stricter `jax.lax` API 😭
try:
jax.lax.add(1, 2.0)
except Exception as ex:
print(f”Type of exception => {type(ex).__name__}”)
print(f”Exception => {ex}”)

>>> Type of exception => TypeError
>>> Exception => lax.add requires arguments to have the same dtypes, got int32, float32. (Tip: jnp.add is a similar function that does automatic type promotion on inputs).

What’s next? I recommend PyImageSearch University.

Course information:
69 total classes • 73 hours of on-demand code walkthrough videos • Last updated: February 2023
★★★★★ 4.84 (128 Ratings) • 15,800+ Students Enrolled

I strongly believe that if you had the right teacher you could master computer vision and deep learning.

Do you think learning computer vision and deep learning has to be time-consuming, overwhelming, and complicated? Or has to involve complex mathematics and equations? Or requires a degree in computer science?

That’s not the case.

All you need to master computer vision and deep learning is for someone to explain things to you in simple, intuitive terms. And that’s exactly what I do. My mission is to change education and how complex Artificial Intelligence topics are taught.

If you’re serious about learning computer vision, your next stop should be PyImageSearch University, the most comprehensive computer vision, deep learning, and OpenCV course online today. Here you’ll learn how to successfully and confidently apply computer vision to your work, research, and projects. Join me in computer vision mastery.

Inside PyImageSearch University you’ll find:

&check; 69 courses on essential computer vision, deep learning, and OpenCV topics
&check; 69 Certificates of Completion
&check; 73 hours of on-demand video
&check; Brand new courses released regularly, ensuring you can keep up with state-of-the-art techniques
&check; Pre-configured Jupyter Notebooks in Google Colab
&check; Run all code examples in your web browser — works on Windows, macOS, and Linux (no dev environment configuration required!)
&check; Access to centralized code repos for all 500+ tutorials on PyImageSearch
&check; Easy one-click downloads for code, datasets, pre-trained models, etc.
&check; Access on mobile, laptop, desktop, etc.

Click here to join PyImageSearch University

Summary

Great job on completing the first part of the tutorial on JAX! In this tutorial, we covered the background and origins of JAX, specifically highlighting its parent libraries autograd and xla. We also explored the API layering of JAX and delved into the details of the two most commonly used APIs: jax.numpy and jax.lax.

Now, it’s time to move on to the next part of the tutorial, which will focus on functional transformations in JAX. These transformations, such as grad, jit, vmap, and pmap, are essential tools in the JAX toolkit and allow you to optimize your code for better performance and efficiency.

Finally, in the third and final part of the tutorial, we will put everything we’ve learned to the test by training a model from scratch using JAX. This will be a great opportunity to apply the concepts and techniques covered in the first two parts of the tutorial and see the power of JAX in action. So buckle up because the next tutorial will be an exciting and resource-filled adventure!

We would like to acknowledge the detailed review and discussion from Jake Vanderplas.

References

What is Automatic Differentiation?You don’t know JAXWhy You Should (or Shouldn’t) be Using Google’s JAX in 2023The Sharp Bits 🔪 — JAX documentationTraining a Simple Neural Network, with tensorflow/datasets Data Loadingjax/README.md at main · google/jax · GitHubGitHub – google/jax: Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and moreJAX Crash Course – Accelerating Machine Learning code!

Citation Information

A. R. Gosthipaty and R. Raha. “Learning JAX in 2023: Part 1 — The Ultimate Guide to Accelerating Numerical Computation and Machine Learning,” PyImageSearch, P. Chugh, S. Huot, K. Kidriavsteva, and A. Thanki, eds., 2023, https://pyimg.co/uwe1j

@incollection{ARG-RR_2023_JAX1,
  author = {Aritra Roy Gosthipaty and Ritwik Raha},
  title = {Learning {JAX} in 2023: Part 1 — The Ultimate Guide to Accelerating Numerical Computation and Machine Learning},
  booktitle = {PyImageSearch},
  editor = {Puneet Chugh and Susan Huot and Kseniia Kidriavsteva and Abhishek Thanki},
  year = {2023},
  url = {https://pyimg.co/uwe1j},
}

Want free GPU credits to train models?

We used Jarvislabs.ai, a GPU cloud, for all the experiments.
We are proud to offer PyImageSearch University students $20 worth of Jarvislabs.ai GPU cloud credits. Join PyImageSearch University and claim your $20 credit here.

In Deep Learning, we need to train Neural Networks. These Neural Networks can be trained on a CPU but take a lot of time. Moreover, sometimes these networks do not even fit (run) on a CPU.

To overcome this problem, we use GPUs. The problem is these GPUs are expensive and become outdated quickly.

GPUs are great because they take your Neural Network and train it quickly. The problem is that GPUs are expensive, so you don’t want to buy one and use it only occasionally. Cloud GPUs let you use a GPU and only pay for the time you are running the GPU. It’s a brilliant idea that saves you money.

JarvisLabs provides the best-in-class GPUs, and PyImageSearch University students get between 10-50 hours on a world-class GPU (time depends on the specific GPU you select).

This gives you a chance to test-drive a monstrously powerful GPU on any of our tutorials in a jiffy. So join PyImageSearch University today and try it for yourself.

Click here to get Jarvislabs credits now

To download the source code to this post (and be notified when future tutorials are published here on PyImageSearch), simply enter your email address in the form below!

Download the Source Code and FREE 17-page Resource Guide

Enter your email address below to get a .zip of the code and a FREE 17-page Resource Guide on Computer Vision, OpenCV, and Deep Learning. Inside you’ll find my hand-picked tutorials, books, courses, and libraries to help you master CV and DL!

The post Learning JAX in 2023: Part 1 — The Ultimate Guide to Accelerating Numerical Computation and Machine Learning appeared first on PyImageSearch.

 Table of Contents Learning JAX in 2023: Part 1 — The Ultimate Guide to Accelerating Numerical Computation and Machine Learning 🙌🏻 Introduction Configuring Your Development Environment Having Problems Configuring Your Development Environment? 🤔 What Is JAX? autograd XLA 👀 What…
The post Learning JAX in 2023: Part 1 — The Ultimate Guide to Accelerating Numerical Computation and Machine Learning appeared first on PyImageSearch.  Read More Deep Learning Tutorial, JAX, JAX Tutorial, Numerical Computing Library, autograd, computing, deep learning, jax, lax, machine learning, numpy, pytorch, xla 

Leave a Reply

Your email address will not be published. Required fields are marked *