Efficient and Modular Implicit Differentiation (Machine Learning Research Paper Explained)
32:46

Efficient and Modular Implicit Differentiation (Machine Learning Research Paper Explained)

Yannic Kilcher 11.06.2021 18 095 просмотров 629 лайков

Machine-readable: Markdown · JSON API · Site index

Поделиться Telegram VK Бот
Транскрипт Скачать .md
Анализ с AI
Описание видео
#implicitfunction #jax #autodiff Many problems in Machine Learning involve loops of inner and outer optimization. Finding update steps for the outer loop is usually difficult, because of the.need to differentiate through the inner loop's procedure over multiple steps. Such loop unrolling is very limited and constrained to very few steps. Other papers have found solutions around unrolling in very specific, individual problems. This paper proposes a unified framework for implicit differentiation of inner optimization procedures without unrolling and provides implementations that integrate seamlessly into JAX. OUTLINE: 0:00 - Intro & Overview 2:05 - Automatic Differentiation of Inner Optimizations 4:30 - Example: Meta-Learning 7:45 - Unrolling Optimization 13:00 - Unified Framework Overview & Pseudocode 21:10 - Implicit Function Theorem 25:45 - More Technicalities 28:45 - Experiments ERRATA: - Dataset Distillation is done with respect to the training set, not the validation or test set. Paper: https://arxiv.org/abs/2105.15183 Code coming soon Abstract: Automatic differentiation (autodiff) has revolutionized machine learning. It allows expressing complex computations by composing elementary ones in creative ways and removes the burden of computing their derivatives by hand. More recently, differentiation of optimization problem solutions has attracted widespread attention with applications such as optimization as a layer, and in bi-level problems such as hyper-parameter optimization and meta-learning. However, the formulas for these derivatives often involve case-by-case tedious mathematical derivations. In this paper, we propose a unified, efficient and modular approach for implicit differentiation of optimization problems. In our approach, the user defines (in Python in the case of our implementation) a function F capturing the optimality conditions of the problem to be differentiated. Once this is done, we leverage autodiff of F and implicit differentiation to automatically differentiate the optimization problem. Our approach thus combines the benefits of implicit differentiation and autodiff. It is efficient as it can be added on top of any state-of-the-art solver and modular as the optimality condition specification is decoupled from the implicit differentiation mechanism. We show that seemingly simple principles allow to recover many recently proposed implicit differentiation methods and create new ones easily. We demonstrate the ease of formulating and solving bi-level optimization problems using our framework. We also showcase an application to the sensitivity analysis of molecular dynamics. Authors: Mathieu Blondel, Quentin Berthet, Marco Cuturi, Roy Frostig, Stephan Hoyer, Felipe Llinares-López, Fabian Pedregosa, Jean-Philippe Vert Links: TabNine Code Completion (Referral): http://bit.ly/tabnine-yannick YouTube: https://www.youtube.com/c/yannickilcher Twitter: https://twitter.com/ykilcher Discord: https://discord.gg/4H8xxDF BitChute: https://www.bitchute.com/channel/yannic-kilcher Minds: https://www.minds.com/ykilcher Parler: https://parler.com/profile/YannicKilcher LinkedIn: https://www.linkedin.com/in/yannic-kilcher-488534136/ BiliBili: https://space.bilibili.com/1824646584 If you want to support me, the best thing to do is to share out the content :) If you want to support me financially (completely optional and voluntary, but a lot of people have asked for this): SubscribeStar: https://www.subscribestar.com/yannickilcher Patreon: https://www.patreon.com/yannickilcher Bitcoin (BTC): bc1q49lsw3q325tr58ygf8sudx2dqfguclvngvy2cq Ethereum (ETH): 0x7ad3513E3B8f66799f507Aa7874b1B0eBC7F85e2 Litecoin (LTC): LQW2TRyKYetVC8WjFkhpPhtpbDM4Vw7r9m Monero (XMR): 4ACL8AGrEo5hAir8A9CeVrW8pEauWvnp1WnSDZxW7tziCDLhZAGsgzhRQABDnFy8yuM9fWJDviJPHKRjV4FWt19CJZN9D4n

Оглавление (8 сегментов)

Intro & Overview

Hello there. Today we're going to look at efficient and modular implicit differentiation by researchers of Google research. This paper on a high level extends what you know from frameworks like TensorFlow or PyTorch or Jax in terms of automatic differentiation. It extends it to multi-level optimization procedures. So this paper makes it possible that you differentiate through an inner optimization loop without having to unroll that inner optimization loop and without having to implement the optimization procedure in a differentiable way. Uh this has been done before for single instances of problems always with sort of specific uh derivations for that particular problem. But this paper uh provides a unified framework of doing this. And so it's a bit of a technical paper and we won't go in this too technical uh mode because I'm also not the most or the biggest expert on the methods used here. I just wanted to raise a bit of awareness that this exists because the ability to back propagate through sort of inner optimization procedures and even like other things in a unified way without having to unroll. I think it unlocks a bunch of research that has been quite cumbersome so far and could be interesting to a lot of people. They do provide code and everything and they prove or [snorts] they show that many special instances that have been derived in the past and also a bunch of new ones are just instances of their framework and can be solved sometimes much more easily with their framework. Uh they even provide some approximation guarantees and so on. I think interesting to us is just going to be a little bit of the insight of why and how this works and the fact that it exists. So let's jump in the they say that

Automatic Differentiation of Inner Optimizations

automatic differentiation has revolutionized machine learning. It allows expressing complex computations by composing elementary ones in creative ways and removes the burden of computing their derivatives by hand. This is absolutely true. If you look at old papers in deep learning, half the paper would be spent on you know deriving the gradients of the architecture that was just proposed. So you could actually implement it and now we have autodiff which means that the frameworks they simply do this by themselves. You just compose a bunch of functions and you call gradient on them. uh this is a big part of what has spurred the deep learning revolution in the past few years at least from a implementation point of view right I don't think a lot of architectures would have happened if people always had to derive the gradients by hand and it's kind of obvious to do this if you know the backrop algorithm but still it is a big helper now as I said this paper exposes or sorry this paper extends the concept the spirit of autode uh to a much larger class of applications. They say more recently differentiation of optimization problem solutions has attracted widespread attention with applications such as optimization as a layer and in ble problems such as hyperparameter optimization and metalarning. So the key here is differentiation of optimization problem solutions. So I have a an inner optimization problem and I obtain a solution and I want to back propagate through not only through the solution itself but actually through the path that led me to finding that solution and metalarning is a good example. Hyperparameter optimization of course as well. So in metalarning what you do and this is a simple uh thing there are many various tasks in metalarning but I've done a video on one of those which is called I mammal it's an extension of mammal and I think the ML stands for metalarning um the I here for implicit

Example: Meta-Learning

which is of course going to be related to the implicit differentiation we do right here um or implicit. The implicit here stands for the fact that we can implicitly uh derive the gradient. We don't have to go through the whole unrolling. So in IMAL there is a setting where you have multiple tasks. You have a data set and there is task one, task two and task three. So maybe this is classifying food by taste. calories. This is classifying food by some other nutrient or uh color or something like this. H now and this all should happen with the same architecture of neural network simply you know solving different tasks. So obviously the different tasks are going to have different optima different local optima and from deep learning of course we know that these are never in the same place. There are many local optima but let's just pretend for a moment we knew that these were the three optima. The task of metalarning is can we find an initialization that is really good such that if we fine-tune on any of these tasks if we get data from we can learn it really quickly. So if you know, if you see here, if we choose this as an initialization, it's going to take us a while to get to any of these solutions. However, if we choose this as our initialization, we're here pretty quickly. And in fact, if a new tasks comes that is similar to the other ones, let's say one here, right? That's kind of similar. It's on the same uh hyper plane, whatnot, you can see that we're also there fairly quickly. So the question is how do we find the blue point? Uh obviously we don't know where the green points are and they're non-deterministic anyway. And the answer is we start with anyone like this one. We start with a guess and we move uh point you know step by step into a better direction just as we do with gradient descent. However, how do we know what a good direction is? In order to is, we need to know how good is this initialization. So consider this one. initialization? Well, in order to do that, we actually need to do the optimization procedure. So we do that and we see well that leads us in that direction. We optimize for a different task that leads us in that direction. And now we get an idea that hey maybe if all the tasks go into the same direction uh maybe you know it would be good if we also went into that direction. Specifically what we want is we want the gradient um the gradient with respect to our initialization of the solution of a particular task given that initialization. Right now the solution itself of course is an optimization procedure. So you have an inner optimization procedure that you want to back propagate through. What you usually have to do is you have to unroll

Unrolling Optimization

that optimization procedure. So if you think of gradient descent, so here is your weights and what you do is you subtract learning rate times the gradient. So here is it at step t, right? um learning rate with respect to the weights of f of x and wt. Okay, that's your standard gradient descent. So what does that give you? All of that gives you w t + one. And now you do another step of gradient descent. Okay? So minus again gradient with respect to this this maybe it's a different data point. Maybe it's the same plus one. Okay, so it's it already gets complicated because now this quantity here which is all the quantity of above appears twice. Okay, and if you do another step of course that quantity is going to replicate and be anywhere and autodiff framework can keep track of that. So if you do this and you actually write down from your first thing you write down you can unroll all of this into one big expression that gives you the end of the optimization procedure the end of gradient descent given the beginning you can do that and the TensorFlow or PyTorch they can keep track of this it's just it's going to be a big expression really slow And further what it needs what you need to do is you need to actually implement the gradient descent procedure as a differentiable procedure which is usually not done. Usually in especially in tensorflow and pytorch the gradient descent the optimization procedures they're sort of outside of the autodiff framework. In Jax it's a bit different but in TensorFlow and PyTorch the optimization procedures for good reason they themselves aren't differentiable. So you'd have to reimplement them in a differentiable way. [snorts] All of that is fairly cumbersome and uh people have asked themselves can we do better especially in this technique called IML uh people have found that instead of unrolling what we can do is if we regularize this objective in sort of a good way. So we add some sort of a um regularizer here. Then we can calculate the gradient this outer gradient without having to go through the whole unrolling step. Uh a similar situation you can imagine with hyperparameter optimization if you actually want to do gradient descent on your hyperparameter. So you have some sort of a validation set, right? You want to minimize your loss your valid loss on your validation set, right? Of your um of your with respect to your hyperparameter lambda. Okay. And the solution you find is you minimize with respect to the weights of your loss function on the training set. Uh this is all green and looks horrible. But okay, I think that's it. Okay, so you want to for Oh, we need a lambda. We need a lambda right here. Okay. So for a given lambda for a given hyperparameter we want to find the best weights. Okay. But then lambda such that the weights give us the best validation loss. Right? Such that the weights that came from the training data set give us the best validation loss. We do this right now with grid search, but we could definitely imagine doing this with uh gradient descent if we could get a gradient for that hyperparameter. But that requires us to back propagate through this inner optimization procedure through the actual learning of the neural network. Now given that neural networks usually train in thousands or millions of steps, uh unrolling that is not going to be an option. Like TensorFlow is good, but it's not that good. Okay. So it can technically keep track of it but it's just not going to be possible. So for all of these problems or for many of these problems people have devised individual solutions like given very strict requirements given the exact problem formulations we do have solutions where we don't have to unroll. Uh however these are case by case and much like the old papers on neural networks where every time you have to derive your gradient here every one of these papers has to sort of derive how they apply their conditions how they do how they apply the kushkum tucker conditions in order to get the implicit gradient and so on. And this here this paper is what autodiff is for these old papers. Um so

Unified Framework Overview & Pseudocode

they go on yeah they say involves case by case tedious mathematical derivations. In this paper we propose a unified efficient and modular approach for implicit differentiation of optimization problems. In our approach, the user defines in Python in the case of our implementation a function f capturing the optimality conditions of the problem to be differentiated. Once this is done, we leverage autodeon f and implicit differentiation to automatically differentiate the optimization problem. Okay. So what you do is you don't specify the gradient of the optimization procedure. You specify a function that captures the optimality conditions of the problem to be differentiated. And if that function here is differentiable, then this framework can do its magic to give you the gradient through the optimization procedure. So we shift away from the optimization procedure itself having to be differentiable to only the specification of the optimality conditions having to be differentiable. which is a huge gain, right? Um yeah, so they say it can be this can be actually done in many ways. Uh you can choose your solver and so on, but we'll go through the very basics right here. Okay. Um this is ultimately what is going to end up and this is a problem of I think hyperparameter optimization as we saw. So this is ridge regression and ridge regression is a you have a data set. Okay, you have labels. So x is a is a matrix where each kind of row I think is a or column I think row is a data point and y is a vector of labels numeric labels. Uh and what you want to do is you want to find weights w such that w * x um equals to y. Okay, that is linear regression of course. Now in ridge regression you have a regularization on y uh sorry on w. So it's easier you often to specify the loss. So what you want is that this is small but also that w has some small norm and um so you want this being small and you want the norm of w also to be small and this is a common regularization technique uh to want the norm of w to be small. It sort of means that your line kind of stays rather flat. So if you have a bunch of outliers, uh they won't affect your approximation too much. Um it's very it's a very common technique. The important part is there is a hyperparameter right here and um this hyperparameter is a matter of choice. This is the regularization constant. Now with this framework we can run gradient descent on that hyperparameter and the way we have to we do it is the following. So we start actually with down here. So this called ridge solver. This is the inner optimization. This is the solver of the ridge regression. Now ridge regression has a closed form solution. We can just solve uh we can put this as a linear problem. So here uh you get x * x and here y and then you get yourself a diagonal matrix that you can multiply with the um with the regularization constant and then you can simply uh put up this linear system. So that's the linear system corresponds to x uh * x + theta well in this case in our case it was lambda. this should equal to x * y. So if you solve this um then you'll get you'll get the uh what am I saying? Sorry the linear system is going to be this times w. If you solve this for w you'll get the direct solution to ridge regression. There's no gradient descent here, but it would be totally cool if this contained gradient descent. Okay, the next thing you'd have to do is you have to specify the optimality conditions. Now, in this case, we're sort of going to repeat the loss function of ridge regression. So as you can see here the optimality conditions of course are dependent on uh x here and x is going to be the w actually what we call w and theta is your hyperparameter. So you can see this is just the loss here you multiply w by x and subtract y that's what's called the residual and this here is the square norm of that. So in our loss function up here we'd have sort of square L2 norms everywhere and um you can see here this is the regularization and the half here is for easier differentiation we don't have it uh but doesn't matter okay so this here is simply the loss function of ridge regression you can imagine more complicated things um now if I give you the loss function How do you what you need to give me is a function that is zero when optimality is met. And now that's pretty easy. If I have a loss function, the gradient of that loss function is exactly such a function. Okay, the gradient of the loss function is zero whenever the uh inner problem is optimal. So whenever the ridge regression is solved uh to optimality the gradient of this loss function is zero. Now we have all the ingredients. So what we can do now is we can use their custom decorator right here uh to say that here is the optimality condition. F is the optimality condition on this inner optimization problem. And if you do this then you can just back propagate through that. So here you can see that you can take the Jacobian of the ridge solver at here this is lambda equals 10 for example. So you can simply take derivatives uh through the inner optimization procedure because you have supplied this without having to back propagate through the inner procedure itself. I hope this was a little bit clear. Um so again you need to specify of course the inner procedure which is this thing here. Um in our metalarning case this would be the gradient descent the inner gradient descent. You need to specify the optimality conditions which in the easy case is simply a loss function and then the optimality condition is the derivative of or the gradient of the loss function. uh it's optimal whenever that is zero. Then you supply the optimality condition in the custom annotation to the function and then you can simply treat that inner function as if it were any other thing that you could back propagate through. So cool. Okay. Now they go into the whole math behind this and I don't want to go too much

Implicit Function Theorem

into the math but all of this essentially comes from um the implicit function theorem. So if you have this optimality condition you may have noticed it needs to be zero at optimum and this is what's called a root and the root is specified like this. So you have this inner function that depends on theta and you have the optimality condition that depends on the solution to the inner function and it depends on the or can depend on the parameter itself. If you have a construct like this under some regularity conditions on f, you can the implicit function theorem tells you that in essence you can express the gradient of these things with respect to each other. So from this you can get the derivative of this inner thing. You can get that locally. Okay, without having to back propagate through the procedure of how you found it, right? So it's an implicit gradient because it's defined as a as implicitly as a function of the other argument right here. If you look at this thing and you take the total derivative of this right here, you can use the chain rule to arrive at the expression down here. So if you uh derive the first argument right here, you get the chain rule in um in theta right. So you differentiate with respect to the first argument and then you also have to differentiate that first argument right here and then you differentiate with respect to the second argument and that is already theta of course. So now you can see we've ended up with only partial derivatives right here of simple arguments. So we need three things ultimately. You see this is the thing we want the gradient um of the solution of the inner optimization procedure. Now if we reorder a bit you can see the other things that we need for that is the number zero that that's easy. Uh we need two derivatives of f both are just s simple partial derivatives with respect to the arguments of f. And if f therefore is differentiable then we can get those things right and that's the exact shift I talked about before. So instead of the optimization procedure having to be differentiable only the optimality condition now needs to be differentiable and that's a much easier thing and again we can use autodiff we can use these frameworks for that. So as long as we can specify f in terms of somehow functions of the framework, we're good. The only so obviously the this function here is fully differentiable because it's the loss of logistic regression. The only tricky thing right here is that f big f capital f is actually the gradient of that function. So what we need is uh the framework to be able to differentiate the gradient again. So to obviously the gradient of or the derivative of capital f would be the derivative of lowercase f. But usually frameworks can do this right and this loss function is certainly differentiable twice. All right. And then it's just a linear system as you can see down here. So this is what they call a uh this is b this is j. So what you have to do is you solve the linear system ax plus uh sorry equals b and then whatever comes out here that's your gradient and you can use any classic sort of linear solver for that. So to repeat, you obtain A and B by using autodeiff on the optimality conditions and then you simply have to solve a linear system to get the gradient um of your solution of the inner optimization problem without ever having to unroll that inner optimization procedure, without having to back propagate through the steps of how you've how you arrived at that inner optimum. And that's the cool trick right here. So they can't only do this with a

More Technicalities

root. They can only they can also do this with optimalities that are specified as fixed points. So whenever the optimal solution to the inner problem has the property of being a fixed point of some function t can also use this method. So they I think they provide two different decorators. One is custom root and one is a custom fixed point. And from there you go. Uh so they discuss uh what they need. They discuss the technicalities. Uh they actually don't ever need to they calculate these things fully because they could become pretty big. They actually only need to calculate Jacobian vector products and vector Jacobian products. And they go into the technicalities here of how they obtain those. And the cool thing is that this fully integrates with the autoday framework. So here they talk about pre-processing and post-processing mappings. So you know what if we don't need the solution of the inner problem itself. What if we need a function of that and so on this can all be taken care of by the autode framework themselves uh sorry itself. So they see our implementation is based on jacks. Uh and they say it's it enters the picture in at least two ways. We can lean heavily on Jax within our implementation and we integrate the differentiation routines introduced by our framework into Jax's existing autodiff system. In doing the latter, we override Jax's default autodiff behavior, eg of differentiating transparently through an iterative solver's unrolled iterations. So if you stick this in, you can just differentiate through these things as if they were any other differentiable function in Jacks. Very cool. So the last thing um so here are all the different things that reduce to their method. If you actually uh if you go and look they give a lot of different examples of what other techniques reduce to their methods uh spec specifically you know we've seen these simple optimization procedures but you can also do sort of proximal methods uh in the inner optimization problem. You can do things like projected gradient fixed point uh which is maybe important for something like adversarial examples where you have to minimize a function but at the same time you have to stay within some convex set. So you always back project onto that set. Uh so now we can back propagate through the procedure of finding an adversarial example. Very cool. And they even give bounds because you cannot ever exactly calculate these things. So they give bounds on how far you're off. And lastly they do

Experiments

experiments and these are just more examples. So their first experiment pretty straightforward hyperparameter optimization of multiclass SVMs. So in a support vector machine you generally have a hyperparameter and that hyperparameter um here is sort of the strength of the regularization or like how much you trade off margin um versus slack. I believe I haven't done SVMs in a long time especially multiclass yet you need to stay within sorry you need to maximize the margin while staying within the uh probability simplex because it's multiclass so that's kind of a constrained inner problem but you would like to find the best hyperparameter for the trade-off parameter for the SVM uh with respect to an outer validation set. Okay, so you know that's a problem with two levels and they can do it right here. They can do dictionary learning. So in usually in dictionary learning it's not you need to somehow obtain the dictionary and then you optimize using the dictionary. Okay, so in dictionary learning, you have a some sort of a data point, maybe an image, and you map that into entries in a dictionary and then you use those entries to do something with it and then you have some kind of a loss right here. However, you can't optimize these functions that map and the dictionary itself at the same time. It becomes unstable. Uh so what people do is they do alternating or they have also they back propagate through some inner thing you know in this thing you can actually back propagate through the inner thing through the inner problem and find those dictionary elements as a function of which dictionary elements would actually most optimally solve the outer problems. Lastly this is data set distillation. uh they want to find [snorts] the optimal data set of size 10 right this is the data set that so if give me one image per class and if I train a neural network or whatever on that class on that data set of 10 images I want the best possible validation loss okay and that is an optimization so what you need to do is you need to start with 10 random images. You train your classifier. You measure it on the val on the validation set or whatever the test set. Um, and then you back propagate through the whole thing to update your data set itself. And in the end, you end up with the optimal data set. You can see that this is also a two-level optimization problem uh with maybe some constraints right here. I think this is a very cool idea. Honestly, it's probably I mean it existed before, but uh you can now do this. And in last they have these molecular dynamics where they uh want to see if we change kind of the size of these molecules, how do all of these things change so on. Again, this reduces to quite complex. Um this is the inner problem right here. But I think the point of all of this is that if you have a problem where it has sort of an outer and inner optimization structure and you want to use back propagation for the outer problem through the inner problem, uh give this method a try. It's pretty cool. If you're interested in the more technical aspect, give it a read. And that was it from me. I wish you a pleasant rest of the day. Bye-bye.

Другие видео автора — Yannic Kilcher

Ctrl+V

Экстракт Знаний в Telegram

Экстракты и дистилляты из лучших YouTube-каналов — сразу после публикации.

Подписаться

Дайджест Экстрактов

Лучшие методички за неделю — каждый понедельник