# JAX Crash Course - Accelerating Machine Learning code!

## Метаданные

- **Канал:** AssemblyAI
- **YouTube:** https://www.youtube.com/watch?v=juo5G3t4qAo
- **Дата:** 25.06.2022
- **Длительность:** 26:38
- **Просмотры:** 14,248
- **Источник:** https://ekstraktznaniy.ru/video/13072

## Описание

Learn how to get started with JAX in this Crash Course. JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.

Get your Free Token for AssemblyAI Speech-To-Text API 👇https://www.assemblyai.com/?utm_source=youtube&utm_medium=referral&utm_campaign=yt_pat_43

Colab: https://colab.research.google.com/drive/1_-e5MUrGfS7r1veKiVADuiLI_9e1UAih?usp=sharing
Website: https://jax.readthedocs.io/en/latest/index.html
JAX blog post: https://www.assemblyai.com/blog/why-you-should-or-shouldnt-be-using-jax-in-2022/

▬▬▬▬▬▬▬▬▬▬▬▬ CONNECT ▬▬▬▬▬▬▬▬▬▬▬▬

🖥️ Website: https://www.assemblyai.com
🐦 Twitter: https://twitter.com/AssemblyAI
🦾 Discord: https://discord.gg/Cd8MyVJAXd
▶️  Subscribe: https://www.youtube.com/c/AssemblyAI?sub_confirmation=1
🔥 We're hiring! Check our open roles: https://www.assemblyai.com/careers

▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬▬

00:00 Intro & Outline
01:22 What is JAX
02:55 Speed comparison
05:00 Drop-in Replacement 

## Транскрипт

### Intro & Outline []

welcome to this chex crash course if you do something with deep learning or any scientific computing in python you should definitely learn about jaxx because it can significantly speed up your code and has lots of other great features i mean just look at this first example where we compare normal numpy and then checks numpy and when we look at the execution times then we see checks is over 100 times faster in this example so this is super cool and today i show you what checks is then we go over all the essential things we can do with it and in the end we also talk about why we cannot use checks everywhere and where we have to be careful with it so let's get started alright so first let's have a quick look at what we discussed today so first we have a look at what jax is then we do a little speed comparison then we learn how we can replace it for numpy and then we have a look at the four essential functions so we learn how we can speed up computations with chit we learn how we can do automatic differentiation with grad automatic vectorization with vmap and automatic parallelization with pmap and then i also show you a short example training loop with checks so how this could look like then we also discuss what's the catch so where we have to be careful and why we cannot use it everywhere and then we have a short summary in the end so let's get started so first of all what is checks is

### What is JAX [1:22]

developed by google and it is autograd and xla brought together for high performance numerical computing and machine learning research it provides composable transformations of python and numpy programs so with this we can differentiate vectorize parallelize and we can do just in time compilation so in simpler words checks is numpy on steroids or it's numpy on the cpu gpu and tpu but it's actually much more than that so we can for example also get automatic differentiation and this is super helpful if you do anything with deep learning or scientific computing so now let's have a quick look at what is xla so if we go to the link we actually land on the tensorflow website and by the way i put the link to this google co-op in the description so you find this below the video so xla stands for accelero accelerated linear algebra and this is the foundation what makes check so powerful it's also developed by google and it is a domain specific graph based just in time compiler for linear algebra so you don't have to understand what this means in detail you just have to know that this is the underlying foundation and this significantly improves execution speed and lowers memory usage by fusing low level operations so we have a more detailed look at this just in time compiler in a few moments

### Speed comparison [2:55]

now first of all let's do a simple speed comparison so here we have a simple numpy code where function that does this computation and then we put in a numpy array and time this so now if we run this then numpy doesn't have gpu support so this is actually run on the cpu and in a few moments we hopefully should see how long this takes all right so here we have it so this one was the fastest execution and now let's have a look at the same code in numpy so we also have checks numpy which is basically the same as normal numpy but the jack's implementation but the api is very very similar so almost identical and then we can also use this chit so this is the just in time compiler and with this we can transform the function to be just in time compiled and then we can also transform the numpy array to a jax numpy array and now if we run the same code but with checks then this is significantly faster so in this case i'm also running it on the cpu but we see that already this brought a much better performance and by the way we use this block until ready because checks uses asynchronous execution so be aware of this and now let's actually change the runtime to a gpu in google call up and then let's run this again and then have a look at the performance with gpu support all right and now we see we get this execution time so only 3. 74 milliseconds so this is a significant performance boost so this time numpy was also a little bit faster but still this is over almost over a hundred times faster which is huge

### Drop-in Replacement for NumPy [5:00]

so now let's learn how we can use checks numpy as a drop in replacement for numpy so one use case is to use checks as a drop-in replacement for numpy and then we get gpu support and this often works with our problems because the api is almost identical so let's have a quick look at this so here we have a simple code with numpy where we use a few functions so we use linspace and then here we use the numpy sinus and cosigners and then we plot this with matplotlib and now let's have a look at the same code with jax numpy so we import checks numpy s j and p and then these functions are actually the very same so the very same api but here we use checks numpy and then we can use linspace sinus and cosinus and if we run this then we get the same plot so it basically has the same values now and yeah so this is how we can use this as a drop in replacement and this alone will get you a great performance boost because then we can have gpu and tpu support the only very big difference that we have to be aware of is that checks arrays are actually immutable and we talk more about why this is the case in a moment but yeah be aware that checks areas are immutable so if we have a look at this code with a numpy array where we create an array and then we change one of the values then this works but if we do the same with a checks numpy array and run this then we actually get a type error so here it says object does not support item assignment so checks arrays are immutable so be aware of this

### jit(): just-in-time compiler [6:56]

and now the next thing we have a look at is chit so how we can speed up our computations with jet so chit or also known as just in time compilation is a method of executing code that lies between interpretation and ahead of time compilation so ahead of time is for example what we could get with c plus plus and with just in time compilation we compile the code when we run this the first time so at runtime it is compiled into a fast executable so the first time this can be a little bit slower but the next time this is cached and then the compiled code is much faster than just normal python code and it works by using this chit function so we import this from checks so i already did this up here so we say from jax import chit and then for example we can use a we can apply this to function so this is also known as a function transformation so here we have a normal function that does some computation and then here again i time this with a normal numpy array then here chex numpy array and then here we simply apply this jit function to this function and then we call this jitted and then we can also call this as a function and we put in the checks numpy array and then also compare this so if we run this then let's wait a few moments and then let's have a look at the results so this is the normal numpy execution this is with a jax numpy array on the gpu and this is as a jittered version so this is the fastest of all and this is how we can get this huge performance boost so we could use it this way by just applying the jit function and then put in the function as argument or we can use this as a decorator so we can decorate a function with at chit and now this has the same effect so again if we run this we get a very fast execution time and now let's have a look at how this works so chit and other checks transforms work by tracing a function to determine its effect on inputs of a specific shape and type so this tracing here is very important and let's look at an example so here we have a function that we decorated with at chit and then here we do a few print statements and then we do a operation in here so this is a jax numpy dot product and then we return the result then here we create two numpy arrays so let's run this cell and now call the function and see what we get and notice that here we don't see the print statements but rather than printing the data we print the or it prints the tracer objects that stand in for them and when we run this function again then we actually no longer see these tracer objects but only the output which is again the same as in the first run so it's the fir the same result but now we no longer need these tracer we only need this the very first time and these tracer objects are what jack's chit uses to extract the sequence of operations specified by the function so basic tracers are stand-ins that encode the shape and data type of the arrays but they are agnostic to the values so this is also very important and i will come back to this again in a moment so it's agnostic to the values and now this recorded sequence of computations can then be efficiently applied within xla to new inputs with the same shape and data type without having to re-execute the python code and this is also why we no longer see these objects in the second run so this is roughly how chit works

### Limitations of JIT [11:32]

so now let's have a few a quick look at a few limitations of chit so this is also very important why we cannot simply use everywhere and speed up everything in our code so this sentence here is very important because chit compilation is done without information on the content of the array control flow statements in the function cannot depend on traced values so like i said before it is agnostic to the values so a function like this cannot be jitted so here we have a control flow statement a if statement that depends on the value of x if x is greater than zero we return this and otherwise this so now if we try to apply chit to this function and run this then we get a very long and mysterious error so here it says concretization type error abstract tracer value encountered where concrete value is expected so this might sound confusing but basically you just have to know this you don't you can't use control flow statements that depend on the value so be very careful with this and now to make this chit work you have to implement a so-called pure function so let's quickly have a look at what does it mean to be a pure function so for this again i recommend reading this part of the checks documentation about pure functions but basically it means that you can't have control flow statements that depend on the values like i just showed you then it cannot use or change global states so only variables inside the function and not outside of its scope so we can't use global variables so we cannot use them we cannot access them and we also cannot change them then we cannot have an io stream so no printing asking for input or accessing the time and we cannot have a mutable function as argument so these points here are very important to keep in mind if you want to apply this chit just in time compiler then be very careful to implement pure functions because otherwise this is very dangerous because then you can have untracked side effects so it can change the results of your computation without you knowing about this and this can be super dangerous if it changes the result so yeah be very careful to make sure to understand how chit works and how you implement pure functions so that you can apply chit because once you know this then this is super powerful to speed up your code like i showed you before but again be aware of these side effects now let's look at another great

### grad(): Automatic Gradients [14:35]

feature of jax and this is automatic differentiation with grad so most people only know checks for the numpy on the gpu or for this just in time compiler but this is actually another very powerful feature especially for scientific computing or deep learning so with this we can very simply calculate gradients or second or third order gradients and so on and the difference between check scrap and deep learning libraries is that here it follows more the underlying math rather than using back propagation and this can be much faster with checks so this graph function works with scalar valued functions so again this is a function transform and scalar valued function means that it's a function which maps scalars or vectors to scalars so let's look at an example here we say from jack's import grad then here we have a function where we do some computations and this takes an input so this can be for example a numpy array or a jax numpy array and then we apply this gradient to the function this will give us the gradient of f with respect to x and then we can apply this multiple times so this will be d2 of f with respect to x so we could do it like this or we can simply use this result from the first step and do it like this so this will give you the same and then here i execute this and print this so this is the original function and now here we call this and then here we call the first gradient and here the second order gradient and if you compare this and also do the math for yourself and calculate this then you see this will be the correct results so this is how simple you can calculate the gradients by just using this grad function so here is another function so this is a little bit longer code where we do for example a if statement and a for loop and here we basically calculate a rectified cube and notice that here the if statement depends on the value so this is different than before with chips so with if we apply chit to this function then it would not work but for the gradient here this is okay and then this is still correct so now if we calculate the gradient and then again compare the math then this is correct so um in this example like i said this calculated the gradient with respect to x and if we have a function with multiple parameters then by default it will always calculate the gradient with respect to the first parameter so if we simply call a graph on this function then this will be with respect to x and if we want to change this for example if we wanted with to be with respect to y then we can pass in this arc numps um argument and here we say one so the index indices start at zero so now it will be with respect to 0 to y we could also put in multiple arguments so now it will compute both the gradients with respect to x and also y so this would also work so now let's run this and then here we see the results and yeah so this is um when we want to apply this to scalar valued functions but we can also calculate gradients for vector-valued functions and for vector-valued functions we can calculate the jacobian and the hessian so vector-valued means you now map vectors to vectors and then the analog to the gradient is the jacobian and for this we can use two transformations either check forward or check ref this corresponds to forward mode differentiation and reverse mode differentiation and they should give both give you the same result these are just two different ways of implementing this and sometimes one can be faster than the other but i won't go into more detail here for this just have a look at the official documentation but the way it works is that now we import this from checks and then here we have a function that maps a vector to another vector and then here we apply our jacobian to the function then here we create an input and then we call our jacobian and now if we run this and compare the result to the actual math then you should see that this is the correct result and the way to get a hessian is by first calling gret and then this jacobian and then here again we have an example function and then we apply our hessian to the function and give it an input and then again we can run this and compare to the math and see that this is correct so yeah this is how we can calculate the gradients for scalars and vectors now

### vmap(): Automatic Vectorization [19:57]

let's have a look at another great feature of checks and this is automatic vectorization with vmap which stands for vectorizing map and it has the familiar semantics of mapping a function along array axes but instead of keeping the loop on the outside it pushes the loop down into a function's primitive operations for better performance so let's have a look at an actual example to make this more clear so in this example we have a function where we want to convolve a vector x with weights w and for this we have a for loop where we iterate over each element of x and then here we apply the dot product of x and then the weights w and then append this to the output and now suppose we would like to apply this function to a batch of weights and also a batch of vectors so for this we simply stack a couple of vectors and the manual way to do this is to do a simple loop over the batch so here we loop over the size of the batch and then access all the elements and then for each element in the batch we apply the convolve function so this would work but with this we map function it's actually much more simpler and also much faster so here we simply apply vmap to the convolve function and then here we call this and print the results and here we also print this result and now if we run this then the first print statement is the normal convolve then here we see we get the same result but with vmap it it's much shorter and much faster so this is another very useful feature then

### pmap(): Automatic Parallelization [21:48]

i also want to quickly mention another very useful feature and this is automatic parallelization with pmap so with this we can for example do distributed training and distribute our computations for example across multiple tpus and then it's much faster and for this i don't go into more detail i just want to mention that this is possible with the checks by simply applying this pmap function so for this again have a look at the documentation and know that this feature exists so

### Example Training Loop [22:18]

these were all the essential features with checks i wanted to show you and now the last code i want to show you is a simple example training loop with checks so here we implement a simple linear regression model and now i show you how this could look like with checks so here we create some sample inputs for x and y and then if we plot this then this is how our function looks like and now let's um approximate this with a linear regression model so for this we create a model function where we simply apply the weights times x plus a bias b then we define a loss function and in the case of linear regression we want the mean squared error so first we apply the model which will give us the prediction and then we call checksnumpy. mean and here we have the prediction minus the actual value to the power of 2. then we also define an update function so this is important here so to update the weights we apply gradient descent and the way it works is that we start with the theta and then move into the direction of the gradient so we need the gradient and then the learning rate basically is our step size so to get the gradient we learned this before we simply have to call the grad function and apply this to the loss function and then we can move in this direction and now to do the training we initialize the values of w and b and then we do this a hundred times and call the update step and then we get the values for theta so for w and b and if we run this then sorry i have to run this cell first and if we run this then this might take a few moments to do the training 1000 times and now we get the result so for the w we get 3 and for b we get -1 and this is how the values looks like if we plot this with our approximated model so yeah this looks almost similar to this so this works and this is how we can do a simple training loop with checks all right so these were all the

### What’s the catch? [24:38]

things i wanted to show you this all sounds super exciting and now let's briefly talk about what's the catch so i talked about this before already when we talked about the just in time compiler so we need to make sure to follow the functional paradigm so we need to implement pure functions and have to understand what this means because otherwise we can get untracked side effects and this is super dangerous so be very careful if you want to apply the jit transform then we have to be aware that we cannot modify arrays and with checks we need explicit random number handling so for this i also recommend to read this article from the documentation but basically we cannot use a stateful pseudo random number generator like with numpy and here we have to be very careful when we want to reproduce our results so if you care about this and random numbers then also make sure to read this article but yeah overall what's the conclusion so i think jax is a very promising project and it has been steadily growing in popularity despite the learning curve and it's still considered as an early experimental framework but yeah i'm very excited to see what's coming in the future um if you just want to use it as a drop-in replacement for numpy then it shouldn't be too hard and also you don't have to worry about a lot of things this works most of the time without problems then the autograd feature is very useful for scientific computing also for deep learning you can get a huge performance improvement but of course as i said you have to understand how pure functions and how this chit works so you don't get side effects but yeah overall i think it's really cool let me know in the comments if you like this tutorial let me know what you think about checks if you want to use it in your projects and then i hope you enjoyed this tutorial and i hope to see you in the next one bye
