# Leave No Context Behind: Efficient Infinite Context Transformers with Infini-attention

## Метаданные

- **Канал:** Yannic Kilcher
- **YouTube:** https://www.youtube.com/watch?v=r_UBBfTPcF0
- **Дата:** 24.04.2024
- **Длительность:** 37:16
- **Просмотры:** 59,960

## Описание

Google researchers achieve supposedly infinite context attention via compressive memory.

Paper: https://arxiv.org/abs/2404.07143

Abstract:
This work introduces an efficient method to scale Transformer-based Large Language Models (LLMs) to infinitely long inputs with bounded memory and computation. A key component in our proposed approach is a new attention technique dubbed Infini-attention. The Infini-attention incorporates a compressive memory into the vanilla attention mechanism and builds in both masked local attention and long-term linear attention mechanisms in a single Transformer block. We demonstrate the effectiveness of our approach on long-context language modeling benchmarks, 1M sequence length passkey context block retrieval and 500K length book summarization tasks with 1B and 8B LLMs. Our approach introduces minimal bounded memory parameters and enables fast streaming inference for LLMs.

Authors: Tsendsuren Munkhdalai, Manaal Faruqui, Siddharth Gopal

Links:
Homepage: https://ykilcher.com
Merch: https://ykilcher.com/merch
YouTube: https://www.youtube.com/c/yannickilcher
Twitter: https://twitter.com/ykilcher
Discord: https://ykilcher.com/discord
LinkedIn: https://www.linkedin.com/in/ykilcher

If you want to support me, the best thing to do is to share out the content :)

If you want to support me financially (completely optional and voluntary, but a lot of people have asked for this):
SubscribeStar: https://www.subscribestar.com/yannickilcher
Patreon: https://www.patreon.com/yannickilcher
Bitcoin (BTC): bc1q49lsw3q325tr58ygf8sudx2dqfguclvngvy2cq
Ethereum (ETH): 0x7ad3513E3B8f66799f507Aa7874b1B0eBC7F85e2
Litecoin (LTC): LQW2TRyKYetVC8WjFkhpPhtpbDM4Vw7r9m
Monero (XMR): 4ACL8AGrEo5hAir8A9CeVrW8pEauWvnp1WnSDZxW7tziCDLhZAGsgzhRQABDnFy8yuM9fWJDviJPHKRjV4FWt19CJZN9D4n

## Содержание

### [0:00](https://www.youtube.com/watch?v=r_UBBfTPcF0) Segment 1 (00:00 - 05:00)

hello there today we're going to look at leave no context behind efficient infinite context Transformers with infinite attention so the specific techniques we're going to look at are called infinite attention and this is by researchers at Google so infinite attention is a promise to scale Transformer based large language models to infinitely long input with bounded memory and computation that is of course since a very long time a dream of Transformer models typically have a limited size context window and there are good reasons for that and this paper says Hey what if we did not do that what if we didn't have a finite you know scrubby finite uh attention what if we could process infinitely long sequences now how do they do that if it's a key component in our proposed approach is a new attention technique dubbed infin attention the infin attention incorporates a compressive memory into the vanilla attention mechanism and builds in both masked local attention and long-term linear attention mechanisms in a single Transformer block so infinite attention has this thing called the compressive memory and it uses that compressive memory to kind of store old stuff from which you can then retrieve things so we're going to look at how this technique performs or what it does and go into the details of how it works we are going to look a little bit at the experiments but honestly the experiments you can already tell right now it's going to perform quite well at long sequences that's what the experiments are going to show now uh before we dive in I just feel like you know I just sometimes it's like drawing something out of I don't know why but just how about how about we just draw a recurrent neural network for no reason really in particular so imagine you have I don't know like a piece of data here and that just goes into some sort of processing thing right and that just produces kind of a an some intermediate value really in a recurrent neural network this is called a hidden state or something like this you can actually have the multi-layer recurrent neural network and each of them will produce some sort of a hidden State and then the next piece of data sometimes the next word or the next token and so on will actually be combined with that hidden state right and then the next layer combined with that hidden State and so on and that will output again hidden States so this is how recurrent neural networks work and we've known them for a long long time and they can actually you know consume infinite uh context if you will because each computation they do they kind of store into that hidden State and then the next thing accesses that hidden state or incorporates it into the computation in order to produce the next piece of hidden State and therefore in a technical sense you know the next block that would you know would come over here technically through this hidden State can in theory have access to all the data that came before it including all the data like from down here and so on so um well certainly no attention mechanism no Transformer this is completely different I sometimes just like to just really like to draw recurrent neural network architectures for no particular reason at all so now on a totally unrelated topic uh let's you know discuss this paper what does it actually do so this uh introduces this uh infinite attention um mechanism it's a new attention mechanism now just to recall I've not done this for a long time just to recall what the attention mechanism actually is so we can get at the mechanics very briefly I have a sequence of tokens usually in language modeling and I want to transform that sequence of tokens into a sort of next layer representation so each of these will be represented by some sort of a vector and I want to transform them one layer up into vectors just one step up and I can do that in multiple layers and that will represent computation um through that now if I were in a classical neural network I would build a so-called dense layer where each of these would be connected by a weight with each of the lower ones

### [5:00](https://www.youtube.com/watch?v=r_UBBfTPcF0&t=300s) Segment 2 (05:00 - 10:00)

and those weights would kind of build a weighted sum and followed by a nonlinearity somewhere here and then that would be this signal not the attention mechanism what the attention mechanism does is um it computes a thing called queries so queries Q will be some sort of vectors uh that just essentially tell these are retrieval keys if you will now um they're called queries but it's like what you input into a search engine and it will also produce a set of keys k now both of these will actually be produced from the lower representation here because you obviously don't have the upper representation yet so you'll produce a series of keys and then you'll compute what you would sometimes call the routing or the attention Matrix or something like this by building the inner product of each of the queries with each of the keys which will give you so then since this these are five queries so five tokens it will give you five queries and five keys so that will give you a 5x5 Matrix which is which we usually sometimes call the attention Matrix a so there will the queries will be here and the keys each entry here will describe how well a query matches a given key now the um thing about the attention Matrix it is then normalized uh in a softmax operation across the whole Space of the keys and therefore uh for each output position you kind of get a weighted sum and you might have heard that before because I just said a dense layer a classic Network would be a weighted sum well what's different here what's different is the weighting of the sum of these connections are computed dynamically given Matrix a and in a classic neural network they are statically learned as parameters of the connection so we get this routing Matrix a we put softmax on them and then we aggregate the actual values of the tokens I said these are embeddings we aggregate these in a weighted sum that is dynamically computed from these uh to produce this value right here in a mathematical sense we usually represent this like so um where do we have that we can just write it down so we produce we have signal X and we produce queries are going to be some sort of function of X usually um some WQ * X some maybe linear function or some small neural network producing the queries the same goes for the keys using the same input signal will produce a keys and by the way the same goes for the values um like so we've produced queries keys and values from the same um from the same data sometimes from different data uh queries can sometimes come from different data than keys and values but in any case then we're going to compute the outer product between queries and keys and of that we're going to take the soft Max uh row or column wise whichever one the keys are and then that we're going to multiply by the values so you can see this here will be so the inner part here will be this Matrix a I've described before and the outer part will actually be the full routing like the weight the way to in a weighted sense aggregate V for the next layer X so that will be x l + 1 the next layer X will be that that's classic attention now as you can see the big part here is building this Matrix right here for five tokens we already had a memory consumption of 25 and the longer you make these sequences this goes up quadratically and that is a problem and therefore current attention mechanisms kind of max out at a couple of thousand tokens length because if you think of you know thousand tokens that's already a million entries in that Matrix now a lot of things have tried over the years to get that down uh but none of them have been really successful so why is it hard to get that down shouldn't we be able to do some linear algebra trick and the answer is no and the reason is because of this softmax right here this softmax this kind of normalization Step um is really needs a globally available a now there are techniques to kind of do it robar row right because the soft Max is across a

### [10:00](https://www.youtube.com/watch?v=r_UBBfTPcF0&t=600s) Segment 3 (10:00 - 15:00)

row um do it row by Row in which case you're just trading off computation against uh memory but still you need you have that quadratic complexity in computation then no matter how you distributed um there are attempts to do what's called linear attention so the easiest form of that is you just say how about I just multiply q and K and then I just multiply it by V and that's it right and that's called linear attention it's also called Fast weights it's also called I invented Transformers in the '90s um it is really like it doesn't really work except for very simple tasks because everything's linear um you can actually even split this up into you can like then split this up this is essentially like X multiplied by itself a bunch of times with a few waiting factors and yeah doesn't really do too much so stuff like this has been proposed then we've had a whole bunch of things where people just slap nonlinearities on these things right they just say well if we just have the correct nonlinearity on Q and K we can kind of decompose that and um then there are a bunch of Kernel papers across that I think you know remind remembering back I did a lot of papers on sort of uh linear and linear Transformers like trying to get around this quadratic bottleneck I think this particular one is either the reformer or the performer that said well since this is an is kind of an you know since this inner product right here in some space there should be some kernel or some kernelization trick uh where in another space it's actually a true inner product so if we just find this function that transforms it into that space Bas we can essentially buy sort of our HKS equivalency math that I don't understand um we should be able to do the same thing ready like that should be equivalent but it's not and reformer and performer they suck too um so a lot of things have been tried right here and infin attention is kind of doing that a little bit the other thing that has been tried and that something they compare extensively here is something like Transformer XL now Transformer XL does a bit of a different take they say hey if we have a long long sequence okay why don't we just split it up here is a segment and then of boundary okay three segments how about we just do a tension here we'll have quadratic complexity and then the final kind of attention States will just will sort of we'll sort of leave these and then the next segment here can kind of also attend to that like they can do attention the queries can also go in here um however they didn't really have training they just kind of threw the last state in here and then these could access this state right here um and that was I think that worked for some stuff but it's not really satisfactory because you're just kind of learning to attend to something there's not really a planned okay what do we actually pass forward and so on and therefore it was more of a hack and so on but a lot of things have come up over the years so let's look at actually how infin attention does this is the core of their attention so like the transformer XL their attention mechanism consists of two parts one part that does really regular attention so if you cannot read this too well it says causal scaled dot product attention and PE not sure what position embeddings okay position embeddings um so this is regular old attention um so you have your queries and your keys and your values and these go in here now um this here is also this is multi-head attention um if you don't know what that is I I've done uh thinking attention is all you need it's discussed um and then a second part right the part over here now ultimately you can see these are added together so the signal from here is augmented by a signal from here in fact it's a weighted sum so in fact there's a hyper parameter beta or the parameter that can even be learned during training um beta that is how much you trade off

### [15:00](https://www.youtube.com/watch?v=r_UBBfTPcF0&t=900s) Segment 4 (15:00 - 20:00)

the tomb so to the normal signal we also get this side signal right here and that's supposed to represent data coming from the compressive memory and the compressive memory is supposed to contain all the information from the past so we're going to try to as we go along through these chunks uh we're going to try to build up this memory so that when we do attention using the queries we can look at the current sequence that's what attention here is doing but using the same queries we can also look back into that memory so how does that work um well we have to build up first that memory so that's one part of this paper and second we have to actually retrieve from that memory that's the other part of that paper or of this paper all right so um the naive approach we can maybe deal with the naive approach first and then see how the paper actually do it the if you wanted to do this right let's assume uh here you can see um how could you do that right could you just use Q to attend to something uh well what you would need is you would need sort of a compressed key space of the past and compressed values separately because keys are first multiplied by queries then there is a soft Max and then they're multiplied by values so you can just use q and go somewhere and multiply it with something and then expect to get something out of that you actually need these two things separately because there is a nonlinearity in the middle um now for yeah that that's the reason so if you want to build up a memory then you better of keys in the memory of corresponding values however that is incredibly tricky because you somehow need to uh store individual values and corresponding keys to retrieve them from and what you actually want to do is you actually want like something where you can push more and more data in and that something is a matrix um this is very related what they're doing here is very related to the concept of like an associative memory um and it works a bit like this okay there's a matrix M now I want to remember a vector v so what do I do I um Define a key for that V like K okay now what I'm going to do is the outer product V * K transposed right V is maybe Vector like this K is a vector like this do V * K transpose which will give me a rank one Matrix um now I have M right here that's of the same size and then I do plus I add those things together and that gives me m Prime okay now let's think what happens M Prime is like m + V * K outer product now there is an assumption and the assumption is that the things I want to store the v's are generally and in the K's these things are gen generally because they're high dimensional they're generally kind of orthogonal to each other and that is actually quite likely in high Dimensions that things are orthogonal to each other right if you have a thousand vectors in a 8,000 dimensional space that is one empty space even if you actually have more stuff um that is you know huge empty space and things tend to be rather orthogonal to each other in high Dimensions so if I now store something else I want to store you know V2 um and I Define key K2 I Define any key right any key by which I want to retrieve that value and I add that again and I get M prime so M prime is now the is M plus key value key V or uh M plus v KT plus V2 K2 T now when I want to retrieve something I go with the same key so here I come and I come with the key K okay so that's this key right here now look what happens I multiply that by m prime okay cool that's M um that's K *

### [20:00](https://www.youtube.com/watch?v=r_UBBfTPcF0&t=1200s) Segment 5 (20:00 - 25:00)

m plus um well let's actually multiply from the right because I totally thought ahead M prime K is equal to M K plus b KT k plus V2 K2 T K right so this K is this K and this K but here only this K oh and this one no this one now by let's say m is can be de and can always be decomposed as R rank one decompositions of some kind where a assuming they're generally orthogonal right uh we can also just say hey we start with an empty Matrix or a Unity Matrix or something this so This falls away but also consider this here since we assume things are generally orthogonal uh the inner product now see this is an outer product and then this is an inner product the inner product between K2 and K is going to be zero and therefore this Falls away whereas um because and we can also normalize right this here the inner product between the same Vector if it's normalized is one and therefore this whole thing results in V so I stored I build up this memory and stored stuff in it using these Keys which I defined you know for given values and then when I go and I add it to the same Matrix right I add these things together plus plus plus plus and when I come back and I come with the same key the likelihood that I'm going to get the same Vector obviously the more I store the less precise this gets and so on but I you I come with the key to this big memory and I get out approximately that Vector again so we're going to do very similar principle here to build up this compressive memory of the past and uh that's is going to be this part right here and retrieve is largely going to be how I just said we go with the same or similar keys in or in this case uh we come with queries you know that we want to where we want to grab out stuff um but similar concept all right now how does this ultimately work um and it works pretty much like I just said so here is the normal attention they call dot attention soft Max of q k times values compressive memory on the other hand is here is how we retrieve so first they show us how they retrieve as you can see they take the query and they go through some nonlinearity right here and they multiply that by this memory and that becomes a mem and ultimately their output now don't let them confuse you regular Transformer papers will call a like I did The Matrix of um of the Matrix of the attention values right not here this is actually the output so this is already multiplied by V and you can see that up here sorry up here where the a is all already the output usually one will call this here A but this paper calls a the actual output which then goes into a dense layer to compute the actual layer output so um sometimes the trans the attention mechanism is followed by a dense layer all right uh so you see there is this memory component and this dot attention component and they're mixed by this parameter Sigma that's a learned ating scalar all right so we have neared ourselves the technique right here this mem thing just comes about by multiplying a function of Q with this m so this is a in a sense an associative memory like we saw before right so we go with queries and we grab out uh signal now that signal how is that actually

### [25:00](https://www.youtube.com/watch?v=r_UBBfTPcF0&t=1500s) Segment 6 (25:00 - 30:00)

built up it's built up like this you may be familiar with that but here is how we update the memory by adding taking the last memory and adding something to it what do we add to it well a multiplication um of a function of the keys and the values right and in fact H we'll first do that so what happens if I retrieve they don't write it out right here but what happens if I combine these two formulas so you can see if I actually replace from down this and I CH just pretend up here this is actually Ms from or no let's do Ms minus one this is actually equal to I'll leave the bottom part away that's just normalization so sum function of Q times sum function of K times the Value Plus even the Lesser memory right and well would you look at that it's exactly what we've seen before from this whole linear attention literature so this here is just this is supposed to approximate like the softmax attention in some sense um it's an inner product between two things it's not a combined softmax and past literature has shown that approximating softmax like this isn't really super duper viable at least as far as I know so this is just linearly combining um all of these things together right um no matter whether that's a linear nonlinearity here like that's fine but um the point is like the computation itself is completely linear it uses it relies on this fact that oh in some space this is going to like for the correct nonlinearity here this is going to approximate uh softmax in fact if I remember correctly softmax cannot be uh perfectly replaced by this unless you go to infinite dimensional uh space with your Sigma function which you can do so um yeah it was a bit weird to me because all you see again is this T this again and again the same principles if you actually multiply it out this is just one a technique that we've seen before from linear attention papers um just rebranded and now we're accumulating it in a memory like we're accumulating a bunch of these but doesn't change the fact so they're using um they're just using the queries and keys they compute from the attention mechanism the regular attention mechanism and they're using the current queries to go and retrieve of a compressed version of all the past key value combinations like this and um rely on this oh if we choose a good nonlinearity here it kind of sort of approximates a soft Max right here I personally would be skeptical um this trick is fairly neat um they this inspired by uh some previous work they don't actually save like this they when they store like this so if you multiply it out right here what you actually add to the um memory is this Sigma k v minus uh Sigma k m s minus one so first you actually want to use the keys remember you see the difference between the formula here which is Sigma q m and k m so this is almost like you use the keys as queries to retrieve from the memory and why do you do that well because this here is what you want to store now in the memory all of this you want to add to the memory so what is this is what you want to store currently but first you're going to look up if you have stored that in the past because the same thing in the past you don't want to store it twice so that's why they first retrieve sorry they first retrieve from

### [30:00](https://www.youtube.com/watch?v=r_UBBfTPcF0&t=1800s) Segment 7 (30:00 - 35:00)

the memory this part right here so they can subtract it so they don't store something twice uh so they only store the same thing once that's pretty neat and um keeps the memory less cluttered in a sense all right so that's essentially it so now um we figured out how they store stuff uh by doing an associative memory like this how they retrieve stuff by using the queries with the associative memory and how does that flow into the attention mechanism it flows by means of a linear attention mechanism so not only does this compress the whole past it also uses a linear attention mechanism to then include it into the attention and I somehow feel only one of those two tricks would actually do the trick of infinite length attention uh maybe not infinite the linear attention is not infinite okay but also what you can get with a compressive memory is probably not infinite now how does that all fit together in a model so here they see what they're doing so they are um they are taking a segment right here they're going through layers of computation and in each layer they're producing some this thing right here now this is this memory that we've talked about like some intermediate thing that the next segment can attend to and in fact will attend to using its Q values and will attend to in this linear attention fashion so this is the diagram of the new novel infinite attention mechanism that allows transform warmers to become for infinite attention this is a new architecture that has been newly invented and is novel and um has certainly not been seen before and it's the whole concept in fact is new and yeah and they compare it to Transformer XL which you actually can't see because of the settings of the iPad I think um but here they show there's a crucial difference because Transformer XL through the way it's doing things it can only has only this um it can only attend like uh in a stepwise fashion that means you can only look back um few entries into the past From Any Given layer so yeah interesting stuff all right so I feel like uh we've gone through the architecture a little bit I hope that was at least somewhat clear uh they have to put everything into an associative memory because that's something you can add on top on top on top however because they have to put it into an associative memory they have to somehow decouple the keys from the curry and put the keys and the values together into one thing and that just doesn't work because should be a soft Max around the queries and the keys and you can do this trick where you say well if we have the correct nonlinearities on both ends like um uh here and here then uh that kind of approximate if we have the correct one that kind of approximates a soft Max and yeah I'll leave it to past literature whether that is actually true and worthwhile I personally would be quite skeptical of this method but I'm absolutely open to believing that it works all right they do some tests I don't want to go too much into um into the experiments right here I'm very sure the experiments show uh fabulous performance on um on fabulous performance on uh long context tasks and their related work uh does mention some stuff but uh I think the word recurrent is even mentioned once but in a some context so I don't know everyone you tell me what you think again I'm super open to believing that this actually does work and this is uh is the way to go for infinite

### [35:00](https://www.youtube.com/watch?v=r_UBBfTPcF0&t=2100s) Segment 8 (35:00 - 37:00)

attention but I'm very skeptical for one it uses linear attention linearized attention mechanisms that in the past just have not turned out to deliver on their promises second um because it kind of uses this it uses this compressive memory approach where you just sort of you just kind of sto or as you go along you never you don't really learn how to store you just store in a deterministic fashion which also means you have very little control over what you store and how you store it there's just a predefined formula that everything that happens just goes like this into the memory and even the retrieval you can't really help it so to say but to retrieve in that one way they give to you cannot go back into that infinite context and you know to a fine check and anything like this so that makes me doubt that because you would essentially have to very selectively store stuff once you get into the really long context game right you can't just compress equally the past into uh such a small state and lastly because it's essentially a recurrent neural network without the benefits of the recurrent neural network which is a back propagation through time so that it actually actively learns what to remember so it gets the drawbacks of recurrent neural networks which is everything has to go through that one hidden state in the middle without the benefits and all of these things uh kind of combined make me doubtful on the other hand I am super happy that people are trying out new stuff into very long attention even infinite attention and so on make up your own minds um I'll leave a link to the paper in the description tell me what you think and stay hydrated bye-bye

---
*Источник: https://ekstraktznaniy.ru/video/12003*