Expire-Span: Not All Memories are Created Equal: Learning to Forget by Expiring (Paper Explained)
41:44

Expire-Span: Not All Memories are Created Equal: Learning to Forget by Expiring (Paper Explained)

Yannic Kilcher 24.05.2021 10 691 просмотров 376 лайков

Machine-readable: Markdown · JSON API · Site index

Поделиться Telegram VK Бот
Транскрипт Скачать .md
Анализ с AI
Описание видео
#expirespan #nlp #facebookai Facebook AI (FAIR) researchers present Expire-Span, a variant of Transformer XL that dynamically assigns expiration dates to previously encountered signals. Because of this, Expire-Span can handle sequences of many thousand tokens, while keeping the memory and compute requirements at a manageable level. It severely matches or outperforms baseline systems, while consuming much less resources. We discuss its architecture, advantages, and shortcomings. OUTLINE: 0:00 - Intro & Overview 2:30 - Remembering the past in sequence models 5:45 - Learning to expire past memories 8:30 - Difference to local attention 10:00 - Architecture overview 13:45 - Comparison to Transformer XL 18:50 - Predicting expiration masks 32:30 - Experimental Results 40:00 - Conclusion & Comments Paper: https://arxiv.org/abs/2105.06548 Code: https://github.com/facebookresearch/transformer-sequential ADDENDUM: I mention several times that the gradient signal of the e quantity only occurs inside the R ramp. By that, I mean the gradient stemming from the model loss. The regularization loss acts also outside the R ramp. Abstract: Attention mechanisms have shown promising results in sequence modeling tasks that require long-term memory. Recent work investigated mechanisms to reduce the computational cost of preserving and storing memories. However, not all content in the past is equally important to remember. We propose Expire-Span, a method that learns to retain the most important information and expire the irrelevant information. This forgetting of memories enables Transformers to scale to attend over tens of thousands of previous timesteps efficiently, as not all states from previous timesteps are preserved. We demonstrate that Expire-Span can help models identify and retain critical information and show it can achieve strong performance on reinforcement learning tasks specifically designed to challenge this functionality. Next, we show that Expire-Span can scale to memories that are tens of thousands in size, setting a new state of the art on incredibly long context tasks such as character-level language modeling and a frame-by-frame moving objects task. Finally, we analyze the efficiency of Expire-Span compared to existing approaches and demonstrate that it trains faster and uses less memory. Authors: Sainbayar Sukhbaatar, Da Ju, Spencer Poff, Stephen Roller, Arthur Szlam, Jason Weston, Angela Fan Links: TabNine Code Completion (Referral): http://bit.ly/tabnine-yannick YouTube: https://www.youtube.com/c/yannickilcher Twitter: https://twitter.com/ykilcher Discord: https://discord.gg/4H8xxDF BitChute: https://www.bitchute.com/channel/yannic-kilcher Minds: https://www.minds.com/ykilcher Parler: https://parler.com/profile/YannicKilcher LinkedIn: https://www.linkedin.com/in/yannic-kilcher-488534136/ BiliBili: https://space.bilibili.com/1824646584 If you want to support me, the best thing to do is to share out the content :) If you want to support me financially (completely optional and voluntary, but a lot of people have asked for this): SubscribeStar: https://www.subscribestar.com/yannickilcher Patreon: https://www.patreon.com/yannickilcher Bitcoin (BTC): bc1q49lsw3q325tr58ygf8sudx2dqfguclvngvy2cq Ethereum (ETH): 0x7ad3513E3B8f66799f507Aa7874b1B0eBC7F85e2 Litecoin (LTC): LQW2TRyKYetVC8WjFkhpPhtpbDM4Vw7r9m Monero (XMR): 4ACL8AGrEo5hAir8A9CeVrW8pEauWvnp1WnSDZxW7tziCDLhZAGsgzhRQABDnFy8yuM9fWJDviJPHKRjV4FWt19CJZN9D4n

Оглавление (9 сегментов)

Intro & Overview

hello there today we're going to look at not all memories are created equal learning to forget by expiring and the system also known as expire span it's by sanbaya subbata da ju spencer poff stephan roller arthur slum jason weston and angela fun of facebook ai research and lauria in this paper on a high level the authors propose a modification to the transformer attention mechanism that allows these systems potentially to include much longer context spans the way they do it is that they don't want to attend to all of the context but in an auto regressive way in each time step they want to decide is this particular time step worth remembering or not and if so then for how long so after a while these memories of the past expire and then they are dropped and the system can learn itself which things are important to remember for the future and which ones aren't so it has some good things it has some limitations it's very strong in tasks where you explicitly have to remember individual things for a long period of time so we'll dive into the system right here it's a pretty simple idea i think and uh it appears to work on the tasks that they produce so yeah as always if you like this don't hesitate to share this out and tell all your friends about it i'm sure they are very interested so they say uh the attention mechanisms have shown promising results in sequence modeling tasks that require long term memory right so the they say however not all content in the past is equally important to remember we propose expire span a method that learns to retain the most important information and expire the irrelevant information they say these forgetting of memories enables transformers to scale to attend over tens of thousands of previous time steps efficiently as not all states from the previous time steps are preserved

Remembering the past in sequence models

so again this is the core idea right here if you have a sequence model like a transformer and in this case particular we consider a sort of auto regressive decoder only sequence model which means that for the next token to predict like this one right here we only care about the past and not the future so this is a unidirectional sort of auto-regressive style decoder so every token can attend to its past now if you want to predict the fourth token right here in an attention mechanism you have to pay attention so to say to three things in the past right if you want to predict the next token the fifth token right here you have to attend to this previous one but also all the other previous ones so to four in the past if you want to predict you see what's coming right if you the more the longer your sequence gets the more things you need to attend to in the past which gives us this traditional o of n squared uh computation and memory requirements that attention mechanisms have so if you get to very long sequences this can become a problem because uh you always need to attend to everything in the past so imagine this is whatever a sentence the cat sat on the mat now not all words they say right here are equally important so for example it would be easy if you wanted to predict this word right here matt it will be pretty easy to do so even if you don't remember that the word the is in front of here right the word sat here sat on seems pretty important because you know to sit on something is a good indication that there is maybe a mat there or a chair or something like this right so these seem to be worth remembering while the word the is maybe not as important the word cat might be semi-important and we would like a system that learns to sort of forget and remember the correct words right here if we only remember the more important pieces of information and we discard here in this case this word the then we also have one less thing to attend to and the goal is if we can get the number of important things down then it won't be n squared but it will be something like o of n times m where m is the size of the memory that we have this work here doesn't have an explicitly sized memory rather it does the following

Learning to expire past memories

following it goes over every element in the sequence and of course gives you sort of goes through a bunch of layers gives you a prediction right so here is a prediction uh i misplaced this let's go down a bit further here so every element in the sequence gives you first of all a hidden state right h here this and it gives you a prediction like y okay so this is h one and y one then you go to the next element and that with consideration right attending this layer attends to the last layer gives you h2 and from that it predicts y2 and so on um let's do one more so in this layer so in each layer the sort of future attends to the past and that gives you a prediction and the attention is over these h right here over this hidden state now what this model does is it adds one component in each time step it doesn't only predict the output of this particular time step if there even is an output right it also predicts this number they call e and e is the expiration duration of that particular memory so e is produced every time from h and e tells you how long you should remember that particular h so here for example h3 also attends to h1 i forgot to draw this in right here right now let's say that e1 here is two okay saying that this particular memory should be valid for two time steps i'm not going to need it longer than two time steps now let's say the fourth so the next sequence tokens comes in h4 and h4 is produced of course by attending to the past but now you want to attend to h3 to h2 and because all of the past you want to attend to h1 but because this h1 is already expired you can't so the the system would it would drop h1 you no longer can attend to h1 so this is different

Difference to local attention

from just a fixed window right if you have a sequence what people previously did was something like local attention where you say okay i have a window of like size l which is four and if i predict this token right here i can attend to the past four things if i then predict this one predict this one i can attend to these past four things so this here is different in the sense that if you have a fixed window again everything is the same importance but you just limit how far you can look back this works to an extent but if there is something really important right here you will forget it no matter what however in expire span this thing right here can say well i have an expiration date of 1 million billion right 1 million billion so for 1 million billion future time steps things will be able to attend to that important piece of information however if you can say for the next thing well i only i expire immediately this is not worth remembering for the future okay so i hope you got the principle right here they also have a drawing here where you can see these hidden states are produced

Architecture overview

and these hidden states are produced naturally from forward propagating through the model and for each of these hidden states one expiration date is produced and now in the future when i want to produce the next hidden state or you know the next output of the next layer i can look at the past and i only consider the things where the expiration date hasn't passed yet so for anything else like this one right here or their expiration date was just too short so this is uh and only these go into the attention mechanism so this is a dynamic way of saying how long a memory should last now you can immediately sort of see the weaknesses of this right here you have to know at the beginning like at the moment where you produce the signal you have to know for how long it's going to be valid and that's certainly that is certainly you know the case for some things that you have to remember like when you come across a name in a story that is maybe something that you know okay i'm going to remember that piece of information very well because probably it's going to be important um but not for all right so sometimes something big something that you thought wasn't important maybe this thing right here you just you read it it's in a sequence of text you read that word and you know it doesn't seem too important but then all of a sudden because this word um is something so you read on all of a sudden that password becomes super duper important and you shouldn't forget it and this is a these are effects that the system cannot handle the system can only decide at the moment where you consume the token how important is it how for how long should i remember it independent of what happens in the future you might already know a system that learns to remember things over long pieces of time which is the long short-term memory uh cell or generally recurrent neural networks that have an internal state and then at each point they decide how to update that state so this here is sort of an in between a transformer which you cannot decide at all how important things are and what you should remember it's either you remember all of it or part of it and the lstm on the other hand that dynamically updates its internal memory every single time step right so it can make remembering something dependent even on the future um this yeah as i said this is done for computational reasons mostly because lstms you have to train one after the other you have to back prop through time here you can still get away with a bit of parallelism i think at least though i would argue if i could extend this i would argue that if you consider the point where something expires i would maybe build in something where the system can decide to re-retake this into memory or you know like such that the system can revise its own predictions about how important each of the memories are and if you look at this in

Comparison to Transformer XL

a let's say a computational point they base their work off transformer xl so transformer xl is sort of the baseline right here what transformer xl does is it has long sequences and then it considers blocks of those sequences and they do the same here so you just you chunk these sequences into different blocks okay now for each of the elements here you output a vector which is this hidden state now what transformer xl does is it does the attention in block one just as it would do regularly and then in block two three so it chunks the sequence and handles the blocks individually however in block two in order to you know look back because we always want to look back we want to remember things what you do is you put the hidden states that you produced in block one you sort of put them into like a little bit of a register i would say so you put them into so these are the vectors i just lay them on their side right these are the vectors and you put them just there um there is a sort of a stop gradient right here but you just kind of put them to make them available for the next block so what the next block can do when you want to predict for example the hidden state of this thing it can attend to obviously to the sequence elements in its own block right because you consider the block as a whole but it can also attend to these things right here and again you produce that hidden state ultimately from it and from it every element in that block and those go then to be available for the next block to attend to and you can even remember multiple blocks like this so you can sort of carry forward this block as well right and now block three can attend to the last two blocks however um you can't do this infinitely right otherwise you're going to run into the same problems but at least this handles a bit of the back prop issues and also these things right here they cannot attend to each other right there is no need for them to attend to each other so you don't have n squared you have n times whatever that here um so if this is m and this here is n you have o of n times n plus m no sorry um yeah but n is way smaller so it's n squared small n isn't the whole sequence length i'm maybe b let's call this b the block size right and this here at maximum is n so you have a way smaller sort of um way smaller could radic blow up only inside the block and you can even compress these memories here of transformer xl you can max pool you can learn to compress them and so on so this is the system that they base off of right they also consider sequences in these blocks where inside the block it's just regular attention and then you can attend to the past as you would in transformer excel except that um some of these past memories they are forgotten so here these are maybe forgotten and maybe this one is forgotten too until you are here right and then during that time you know one more expired so you can see there is a lot less stuff around so you get away with having a smaller memory and you can potentially up the time that you can look back into the past if you only have a limited set of slots available here you know you can increase that so that's i hope that is a bit clear how they do it they go block by block and in each block they look back um and they build this memory uh right here so this memory here that inside the next block they can also attend to but in the memory other than transformer xl they only consider things that have not expired yet and the expiration is determined at the moment where the signal where the hidden state is produced in fact the expiration here is pretty simple so you take that hidden state that's produced by the network and you simply perform a

Predicting expiration masks

logistic regression on top of it so the logistic regression here will give you something in the range 0 to 1 and you multiply that by l and l is the maximum possible length of remembering right now these are all you know design choices you know that the sigmoid function here used in logistic regression is a rather let's say rather steep function so there is a region where you sort of go up quite quickly but there are also large regions where it's just uh all or nothing right so i get i'm going to guess that this function here will be either remember this or don't remember this maybe there will be some in the middle but which tells me that this l setting right here might be fairly important that you tune that for the task that you want to consider another thing they say is okay how do we actually implement this and they implement this via a mask okay like if you have a bunch of things that you could attend to right the way that you don't attend to everything is by masking out uh attention parameters essentially or elements of that map so if i draw the same sequence twice the attention matrix is of course constructed by outer product of keys and queries right so here is the attention matrix every cell gets a value of how much uh this x here attends to this y and as you know that already in these decoder things we need a mask because this thing here cannot attend to this thing here would be like this thing here so it cannot attend so all the upper triangular thing right here is already dark well okay i can't draw but we usually implement this with a mask right because gpus aren't super good at doing triagonal matrices so we just put a mask here and we say everything up here is off limits okay now if we also say well this let's say this thing here has an expiration date of two which means that this can still attend to it but this here cannot attend to it so what we need to do is well i might have drawn this slightly weird but let's say that is this no it's not correct but you go to that cell and you also mask that out you say you cannot attend to anything that's expired so what you end up with is sort of this mask where you fill in um yeah i think after that it should all be black right where at some point uh the row will just be masked out uh from then on so the light squares here have a value of 1 and the dark squares value of 0 meaning that you don't consider these things in the attention anymore that's how it's implemented if you just do that then you have a problem on your hand okay because this is not differentiable simply putting them asking whether or not this r number r is the thing still valid you see it's constructed from e which is the expiration uh duration and the t which is the current time step and i which is um the i from the e so you look back and say is this thing still valid and this number if it's positive it's still valid if it's negative it's no longer valid if this becomes negative it indicates the memory is expired and can be removed from the set you attend to so you construct a mask with just everything all the r's that are positive and use that mask in the attention like you already do with the masking out future tokens this is not differentiable okay however they say with such discrete masking the x bar span will not receive any gradient for training instead we use a soft masking function that smoothly transitions from zero to one and this is what you can see right here so essentially how this works is here is a memory produces a hidden state and it says i am valid for three steps so that means that the mask here how does the mask look the mask for this particular thing looks as follows so here is zero and here is one um the mask okay well yeah the mask starts at one for one two three and then it drops off linearly until it's at zero you can see this right here so here's the min of one which means that um it can never be higher than one the max of zero which means that it cannot be lower than zero and then in between it's governed by this rule right here which you can see r is a hyper parameter saying that like a ramp drop off um yeah the length of a ramp that is bound between zero and one and the higher this r is if it's negative then we're in this decreasing regime okay so this is the mask now you can also immediately see that talking about gradients right the only place where the module that generates e right this is a we generate this here the hidden state goes into a neural network and that generates this expiration date the only place where that neural network gets a learning signal gets a gradient is during this drop-off no not before not after the only time where this network gets any learning signal at all is during this thing so it is quite important these parameters right this here this is upper bounded by the parameter l and then this thing right here is modulated by the parameter r so these hyper parameters i feel have are quite important to how this task is going to play out if you actually want to learn anything because let's say in a sequence here is something that you need to remember but it for here if the l is too short right you will maximally remember it till here and then it's gone even if the l is large enough right then you won't get any training signal for this unless sort of the let's say the l is large enough so this is your expiring span and then it sort of drops off the importance drops off and only if that drop-off happens to coincide with you know the thing where it's important you do get a learning signal that hey maybe you should remember that thing for longer next time because i'm gonna need it right if that is not the case if your expiration prediction is like this and your drop-off is done here then you will never get a learning signal that hey there might be something here where you should remember this thing this is the i mean it's the same problem you get anywhere where you're dealing with long sequences and it is a problem because ultimately if you want to have a general training method where anywhere in the future there could be something important you have to you you're going to have sort of this quadratic um this quadratic thing where you technically have to attend to all the things in the past even a little bit because you want to make it differentiable learn to remember right if you always forget and then there is something here you don't know anymore that there was something to remember you'd somehow need a learning signal i guess you could break this maybe down into maybe not n squared but maybe like n log n where you sort of build up a tree of the past and then you somehow realize that okay there is something to remember you don't maybe don't know what but maybe there is something to remember this might have been done already in any case i just wanted to show you that the learning signal here is very small like the window where you can learn something is very small and that means the kind of tasks it can be applied to or maybe not as much as many as you would hope what they also do is they put an l1 penalty so an l1 penalty onto these expiration things so they encourage the network to rather forget things this is in order to keep the um to keep the just the predictions small you don't want the network by default to say well none of this is important and only if you get a learning signal that something is important then the network should predict high numbers so ultimately you're going to have a sequence right i'm going to draw it like this time and the network will predict various spans to expire these memories and the first thing you do is you'll say okay everyone just kind of you know kind of go down go down and then if let's say this thing right here really profits from this thing right here in the sequence then and if uh this has been going down enough such that the later one is in this ramp portion this this r portion of the former one then you get a learning signal saying hey maybe you should remember that thing for longer right and then hopefully hopefully some next thing right here will also benefit from remembering this thing and now that is in this span sorry in this ramp region which will give here another boost to remember it for longer so this is how you learn you sort of need a continuous reinforcing signal over different time steps in order to learn you the this long range thing it's i don't think that generally is learnable with this system you need these intermediate things or you need some kind of randomness to discover it and this is very close right to um reinforcement learning now all right and yeah so it's what they do here they also they have some practical considerations where they say okay because we cache these things like the question is how do you back prop how do you even back propagate through something like this i said there was a stop gradient uh right here what you do is you cache the h you cache these things and then as far as i understand you do compute the attention like the expiration things on the fly like you cache the hidden states and then you compute the should you mask them or not you compute that thing on the fly and so you can back propagate yeah to these variables even in the future because you have the h's cache i don't think the back prop flows back to when the hidden states were produced because well it can't right because you cache it you don't have the graph available anymore so they have a bunch of practical considerations right here

Experimental Results

and now they test this so they test this in various tasks for example there are these reinforcement learning tasks there are these text instruction tasks there is character level language modeling collision detection where you have a video you go frame by frame so these tasks i guess except the language modeling tasks are quite constructed such that you have to remember long things particularly interesting for example is this one right here where they do have this character level language model and then they look at what does it learn to remember and you can see right here if the sentence is powerful influence in egypt right and they say this the model strongly memorizes the two areas egypt and alexander so if you look egypt right here and this is the visualization of the expiration time this is strongly remembered if you replace in the same model you just replace this with the word somewhere all of a sudden the model doesn't remember it anymore and if you replace it with humpty dumpty again the model remembers it quite uh well so this is an indication that the model has in fact learned that you know if there is something uh special and they claim if it's a name um or something like this the model remembers it well they also say the rare words remembers those in memory and i'm asking myself is this just a function of let's say complexity sorry perplexity like could you just remember the things where the model perplexity is pretty high instead of learning what to remember right so you just remember sort of the things that you would not have predicted i'm going to guess the learned remembering is better just because it's learned so it can also remember things that have a low like that have a big probability but might still be important i want to talk just a little bit about this first task right here to show you the kind of tasks where this could be good at so here you have a grid world reinforcement learning approach and you're at the start you were able to observe the colors of the fields you're on right so you're at this start right here and this is either blue or red and then what you need to do is you need to walk all the way through this long corridor and then you need to go to the correct door and the correct door is whichever one was you know the collar was at the beginning and the long corridor is made such that it is too long to be in the same block right it's too long to consider in one attention operation at the same time and this model they say it learns to remember the correct thing with very little effort so here you can see the um the comparison to transformer xl so transformer xl also has the ability to remember that right it can simply attend to this thing in the past if given enough memory so here you have the memory size and you can see it starts out by being just kind of random because it doesn't remember it like the memory size is too small to actually remember and as you give it more and more memory it learns um to attend to the correct thing in that memory however expire span it doesn't have a set memory right you can with the l1 penalty you can sort of modulate how long it forgets things but these here are just five random samples i guess of the same model and you can see that it solves the task pretty well while it's effective memory size if you calculate like if you look at you know what things you do remember uh stays relatively low so it learns to remember this correct thing right here which is pretty cool right however this there is details of how this task was constructed i already said if it's just a long thing then um then we this is like if this was just a long corridor this was unlearnable so if you look at the details here in the appendix um where is it yeah the corridor task the corridor length is sampled from between 3 and 200 right so and for the x bar span we set the maximum span to 200 so it's able to remember which again this l seems to be an important hyper parameter and the ramp length to 16. so so what does this mean right if you have a let's say a i don't even know how many things they consider at the moment like what's their block length i'm sure that's stated somewhere okay but in this corridor task reinforcement learning problem right if you sample things that are just 200 apart right i guess you can learn uh because your l is 200 right but your predictions you know they if they are too short then you never learn to get up there and if they're too long okay you have the nl1 penalty which makes them shorter and shorter and eventually come into the field of learning but here you sample at randomly so sometimes it's three and sometimes it's 200 and sometimes it's here here so you give the model a really nice training signal where however wherever it currently has learned for however long to remember things there's going to be this ramp and some training runs where the length of the corridor exactly falls into this ramp and that will give it a training signal saying hey you maybe should remember that thing for longer okay for longer then the ramp is here and then there will be some kind of problem that exactly falls into this ramp right so as in reinforcement learning you it is best i'm going to argue if you sort of if your loss structure guides the model to remember things for longer of course this doesn't work in the character level modeling but there i think the text is naturally structured such that if it's something important to remember you will find instances where that comes after 10 tokens and you will find instances

Conclusion & Comments

where the need to remember comes after 20 and 50 and 100 and so on so yeah not for every task but certainly for many tasks this might be a good solution again i would advocate to add the ability of the model to refresh these memories not full lstm style so not internally compute and uh update an internal state or something but just to go there and say well in the light of this new evidence this thing right here that i want wanted to forget now it might still be quite important right so that would be my first extension and my second extension would be instead of building some sort of a bank right here that you can attend to maybe you build some sort of a tree like some kind of a merkle tree-ish thing ins but not with hashes but with um with hidden latent variables i'm sure maybe this has already been done okay that was my two cents to this paper i think it's a pretty cool paper um if you have problems that have super long sequences and you have a clear structure where it's important to remember key pieces of information a few over long distances and if that is if those distances are somehow distributed a bit such that it's not only super long distances this might work wonders so tell me what you think in the comments and that was it for me bye-bye

Другие видео автора — Yannic Kilcher

Ctrl+V

Экстракт Знаний в Telegram

Экстракты и дистилляты из лучших YouTube-каналов — сразу после публикации.

Подписаться

Дайджест Экстрактов

Лучшие методички за неделю — каждый понедельник