Grokking: Generalization beyond Overfitting on small algorithmic datasets (Paper Explained)
29:47

Grokking: Generalization beyond Overfitting on small algorithmic datasets (Paper Explained)

Yannic Kilcher 06.10.2021 82 778 просмотров 2 918 лайков

Machine-readable: Markdown · JSON API · Site index

Поделиться Telegram VK Бот
Транскрипт Скачать .md
Анализ с AI
Описание видео
#grokking #openai #deeplearning Grokking is a phenomenon when a neural network suddenly learns a pattern in the dataset and jumps from random chance generalization to perfect generalization very suddenly. This paper demonstrates grokking on small algorithmic datasets where a network has to fill in binary tables. Interestingly, the learned latent spaces show an emergence of the underlying binary operations that the data were created with. OUTLINE: 0:00 - Intro & Overview 1:40 - The Grokking Phenomenon 3:50 - Related: Double Descent 7:50 - Binary Operations Datasets 11:45 - What quantities influence grokking? 15:40 - Learned Emerging Structure 17:35 - The role of smoothness 21:30 - Simple explanations win 24:30 - Why does weight decay encourage simplicity? 26:40 - Appendix 28:55 - Conclusion & Comments Paper: https://mathai-iclr.github.io/papers/papers/MATHAI_29_paper.pdf Abstract: In this paper we propose to study generalization of neural networks on small algorithmically generated datasets. In this setting, questions about data efficiency, memorization, generalization, and speed of learning can be studied in great detail. In some situations we show that neural networks learn through a process of “grokking” a pattern in the data, improving generalization performance from random chance level to perfect generalization, and that this improvement in generalization can happen well past the point of overfitting. We also study generalization as a function of dataset size and find that smaller datasets require increasing amounts of optimization for generalization. We argue that these datasets provide a fertile ground for studying a poorly understood aspect of deep learning: generalization of overparametrized neural networks beyond memorization of the finite training dataset. Authors: Alethea Power, Yuri Burda, Harri Edwards, Igor Babuschkin & Vedant Misra Links: TabNine Code Completion (Referral): http://bit.ly/tabnine-yannick YouTube: https://www.youtube.com/c/yannickilcher Twitter: https://twitter.com/ykilcher Discord: https://discord.gg/4H8xxDF BitChute: https://www.bitchute.com/channel/yannic-kilcher Minds: https://www.minds.com/ykilcher Parler: https://parler.com/profile/YannicKilcher LinkedIn: https://www.linkedin.com/in/ykilcher BiliBili: https://space.bilibili.com/1824646584 If you want to support me, the best thing to do is to share out the content :) If you want to support me financially (completely optional and voluntary, but a lot of people have asked for this): SubscribeStar: https://www.subscribestar.com/yannickilcher Patreon: https://www.patreon.com/yannickilcher Bitcoin (BTC): bc1q49lsw3q325tr58ygf8sudx2dqfguclvngvy2cq Ethereum (ETH): 0x7ad3513E3B8f66799f507Aa7874b1B0eBC7F85e2 Litecoin (LTC): LQW2TRyKYetVC8WjFkhpPhtpbDM4Vw7r9m Monero (XMR): 4ACL8AGrEo5hAir8A9CeVrW8pEauWvnp1WnSDZxW7tziCDLhZAGsgzhRQABDnFy8yuM9fWJDviJPHKRjV4FWt19CJZN9D4n

Оглавление (11 сегментов)

Intro & Overview

hi there today we'll look at groing generalization Beyond overfitting on small algorithmic data sets by Alia power Yuri BDA Harry Edwards Igor babushkin and Vidant misra of open AI on High level this paper presents a phenomenon that the researchers call grocking where a neural network will generalize all of a sudden uh after having after weigh the point of overfitting on a data set so you train the network it completely overfits on a data set training loss is down training accuracy is 100% but it doesn't generalize at all to the validation set and then when you continue training the network at some point it will just snap into over uh into generalizing on these data sets that they're researching to a like a 100% generalization so 100% accuracy on the validation set this is extremely interesting and as you can see the paper has been presented at a workshop at iair 2021 which means that it is not yet it's sort of work in progress so there is still a lot of unclear things about this phenomenon it's a as I understand it a phenomenological paper that just presents look here is something interesting that we found and I think it's pretty cool so we'll dive into the paper we'll look at this phenomenon they do dig into it a little bit uh into what's happening here and try to come up with some

The Grokking Phenomenon

explanation so the basic premise of grocking is the graph you see on the left right here now it is a little bit pixel is but I hope you can still see what's happening the red part is uh the training accuracy and on the x-axis you have number of optimization steps and this is a log scale so that's important to see for training steps in this direction now the training accuracy naturally after a few steps it shoots up to 100% we we'll get to what data sets these things are in a second but it's important to see the network can in fact fit the training data extremely well and it just overfits however the validation accuracy it if you can see it there is a little bump here but then it goes down again almost um I don't know whether we should even regard this as a little bump that's actually happening however it just stays it stays down and then after you can see orders of magnitude more steps this is 10 to the 2 10 to the 3r 10 to the 4th 10 to the 5ifth steps it shoots up and it starts to generalize as well this is very interesting because um you know this essentially means you keep on training uh for a long time and when all hope is lost still the network at some point will generalize now why is this happening um and as I understand it it's not the case often that the network like drops down again out of generalization though I haven't actually seen this investigated like if they run for 10 to the I don't know how many steps but it seems like once the network is generalizing is uh has training accuracy of 100% it doesn't fall out of that again so the question is how does this happen like what What's Happening Here uh why is this happening why is it all of a sudden and what makes

Related: Double Descent

it work and for that it's a bit uh important to understand a very related phenomenon in fact a connected probably phenomenon called the double descend phenomenon in deep learning the double descent phenomenon graph looks somewhat similar in that the premise is that on the x-axis you have the number of parameters uh in a network so in a neural network and then on the Y AIS you have let's say loss okay or actually let's say accuracy I'm not sure L most of these plots for the double descent phenomenon are actually loss so if you consider the training loss um as you increase the number of parameters in your neural network you will fit the data better and better the training data so you get a curve that goes something like this and then it just stays at zero right so there's zero training loss um as you increase the number of parameters these every point on this line is a neural network with a given number of parameters that has just been optimized to convergence okay that's important to remember on the left here we saw a graph during optimization on the right here is a graph of many different networks all of which have been trained to uh convergence now what you see with the validation loss in this case so if you look at the validation loss it might at some point it might come down with the training loss right and then in the classic fashion of machine learning you as the number of parameters go up you start to sort of overfit the validation loss goes up again uh because you start overfitting you start memorizing the training data set and then at a point where pretty much the number of parameters equal the number of training data points like the number of let's just call this n then you have again like a really crappy validation loss because you're just remembering the training data however if you increase your parameter beyond that point so if you scale up your neural networks even more the validation loss will come down again and actually end up at a lower Point than if you were on this place over here if you had not enough parameters so there is a point Beyond overfitting uh where you have more parameters than data points and interest interestingly for neural networks it is the case that it happens that they can achieve generalization in fact better generalization with overparameterization than comparable underparameterized uh models Which flies in the face of all statistics and whatnot but we know this phenomenon exists okay so uh we knew that um things like this can happen like the training loss can be perfect and still we can have generalization right the groing phenomenon is a phenomenon where I'm going to guess the the creators of the double descend phenomenon haven't looked quite as far in order to I guess they simply ran training to convergence for a number of steps and then they they looked at the validation loss so I guess they would have stopped somewhere in between here between 10 to the 3r and 10 to the 4th steps This research here is simply what happens if we like let it run for a really long time then this shoots up uh as well and it seems like for a lot of conditions you can you can do this so now it's worth looking at what kind of data sets we are interested in here the

Binary Operations Datasets

data sets are synthetic data sets in this paper the synthetic data sets are binary operation tables so here the data sets we consider are binary operation tables of the form a um and then here this is like some sort of an binary operation a let's just call it multiplied a multiplied by b equals c where a B and C are discrete symbols with no internal structure and the circle is a binary operation examples of binary operations include addition composition of permutations byari polinomial and many more in fact they have some examples I think down here so here you see some examples like addition and multiplication but also more complicated things like a polom that you then um that you then uh do modulo a prime number a division and so on so the way you create a data set is you construct a table and in the table you have a number of these symbols and then you Define binary operations by simply filling in that table okay so if this were I don't know like a plus b and a and b are numbers then right A plus b is C if a is 1 B is 2 C is 3 and so on um but you can Define this as many different things uh a lot of the experiments in this paper are of the group S5 which is the group of all permutations of five elements which I think has like so this is a group with 120 elements so your table would here be 120 by 120 and the operation would be the sort of um composition of permutation so every permutation of five elements composed with another permutation gives you yet another permutation of five elements so you can just construct this table and then what you do is you just simply cross out a few things in the table so you say okay here I'm just going to cross out a few things and this is what the network should predict right I'm going to train the network on the data that I have and I'm going to predict the cells that I crossed out this way you can exactly measure how good the network is right there is no noise effectively in the data um it's all very well defined and a human goes about this with I guess with sort of a logical mind they try to figure out like ah what's the rule a neural network can simply remember the training data but then it will not generalize to the hidden Fields because it cannot memorize those so if a neural network generalizes here it also kind of means that it must have somehow learned the rule and this is pretty interesting so there are number of quantities to keep in mind um the three quantities are first of all what's the operation uh because there are more and less complicated things for these networks to learn just from the kind of difficulty the complexity of the operation itself second of all is the data set size or the size of the binary table itself in this case it's 120 by 120 um and the third one is how many things are left away so how large is the training data fraction the fraction of the table that is filled in for the network to learn all of these three things are going to play a crucial role in this grocking phenomenon and

What quantities influence grokking?

when and how it appears for example here you see um they have trained neural networks on this S5 group gr right the permutations of groups of five elements until they reach generalization so they simply run it and they measure how long does it take a network to reach 99% validation accuracy or higher right that's the thing on the left is essentially um you know the answer would be something like between 10 to the 5 and 10 to the six right okay so and they measure this as a function of you might not be able to read this but it says training data fraction okay how much of the training data is filled in and you can pretty clearly see if I just give it like here 20% of training data there are even some runs that do not generalize in this number of steps now would they generalize if you were to optimize for even longer who knows honestly but you can see that as soon as you give like 30% of the training data the runs in general do generalize but they take something like um here yeah 10 to the 5 number of steps to do so and then as you increase the training date to fraction uh this snap to the generalization happens faster and faster you can see right here as you give more training data uh it goes faster and faster until it generalizes and the generalization happens as I understand it yeah fairly like quickly like it doesn't generalize because it remembers the training data and this always happens as I understand it in a fairly similar number of steps um but then at some later point it just kind of snaps and completely generalizes to the uh validation set and this is really interesting so we know that the more training data we have around the better right that's one recognition um then the other thing is they try to figure out okay um which parts of the optimization algorithm are making this grocking phenomenon happen and here they figure out that uh weight decay in fact is one of the big drivers of this so if they add weight Decay to the algorithm and they try a lot of different things they try full batch versus mini batch with Dropout without Dropout uh modulating the learning rate and so on but weight Decay seems to be one of the biggest uh contributors to this groing uh phenomenon to the fact or to how fast these networks generalize you can see that the network generalizes much sooner uh if you have weight Decay turned uh up than not also they make the observation that uh if you have symmetric operations uh if your binary operation is symmetric then also the grocking phenomenon happens much faster than if you have like nonsymmetric operations this might just be a function of these networks which if you have like something like a Transformer uh you know it it's sort of kind of invariant to the symmetry so it might like essentially one data point is sort of two data points in Disguise if it's symmetric or there's only half as much stuff to learn uh you choose whatever you want to interpret this as but I think yeah this is not as

Learned Emerging Structure

important as the weight Decay and why do I highlight this um I highlight this because Al down here you can see they analyze then um they analyze the results of a Network that has learned to generalize uh like this so on the right you see a t projection of the output layer weights from a network trained on modular addition so this is x + y modulo 8 I think the lines show the result of adding eight to each element the colors show the residue of each element modulo 8 so if you do the tne projection you can see the lines are obviously drawn by the authors but you can see there are structures where if you go along the line right here they colored essentially this is always adding eight adding eight so there are structures where um this the rule for generating the data is clearly present in the data itself uh sorry in the Network's weights this gives you a strong indication that the network has not only just remembered the data somehow but has in fact discovered the rule behind the data and we have never incentivized the networks to learn these rules that's the wild point there are architectures where you try to specifically make tell the network look there is a rule behind this I want you to figure out the rule you can maybe do symbolic regression or um I don't know like you can try to build an internal graph of and reason over it no no we just train neural networks right here and it turns out that these networks can learn these rules so why do I relate this to the

The role of smoothness

double descent phenomenon in um it is assumed or I've heard the authors of these papers uh speak about their kind of hypothesis why this happens and this is a bit mixed with my hypothesis as well uh they speak of for example weight Decay being one possible explanation so they say if I have a bunch of data points let's say right here right and I want to do regression on them well if I just do linear regression I have one line right it's fairly robust right it's fairly flat it's fairly robust because it's just one parameter now if I start to add parameters right I get maybe I get to a point where I have a good number of parameters you know this polom maybe kind of like this still fairly robust right you can see how it might generalize to new data then right so this the blue one will be somewhere here the dark blue one would be somewhere here where the validation loss actually goes down with the training loss but then when I add when I keep adding data points uh sorry parameters then you know classically I'll start you know my overfitting right here and this it will not generalize to any point that might be in between like one here or so there will just go up so the green would correspond to the point where I just start to interpolate the training data but then what happens if I go on if I make even higher order pols or higher order neural networks well at that point at least these authors argue do I have another color this one they argue that you get like a polinomial that or or a curve Pro that yes it has a lot of parameters but it uses these parameters such that it can be sort of smoothly interpolate the training data and this curve is quite complicated in terms of the number of numbers you need to describe it but it uses the fact that it has a lot of freedom you know it can choose to be however it wants as long as it interpolates the training data right yet it chooses to be smooth because of a combination of GD training it and of weight Decay so the weight Decay would prevent any of these numbers from getting too big and therefore for getting like super out of whack curve uh so the weight Decay would in fact smooth the curve and that makes the model generalize really well because the smoothness now is reasonably generalizes to training data points that are in between like this data point is still fairly well represented by the purple curve in fact it's better than the dark blue curve in this particular case uh so you can see that the authors here argue that weight Decay might be an important contributor to why overparameterized networks generalize and it's interesting that the these grocking uh the authors of the grocking phenomenon paper here find the same thing they say okay if we use weight Decay the grocking appears to happen much faster um if is this I don't know what exactly they call grocking I'm just going to call grocking this whenever the validation loss uh snaps all of a sudden from 0er to 100 on these data sets now again these are algorithmic data sets so you know we don't know what happens I think they do make experiments when they noise some of the data so um they have some noise in there and I think they find that if they add noise then uh it's way more difficult um I'm not sure though maybe

Simple explanations win

I'm confusing papers here um but what might be happening right here right this is it's interesting because um what might be happening is that by imposing this smoothness um and the overparameterization we're sort of biasing these networks to find like Simple Solutions right so if at if I have just very few training data points if most of the cells here are blacked out right the simplest solution is simply to remember the training data however as I get more and more training data points right uh that give me more and more information about a potential underlying rule it becomes simpler for me to Simply to understand the underlying rule than to remember the training data it's more it's more difficult to remember the training data than simply to learn the rule so what might be happening here is that as I train and this is always training here the training happens always on the same data right you simply uh sample the same things over and over again train on it I think what might be happening is that you kind of jump around in your optimization procedure you can see there are some bumps in the training accuracy here so you kind of Jump Around that's a song no um so you jump around a bit and in your loss landscape there might be many of these local Minima where you in fact uh remember the training data perfectly so you kind of jump around a bit between them right you remember the training data perfectly and then one of them is just you remember the training data as well now this is however the solution is just so much simpler that you stay there this is not a good way of visualizing it so it must be something like here are the Minima where this is the training just the loss on the data however there is another loss and that's the loss on like the for example the weight Decay loss and is you know it's pretty good all of these things but then for one of them it's just like because that solution is so much simpler so you're going to choose you're going to jump around between those Minima jump around until you know once you reach this one this loss right here that comes on top of this is just so much lower that you're going to stay there and it's like wow I found such an easy solution um I'm not going to go out again so yeah now the big question is of course how and why does something like

Why does weight decay encourage simplicity?

SGD plus weight Decay plus potential other drivers of smoothness in these models how and why do they correspond to Simplicity of solutions right because Simplicity of solutions is something that kind of we humans have built in like okay what's the rule behind this what's the rule it's essentially assuming that there is a simple rule trying to find it because it make our life much easier it's a simple explanation for what's happening the interesting part is that weight Decay or something similar something that's happening in these neuron networks is essentially doing the same thing even though we don't tell it to do it so understanding this I think is going to be uh quite an important um quite an important task for the near future and also maybe we're not exactly right with the way Decay maybe there is some other constraint that we can impose that encourages Simple Solutions in the way we care about Simplicity even more and you know once we have that um the it's like you know there this age-old argument do these things actually understand anything well in this case I'm sorry but if you have found this solution with the rule uh essentially built into the networks of the into the weights of the neural network you can say well the network has in fact learned the rule behind this binary operations so you know who are we to say these networks don't understand anything at that point and also it gives us the opportunity to you know train these networks and then from the structures of their latent spaces we might in fact parse out the rules of data we don't know yet so we let the networks fit and we parse the underlying maybe physical laws maybe um social phenomena we pars them out from the underlying uh data oh yeah here

Appendix

okay there is an appendix where they list binary uh operations they have tried out um models uh optimizations so yeah they use a Transformer with two layers for attention heads um so it's not a big thing and also the data sets aren't super complicated but it's pretty cool to see uh this phenomenon now again on if we have real world data bigger networks noisy data um it's not going to happen as drastically and also they say as you increase the size of the data set where is that um then this phenomenon is harder and harder so if the entire data set is bigger uh the grocking phenomenon I guess it's more tough to see and also here is the experiment I mentioned where you have several outliers so noisy data points and as you um so this is the fraction of correctly labeled data points so as you increase the number of correctly labeled data points you can see the grocking happens in more often or to a better validation accuracy than not so well you can I don't know if you can read this but um yeah the these down here they have too many outliers so with too many outliers either the validation accuracy just stays at zero or it just turns up like quite late okay that's it here is an example of one of these binary operation tables that is a little bit larger I don't know if it's one of the 10020 uh sized ones but this is something that would be presented to the network and they say what we invite the reader to guess which operation is represented here well have fun dear reader um yeah all right so this was it

Conclusion & Comments

from me for the grocking paper as I said this seems like it's work in progress I think it's pretty cool work in progress it uh raises a lot of questions and um I think yeah I think it's pretty cool I wonder how this happened like how did how did people find this they just forget to turn off their computer and in the morning they came back and they're like whoopsy doopsy generalized though if you know if you build these kinds of data sets I guess you have something in mind already yeah in any case that was it for me tell me what you think is going on in neural networks or is there like a super easy aom's razor explanation that I'm missing um I don't know tell me what you think I'll see you next time bye-bye

Другие видео автора — Yannic Kilcher

Ctrl+V

Экстракт Знаний в Telegram

Экстракты и дистилляты из лучших YouTube-каналов — сразу после публикации.

Подписаться

Дайджест Экстрактов

Лучшие методички за неделю — каждый понедельник