RWKV: Reinventing RNNs for the Transformer Era (Paper Explained)
1:02:17

RWKV: Reinventing RNNs for the Transformer Era (Paper Explained)

Yannic Kilcher 02.06.2023 81 337 просмотров 2 109 лайков

Machine-readable: Markdown · JSON API · Site index

Поделиться Telegram VK Бот
Транскрипт Скачать .md
Анализ с AI
Описание видео
#gpt4 #rwkv #transformer We take a look at RWKV, a highly scalable architecture between Transformers and RNNs. Fully Connected (June 7th in SF) Promo Link: https://www.fullyconnected.com/?promo=ynnc OUTLINE: 0:00 - Introduction 1:50 - Fully Connected In-Person Conference in SF June 7th 3:00 - Transformers vs RNNs 8:00 - RWKV: Best of both worlds 12:30 - LSTMs 17:15 - Evolution of RWKV's Linear Attention 30:40 - RWKV's Layer Structure 49:15 - Time-Parallel vs Sequence Mode 53:55 - Experimental Results & Limitations 58:00 - Visualizations 1:01:40 - Conclusion Paper: https://arxiv.org/abs/2305.13048 Code: https://github.com/BlinkDL/RWKV-LM Abstract: Transformers have revolutionized almost all natural language processing (NLP) tasks but suffer from memory and computational complexity that scales quadratically with sequence length. In contrast, recurrent neural networks (RNNs) exhibit linear scaling in memory and computational requirements but struggle to match the same performance as Transformers due to limitations in parallelization and scalability. We propose a novel model architecture, Receptance Weighted Key Value (RWKV), that combines the efficient parallelizable training of Transformers with the efficient inference of RNNs. Our approach leverages a linear attention mechanism and allows us to formulate the model as either a Transformer or an RNN, which parallelizes computations during training and maintains constant computational and memory complexity during inference, leading to the first non-transformer architecture to be scaled to tens of billions of parameters. Our experiments reveal that RWKV performs on par with similarly sized Transformers, suggesting that future work can leverage this architecture to create more efficient models. This work presents a significant step towards reconciling the trade-offs between computational efficiency and model performance in sequence processing tasks. Authors: Bo Peng, Eric Alcaide, Quentin Anthony, Alon Albalak, Samuel Arcadinho, Huanqi Cao, Xin Cheng, Michael Chung, Matteo Grella, Kranthi Kiran GV, Xuzheng He, Haowen Hou, Przemyslaw Kazienko, Jan Kocon, Jiaming Kong, Bartlomiej Koptyra, Hayden Lau, Krishna Sri Ipsit Mantri, Ferdinand Mom, Atsushi Saito, Xiangru Tang, Bolun Wang, Johan S. Wind, Stansilaw Wozniak, Ruichong Zhang, Zhenyuan Zhang, Qihang Zhao, Peng Zhou, Jian Zhu, Rui-Jie Zhu Links: Homepage: https://ykilcher.com Merch: https://ykilcher.com/merch YouTube: https://www.youtube.com/c/yannickilcher Twitter: https://twitter.com/ykilcher Discord: https://ykilcher.com/discord LinkedIn: https://www.linkedin.com/in/ykilcher If you want to support me, the best thing to do is to share out the content :) If you want to support me financially (completely optional and voluntary, but a lot of people have asked for this): SubscribeStar: https://www.subscribestar.com/yannickilcher Patreon: https://www.patreon.com/yannickilcher Bitcoin (BTC): bc1q49lsw3q325tr58ygf8sudx2dqfguclvngvy2cq Ethereum (ETH): 0x7ad3513E3B8f66799f507Aa7874b1B0eBC7F85e2 Litecoin (LTC): LQW2TRyKYetVC8WjFkhpPhtpbDM4Vw7r9m Monero (XMR): 4ACL8AGrEo5hAir8A9CeVrW8pEauWvnp1WnSDZxW7tziCDLhZAGsgzhRQABDnFy8yuM9fWJDviJPHKRjV4FWt19CJZN9D4n

Оглавление (11 сегментов)

Introduction

hello today we're going to look at RW KV which in its own words is Reinventing RNN for the Transformer ERA this is a very interesting project and a very interesting model architecture because it has some properties of Transformers notably it's a model architecture that's very scalable in terms of training so you can stack it really deep and you can still train it and also you can parallelize training at the same time it avoids the quadratic memory bottleneck that Transformers have by essentially being an RNN it's kind of a recurrent neural network in the sense that during inference you can compute the output step by step and always have a constant memory because everything is put into a hidden State we're going to look at how these two things come together and what the trade-offs are between the two the project is also very interesting because it's been largely developed by one person or by just very few people uh who have worked on this and then compare this to entire corporations that are pouring their Human Resources into Transformers and still the results are that this model in some cases not in all cases but in some cases can be very comparable in terms of performance uh with Transformers with really big transformers and as I said it is scalable which so far rnns have been lacking we'll go over the paper architecture and see what's actually happening right here I have some of my own thoughts and um you know just opinions on this and I hope you're with me but first uh let's let me show

Fully Connected In-Person Conference in SF June 7th

you this fully connected is a conference it's by weights and biases it's a one-day event and it has pretty cool speakers so the not only the co-founders of weights and bises themselves but the co-founder of Lang chain is there the co-founder of kaggle is there Richard socer from u. com is there Chipen is there of claypot uh there's so many people right here and so many cool speakers and as I said if you are in around the San Francisco area go give that event a visit if you want I'm going to use uh put a link and a promo code into the description of this video that will make the tickets cheaper so um the tickets will be 49 bucks instead of what you see right here so that's going to be on June the 7th don't forget that June the 7th in San Francisco it's an inperson event has a Max Capacity so grab it now that's it uh thanks to weights and biases for also giving me this opportunity and you know also giving this to you it's very cool so

Transformers vs RNNs

RW KV it stands for let me not screw this up receptive receptance R is for receptance W is for weight ke K is for key and V is for Value these describe the different elements of the model architecture and we'll go through that we'll go through it so what are we dealing with we're dealing with a model that you can use for lots of things but we're mainly dealing with a model that in the instance it is outlined here is for language modeling by language modeling we simply mean we have a piece of text y y and what we want the model to predict is the next set the next tokens or the next word in the text so from here predict this thing then and so on Transformers are usually used in this manner uh notably I would stick the entire prefix here into a big Transformer and the Transformer would spit out the next step it will do that in a form that we call causal attention which essentially means that every piece in the text right here so every token can attend to all the other tokens that are before it so come in front of it would have be would be inputs to that token and using attention that results in a quadratic um in a quadratic requirement of compute and memory now if you can see there the if every token attends to its back it's like T * uh T2 minus one or something like this interactions that need to be considered so t^ squ in expectation and no yes maybe yeah that makes about sense um so Transformers naturally have a limit because if you add one token then you always add a memory requirement of T and so very quickly that scales to be out of proportion so their power is traded off by the fact that they can only consider a limited set of tokens at a time recurrent neural networks trade this off recurrent neural networks if they have to do something like this have an input and then predict the next thing what they'll do is they'll just start here then they'll put something into a hidden State like a little memory I'm going to represent that as a box people have forgotten how rnn's work I figured uh like a few years ago I could have talked about rnn's and every one of you would have known what I'm talking about and then Transformers are kind of the new thing and now it's more like I have to say I have I don't even have to mention stuff is a Transformer um but I do have to conversely explain RNN so it would put whatever it learns into a memory like this hidden box this box right here is the memory then from that it would consume the memory and it would consume this thing and it would build a new memory and then it would consume that memory and it would consume the next input it would build a new memory so it does this step by step and it can always you can drop the old memory so you can forget about the old memory because all you need uh haha pun intended to go forward from here on out is that current memory so you do step by step and you always save stuff into the memory into the hidden state or whatever you want to call so rnns have this really great property that they only require a constant memory in order to do inference however the individual inference step for example when we are here we have that memory we predict the next um the next token from it so we predict the next this thing right here we can only consider the memory and the previous token that's all we cannot explicitly consider any token that is in the way back uh because everything goes through that hidden State and that bottleneck has usually been one of the downfalls for RNN there is a problem of Vanishing gradients uh there are couple of other problems that you can't just compress information into that one hidden State plus rnns have been notoriously hard to train because the inference always requires this stepbystep thing which means you have to do back propagation through time which gets is part of the vanishing gradient problem but also means that you can't parallelize the training in a Transformer I can input a token oopsie a token sequence of 50 tokens and that gives me 50 training examples which I can train all at the same time because of the causal attention mask in an RNN if I have a sequence of 50 tokens I can still only train one token loss at a time because uh I can't in prer everything at the same

RWKV: Best of both worlds

time R wkv is going to strike a tradeoff in the middle of those two things and in a sense so people ask or I asked is this a Transformer more is it an RNN more and I've come to the conclusion it's a convet and I'm going to explain my reasoning and they also refer to that by the way and it's not like I found out something uh great right here but in the most basic sense you can think of this thing as a convolutional network um across a one-dimensional uh sequence of tokens that's going to be my statement and we'll go through it so they give a bit of an introduction right here on what it what this means and what I essentially just said Transformers scale quadratically with sequence length rnn's exhibit linear scaling which is very beneficial very advantageous the RW KV model combines the efficient parallelizable training of Transformers with the efficient inference of rnns so it's like almost two modes between which you can switch around uh they say they have a linear attention mechanism and as I said they formulate the model either as a Transformer or as an RNN that linear attention mechanism is something that we're going to focus on in just a bit um because it's not their fault because people have been doing it uh before them but I think it really stretches the word attention and what it means I think it stretches that to like a point where I don't agree anymore calling it attention um but again they're not the first people kind of doing that so I'm not going to hold them to account right here they say this is the first non- Transformer architecture that to be scaled to tens of billions of parameters so one of the properties of this thing is really you can scale it and you can train it in parallel which also means you can pump a lot of data into it and that's very advantageous and there is a lot to be said here about maybe it's not that much the architecture we're dealing with but more you know good models are simply models with architectures that are scalable so it's not maybe right that's a hypothesis like how much of the performance of something like gp4 is due to the fact that it is a Transformer and how much is due to the fact just that is a scalable architecture you can pump a lot of data into we don't know yet too well how those things exactly tradeoff um but there's a good argument to be made that hey if you can find some other architecture just scales really well then you know you might as well reach the same performance um this is a an a complexity a Asm totic complexity table on the different ways of doing attention mechanisms so the Transformer is that classic mechanism as we said it needs quadratic time in t Big T is the sequence length of the sequence we're processing and then in terms of space it needs a little bit more but the leading term is also this uh T squar right here the D is the dimension I believe of the embeddings of the Hidden spaces which is usually a constant Factor um across all of these things there are various tradeoffs like reformer and so on performer a lot of these um they do approximations to the original Transformer attention mechanism and it's also notable to say that RW KV is not an approximate attention or anything like this it doesn't try to approximate the original attention it replaces it with a different mechanism of considering the past so what does that mechanism look like um yeah let's let's go into it

LSTMs

let's yeah let's go into it so maybe yeah this is smart let's first look at this here if you've uh been in deep learning and are old like me you remember this is an lstm this is a long shortterm memory cell um from back when before attention was a thing so this was one of the main ways people build rnn's recurrent neural networks uh that would actually somewhat avoid that Vanishing gradient problem and could learn to remember things for a long time the idea behind lstms is that you have your the hidden you have two hidden States if I'm correct um am I correct yes so you have two hidden States this C and this H being the real hidden State and you have a lot of these uh gating mechanisms so what the gating mechanisms do it's often represented like this here this is an elementwise product so what you would do is you would give in get in a hidden State you would get in an input you put the input through some sort of computation neuro networks nonlinearities yada yada and then you'd have some result right here which would be a vector right and then the question is obviously you can compute an output from that at this particular time step but the question is how should the next hidden State look like and the idea behind lstms and similar architectures is that we're going to take the last hidden State sorry and we're going to update it in a if in a really basic RNN we would just kind of replace the hidden state or we would maybe add the two together and then for propagated but what an lstm does very interestingly is introduces these Gates so what we'll do is we'll have something like a forget gate and the forget gate is simply a binary or not a binary but it's a obviously continuous but we can imagine it as a binary Vector where just it's a mask values between 0 and one and wherever it's zero there's an elementwise multiplication wherever it's zero this hidden state is going to get forgotten and then the new state is going to be updated at that particular point so there is a going to be a forget gate there's also going to be maybe a gate right here that tells which things of the Hidden state to even remember right that's also a binary Vector maybe only these things and so the network itself can control which parts of the information it wants to remember which parts it wants to forget and rain and then the new hidden state is going to be sort of an addition of these masked inputs and then that goes on in order to do that there's a lot of computation needed as you see right here and in particular I want to draw your attention uh to one fact the next hidden States for example if you take this H right here the next hidden state is always going to be some nonlinear function um for example this is a sigmoid I believe it's a nonlinearity or a tan H or something like this of something here like of this CT the CT in itself is a linear combination but a linear combination of this in its turn again is a nonlinearity and then a linear combination of the things and this ultimately is the last hidden state so you can see from the last hidden state to the next hidden State we pass at least two nonlinearities right and there you see this sequential stepwise nature of things it's always um it's always you have to compute the step then Next Step based on this step and the outputs are nonlinearly related so there is no way you can like jump a few ahead or take five steps together or something like this because they're nonlinear it's not like linear functions compute really easily right um nonlinear functions that are stacked where every next thing needs the first thing as an input those are really not parallelizable not like aggregatable or anything like this so that's the problem with rnns if they're formulated like this they also

Evolution of RWKV's Linear Attention

go a little bit into the attention mechanism um which I guess by now I don't have to explain Too Much Anymore uh there is a query and a key Matrix both are produced from the data so the data enters and three matrices are produced queries keys and values we'll do an outer product between the queries and keys which is defines this quadratic interaction so every token can attend to every other token uh sometimes you would then mask this with the causal mask with the upper or lower triangular Matrix um then build a soft Max across that so the softmax essentially converts some values so let's say your values are this positive negative positive would convert that into a distribution we can interpret it as a probability distribution or you uh essentially whatever you want but you can it defines where to put the attention right it defines which things I should look at like this one and this one a little bit I guess that's why it's called attention because it defines the weights in by which I aggregate information so it dynamically allocates attention to some of these values right here rather than others that's what attention means to me in it's in the sense how it's described and that's why I said the term is stretched a bit far you can write attention in a different way and you can decompose it recur recurrently um almost like an RNN so you can decompose a tension computation and this is done in parts to also get around the memory bottleneck however um you traded off with computation right if you compute these things in sequences you have to compute them sequentially where you could just compute them in parallel by doing a big outer product matrix multiplication so you do trade off time and memory here um and by the way you have to remember all of these things you can sum them I guess yeah I mean any matrix multiplication you can probably do that with never mind you can decompose the attention computation same one as above in a way like this so I have the outer product of just pair of keys and queries um I erased that to the I have the exponential of that that's part of that softmax computation ultimately I divide by the sum of these values so that's the soft Max operator and then I multiply the value at that particular location so you can see here this part here defines a weight and V is the value so it's a weighted sum of the values now we come to attention free Transformers which is a piece of work that this paper takes a lot of inspiration from attention free Transformers try to go about this in the same way as attention but they say hey can we reformulate this formula up here the one that we saw can we reformulate this and just turn it into something that doesn't need that whole quadratic memory uh shebang and they come up and say if we don't have to do the outer product up here this outter product if we don't have to do these outter products um then that would sort of mean that we don't have this token to token every token can attend to every other token interactions anymore which means we don't have that quadratic uh thing anymore and therefore um we would you know we could save a lot of memory so they replace the interactions between query and key they replace that and they say let's not even compute a query we'll just compute a key for each token so we're now only producing matrices K and V okay no queries um instead we'll have these W's right here the W's are learned so w is a learned um Matrix so not a not computed from the data but learned and it just learns how tokens interact with each other what does it mean it means that I just have a matrix have size T by T and in there 1 2 3 4 5 and in this Matrix there's going to be number like 7 okay and that means that the interaction that the weight the attention weight essentially of that token one has with respect to token two so how much is token one going to attend to token two let's assume it's not causally masked is seven okay that that's that it's the same for all the sequences you all you just say well the first token is always going to attend seven to the second token all the same it's one set of learn parameters and therefore you can um this is just now a multiplication by a constant essentially now that just defines a fixed attention um which is a bit too little so the fixed attention is not flexible enough before we had completely Dynamic attention and now we go to completely fixed attention that is just learned across the whole data set and the authors there said wisely hey you know it is in it is the fact that depending on the current data point we might need to look back further or change that attention pattern a little bit so they say okay how about we simply add the keys here so the keys are values that are computed from the data right so now the data can Define essentially an offset to that so maybe for one data point it's 8 important because it says Ah token one should really look at token 2 and for another data point the K is negative one so this is six right here that depresses that a little bit so there is a modulation in that attention pattern there is one across the data set which is fixed and then on top of that there's a modulation given by each individual data point uh to modulate that however the interaction is not multiplicative as it is in the Transformer attention that you see up here it's additive which is in a sense uh a lot less um powerful because the multiplicative interaction uh really defines you know when two things are close and far apart between keys and queries whereas here we just modify a little bit that fixed pattern right here however that fixed pattern can't take into account so token one if it decides on its K value it can't take into account the what token 2 is it simply says I'm the word cat I should probably really look three words behind me um that's really important whereas the original attention can decide I'm the word cut and I should probably really look at words that uh relate to fish or fur or uh sleeping like that those words those kinds of words really interest me and that's how it would craft its query and that's what it would be retrieved from the keys of the other words whereas here it can just say well I'm cat I should probably look three words behind me seems really good I hope you can see how this is kind of less powerful than the original attention but it is more scalable again you see what it defines is essentially this part right here is a weight uh so you have a weighted sum of the values lastly we have this paper right here so it formulates the attention mechanism it says yeah but here what we still need to do is learn that interaction Matrix which crucially also means it's still limited by T um you know sort of a fixed size we can't go bigger than that and what they say now is how about we don't learn this Matrix all we learn is a vector the vector W and the vector W um it's the same for all and it defines so it defines it's a vector and it has the same dimensionality as the hidden dimensions of the Hidden State and it defines for each Dimension it defines how much does the past matter so for one dimension it could say well the past matters a lot therefore um that value is very high and for the other diamension Cas I know that value is very low the past doesn't matter a lot what happens if we do that um ultimately we're going to do something like again something up here so you can see we have e the exponential function of w t plus K Ki um so they now say okay we multiply this by this term T minus I means how much back I'm looking so if we wonder we are token one how much do we attend to or let's go the other way around we're token four how much do we attend to token number two okay then we ask in which dimension in the First Dimension oh the First Dimension is really large um therefore we're going to attend a lot to the general past of Dimension One okay so maybe that's this drop off and then we look two tokens in the past because the current time step is four and this I here is two so how which token are we and which do we attend to and you can see it's minus so this is getting bigger and bigger as you go further back in the past so it's essentially a linear drop off in the value of w and W itself can be big or small and these two things together Define how important a pass token is in that particular Dimension so it's a multiplication of how far back is it and how much is this dimension in general considering its history and then you can see ah it's considering this much it's two tokens back so this much attention and then that is modulatable Again by a key value that can depend on exactly what the current token is so I hope that's understandable the attention Matrix is built on the fly in our wkv and it's defined a by the V it's defined per Dimension the vector W defines a general importance uh list for each Dimension it defines how relevant is the past the second component is this it's simply a linear Decay on in into the past so the uh further back the less important it is that being said this is obviously then put through the exponential function and therefore it's a linear decay in the exponential function so I guess an exponential decay and then the third thing is the modulation and this is where this is where the actual value that the actual what the token is uh plays a role is then to modulate that point we determined right here modulate up or down a little bit also in the exponential function so that is how RW KV considers the past in general it forgets the past in an exponential fashion modulated by the global importance of dimensions and a value that's dependent on the current token all right so it has the same trade-offs as these attention Transformers right now what do we do with this with these things somewhere we have sorry for scrolling around so

RWKV's Layer Structure

heavily so this is how the model is composed this is a recurrent application of the model so we see the same model applied to three tokens in succession so the my name is Bob okay you input my and you're trying to uh make it output name then you input name so you input my name you're trying to make it output is then you input my name is you're trying to make it output Bob so it's three applications of the same model so the model isn't composed of these three columns but the model is composed of one column and then we just apply that over and over again you see it has uh a beginning essentially which is a token embedding um it has an end which is a language modeling head which is a fully connected or a stack of fully connected layers that just Maps into the vocabulary uh but in the middle it's composed of a recurrent or sorry of a series of layers so there's in each layer there's always sorry a time mix module and a channel mix module so the time mix module being here and the channel those are repeated time mix Channel mix and so on and then on top there's this uh language modeling head so what are the two blocks this is a schematic of how the two blocks look like I know this is a bit small um but the time mixing is down here and the channel mix in is down here we're going to look at this in a bit of detail in the math but observe right here what you always have is you have the input signal you're going to compute R from it R is a a value that is going to be used as a for like a gate a forget gate so R always defines how much of whatever is incoming here or whatever is incoming from here how much of that do I want to retain and send up to the next layer so as you can also see R is computed over here as well so for every one of these blocks we're going to do a computation over here or over here and then we're going to have a decision made by this branch of the computation of how much of that we even want to accept and then send up to the next layer so that's the purpose of the left branches of these computation there is residual signal across all of these things right here so that kind of mimics the state of an lstm maybe but in a in an upwards way so in a layer to layer way so we always have a residual modu we also have a forget gate in front of adding it to the residual signal what does the these two modules look like so actually let's first go to the channel mixing block is very reminiscent of um kind of feet forward maybe layers so what we have is ignore this part for a moment right here as I said the r is computed from the input X and just a linear layer so X time a matrix that that's R so that's a linear layer that defines R then we have K also x * a matrix it's a very simple feed forward uh layers right here then W which is this part right here that's no sorry v you can see that's a nonlinearity and the nonlinearity here is the squared reu nonlinearity on top of K and again a linear layer and at the end we are doing that element wise multiplication by this signal right here so the r uh pushed through the sigmoid here is that forget gate that we talked about but you can see it's a essentially if you follow the signal through the actual path of signal it starts here x well that's a funky X it's multiplied by a matrix so a linear layer that becomes K that's put through a nonlinearity then multiplied by another linear layer that becomes the value V and then send through the forget gate so it's essentially a feat for neural network with one nonlinearity and at the end a forget gate as like another nonlinearity and that's it that's the channel mixing module I guess it's Channel mix mixing because this uh Matrix the linear layers they do in fact mix the channels which means that every Dimension sort of can get inputs from every other dimension which is just a B of A feet forward Network now I've crossed out all of this stuff in the back right here so we should talk about this as well what they do is something called time shift or token shift or um I believe that's one of them token or time shift and that is they always not only take the input to the current layer at this particular time step they also always take the input from the last time step and they linearly interpolate between the two so you can see here mu and 1 minus mu the are either hyperparameters or they're learned parameters but they are um per operation so mu R here is a parameter for the general computation of this R it's not dependent on the particular data point only this and this are dependent on the data point with XT being the current input to the layer and XT minus one being the last step input to the layer this is pretty interesting because it means that we not only always only take the current input and the hidden state from before like in a general RNN but we always take the current input the last input and for this layer that's it but in the time mixing module we'll then take the hidden State uh onto these that's why in this diagram you see these um these lines right here these diagonal lines are the token shift lines so you see this channel mix module is going to take the current input whatever it gets from the lower layers or the original signal and the input to the same layer at the last time step so current input and input to the last time step and that also goes in these two things are linearly interpolated and that then is the input quote unquote to the current layer so it's the interpolation of the last step and this step's input noce it's not like we don't mix like the internal States right here we mix the inputs before they into the layer now let's look at the time mix you can see there is also this token shift happening so this token shift that's just something you can do I guess if you have a one directional uh sequence that you need to predict um you can do this token shifting uh but what's also interesting is we have a second line which are these states so how does that work and that's going to be now the actual recurrent part of this mod model right here you can again see here we're always uh working with the token shift we never just work with the input but you can just think of these things here always as like X sorry X tild where X tild is a mix between the current input and the last layer's input so we have We comput R which is just X Tilda time W time a feed forward layer and that again becomes the forget gate down here with an element wise multiplication we do have an output layer an output um uh sorry we do have an output feed forward layer uh kind of a projection that's I'm going to guess they put that in because it was advantageous to do so it can also change the dimensionality and whatnot then we're going to compute two things K and V so you'll notice four whatever was called V uh was produced from K but now both K and V are produced from X um I don't know why they call them the same probably to keep as much in line with the Transformer terminology as possible but it's really there there's no relation uh between like the V here and the V before the K is computed similarly so this block right here the K and the V are computed module of the time shift um as they are computed in the original Transformer uh architecture which is just the input times a linear layer then what happens is interesting then as you can see right here we go into this weighted sum so you'll see something familiar here that that's awaited a weit and VT that's the values so V we computed here so we're going to look for a weighted sum of the values B but oh sorry no no forget that we're not only going to look for a weighted sum of the value V because you also see here are V's but these are V and this is VT the v in fact are the past Val values right so we're going to look for a weighted sum across the entire past and that's actually oh sorry it's actually the same um as before yes let me back up so that I here only goes to T minus one uh so you can see that we sum the vi here and then at the end we also sum the VT the only reason that there is a difference is this U right here is a different parameter than those W's but is in essence it's again a weighted sum over all the values and the values are from the entire sequence so far we've just considered the current time step the current input and yes the last steps input but in this step right here we consider the entire past right and we want a weighted sum across the values of the entire past like in these attention free Transformers and whatnot but because we now no longer are limited by having this fixed size uh this fixed size attention Matrix right here even if it's learned right even if it's not an attention Matrix in the attention fre each Transformers it was still a fixed size because we're no longer limited because all we do is we say how important is each Dimension and uh how does it decay with time in the past that does not is not limited back in time it just gets really small back in time but it is not limited and therefore we can do this until perpetuity and especially we can do it until I equals 1 so going back to the very first token so for every token that we need to do inference for this step this value right here will be a weighted sum across the values of the entire past right and you can see easily that you can do this in a recurrent fashion this is a it's a soft Max and you can see there is exponentials that here and these are multiplied by the values and down here we go just over some of the exponentials so it is a soft Max um however you can just keep track of the numerators and the denominators separately and then that becomes your hidden State and you can pass that forward so you just grab uh this sorry right here and this down here and before dividing them you just pass them on right and then in the next step you simply add something on top divide them for the current step but then pass on the hidden States separately so that's what they mean by States if they say they don't mean the United States I'm sorry they mean these values here that you need you can compute this in a recurrent fashion or parallel fashion so just to f finish here quickly this value here this weighted sum over all the values of the past um is then Fe fed into this forget gate as you see here and the output is computed from it now multiple things to note right here that the aggregation over the past here contains essentially no nonlinearity right because V the thing that is being aggregated or in general these hidden States they're just produced as a linear function of a linear interpolation of the inputs right there is nowhere where the previous state goes through a nonlinearity in order to compute the next hidden State you can essentially track this as a big sum so as a list you can track it as a list and then do the sum or you can just track it as the sum which also means that the parallelism of training this becomes feasible again so you can train these in parallel because it's just all a big sum that you can um compute for an entire batch and an entire sequence at the same time and yes also you can use it in an RNN fashion where you do it step by step but because that's because it has no nonlinearities in between it's literally just a sum and that's also why you while you see what I mean this is essentially a convet and I mean this it has two parts right the first part is this stuff right here this token shift look at the diagram and the diagram you clearly see if you are this element right here what do you have access to this and this oh but by extension you have access if you just go via the token shift you have access to this and this right and so from the lower layer to this and this right so by here so you have a receptive field that grows with depth right if we had another layer the receptive field will grow again so the token shift itself is already ENT very directly a convet and you know you only have nonlinearities as you cross these layer boundaries right here otherwise it's essentially just a linear interpolation which is exactly a convolution with the kernel being uh mu and one minus mu that's your convolutional kernel it's of size two you slide it over um so that defines a con convolution and the same thing for these things right here that is very reminiscent of uh if you know about these like S4 or state space models and so on which essentially what they is they Define a way to linearly aggregate the past right um which is exactly what this big sum is right here they Define a way to do a weighted sum across the past that in between has no nonlinearities so you can just track it like this and yeah so again and S4 is essentially like a big convolution so if you want to think about this model in another way than a Transformer or an RNN not that there are not already enough ways it's essentially a big convet um particular in this way right here it's a conet that has sort of an infinitely long convolution into the past um or until the beginning of the sequence I guess um and you the way it's done is there is a standard con kernel and then that's modulated by these K values right here all right so that is how that works I hope I've made this a bit clear um the why this is called Channel mixing and this one isn't called Channel mixing like this is just as much Channel mixing as the other one is the only difference is that down here um there is kind of a nonlinearity within the layer and there is and here we have this aggregation over time so I guess calling this time mixing is fair but this is just as much Channel mixing because these feed forward layers they mix the channels so yeah but that's naming that doesn't really matter so they

Time-Parallel vs Sequence Mode

specify here this can be used in time parallel mode um complexity of processing a batch of sequences in a single layer is this so you can process a about um as a batch right so it requires they say meanwhile updating uh attention scores requires a Serial scan and has a complexity of this they've implemented this in a custom Cuda kernel you can actually go look at the code I've done that and it's fairly easy to understand the Cuda code it's one of the more understandable pieces of Cuda code um and you just write this function and Cuda takes care of sort of parallelizing that and putting that across cores and workers and um processes and so on the element wise uh computation is time dependent but can be readily parallelized along the other two dimensions on the other hand you can also use this as a in a Time sequential mode um can be conveniently formulated recursively for decoding during inference oh my connection here is spazzing out one second and we're back yeah so they say each output token is dependent only on the last state which brings obviously all the advantages and disadvantages of uh rnns with it so again we can only consider information coming through this bottleneck of the Hidden state but I feel because it's this big sum of aggregation is essentially a weighted sum across the past and not nonlinearity across nonlinearity it can much more easily look back into the past in but it can do so in a linear fashion right so I feel this is among all the situations like Transformer lstm and this one this is probably the weakest form of being able to look into the past with Nuance right you can but you can only do so in like a general fashion whereas a Transformer can go look into the past and have exactly the same amount of detail as it does for the recent past so you can look into the long past as long as it's within the context and do exactly the same computation there as it can do for the recent past or for the current token whereas an lstm can't do that it can't look into the past at all however it can do a lot of considered computation in each step before it saves it to the hidden state so it's weaker because it can't go back but still it can do a lot of complex computation this model right here it can look kind of it also goes through the hidden state but it can look the easiest much more easily into the past as an lstm because it's just this weighted sum instead of nonlinearity after nonlinear ity but it kind of has the weakest form of computation that it does in the exact moment I hope that makes a lot of sense that is not a scientific statement that is just me trying to ramble maybe I'm also totally wrong about this or maybe you know you can easily make up for that by stacking a lot of layers because now this model is being able to be stacked really heavily and be scaled really heavily and that is probably enough to make up for all the lack of computation in the individual cell it's just like hey let's just stack the stuff yeah another property that I have mentioned but is not ex entirely maybe come through is the fact that they always compute from the inputs so they don't take necessarily the um the hidden States over so but all the functions are like linear functions or expon enals of linear of the inputs uh so there's no like nonlinearity in between that time aggregation and the um and where and the inputs it's themselves to the layer sorry enough rambling here you can see scaling behaviors um very beautiful cumulative time during text generation as the tokens go up obviously this model has a linear scaling where everything else Goes Vroom

Experimental Results & Limitations

the experimental evaluations are also really interesting at least at the data sets that they have considered right here it can hold its own sometimes it's a bit better worse than other similarly sized Transformers but it um it performs along the same lines now I have heard people say that um the model is qualitatively not as good I have as Transformers of the same I've heard other people say it is better or for some things it's better that I don't know it still has to be shown also these challenges or these data sets right here don't really show me what I would want to find out so if I compareed this to a Transformer what I would want to find out is how like where is the where then is the actual difference right and I'm going to guess the difference is let's compare something that is not in the recent past but a bit more back now if I have to consider something that is bit more back in the past and if that's a very complex thing who's or the computation of which like how I have to treat that depends on the current token I can't really tell you now an example for that but in a situation like this a Transformer would be way Superior um than like any lstm or or this model right here so maybe during programming in c certain Fashions um or if the context for something is only given much later but um for a lot of application is probably not that important they also can show easily how increasing the context length since they can do that now increasing the context con context length uh nicely decreases the loss of language modeling on the pile data set and they give some um some suggestions right here so for example uh improving computational efficiency by applying a parallel scan in this step to reduce the computational cost uh to this which would be another Improvement that's theoretically possible but I'm going to guess that's not done right now um they also discuss the limitations so they're very open about the limitations and all the comparisons right here right this is um the linear retention leads to significant efficiency gains but still it may also limit the model's performance on task that require recalling minuti information over very long context okay that's essentially what I was trying to describe but it's done in much better words than I did the recurrent architecture inherently limits its ability to look back at previous tokens yeah that's like an RNN and the they also discovered that there is an increased importance of prompt Engineering in comparison to standard trans former models so there are multiple hypotheses the one they give right here is the linear mechanism limits the information from The Prompt that will be carried over to the model's continuation as a result carefully designed prompts may be even more crucial for the model to perform well on tasks um it could be that is the reason um there could also be other reasons maybe this model overfits a bit more so or is less generalizable and that's why changing the prompt really matters more although maybe not I'm it's just one of these things where there is probably a thousand possible explanations of why with this model the getting the prompt right really matters a lot more than in a Transformer but um I wouldn't put my bet on any one of those until we have really good experimental confirmation there's usually a lot of um how should I say there's a

Visualizations

lot of oams Razor that should be done all right the last thing I wanted to show was this experiment right here which I found really cool so uh the first one is this time Decay so time Decay sorted along Channel axis which is where you can see the difference um so here is the channel this is the dimension right we're looking at uh W right here this Vector W how important is the past in layer one you can see as far as I can tell the past is not that important so for many of the channels the past is kind of that for some of the channels it's sorted by that value so for some of the channels the past is important for Mo for a lot of them it's really not important and you can see as you go up the layers of this network they more and more consider the past um and then what's interesting is the drop off right here so you have some channels really specializing in near-term information but a lot of channels really looking back at the long time into the back so the W will Define almost no drop off uh whereas at the in the lower layers um you have much more local information that the thing considers so I found that pretty cool to see that visualized and to see that progression as in a trained model As you move up the layers the second visualization here it's um an example so here you have a input of tokens the Eiffel Tower is located in the city of and then they look at How likely is the word uh Paris right here and what they do is they change in each layer um they swap out the weights or they disturb the weights and they see how much influence does that have on the probability of the word Paris appearing this is a way we've previously looked at that in the paper on Rome I believe was the technique where you can it's a way to figure out um what are the important information paths uh that you're considering so in this case you can clearly see after the I fell um we see that layers one to like 20 or so light up right and after that it's what's cool is so after that only layers whatever 21 22 and 23 Light Up Until the End right here so only these light up which means that you can disturb these lower layers right here because the information that Paris should probably is a very likely word has already passed from the lower layers to the higher layers right all the way up and is then stored and carried along in the hidden states of these higher layers across these tokens such that the information in the lower layers is only mildly relevant um to the output right here of course it is relevant there's not zero values right here but it is um mildly mildly relevant so I thought that was a pretty cool visualization of what was going on in this model let me quickly scroll through and see if I forgot anything I

Conclusion

did not there are some examples at the end which is pretty cool um the models are available you can go check them out you can go test them uh the code base is also I found it to be fairly understandable so if you want to go look at that it's also available and yes my thing is spasming out again that's where I'll end it uh go check out fully connected um again code is in the description to get you some discount if you're in San Francisco June 7th I unfortunately will not be there but many many cool people will be that was it thank you bye-bye

Другие видео автора — Yannic Kilcher

Ctrl+V

Экстракт Знаний в Telegram

Экстракты и дистилляты из лучших YouTube-каналов — сразу после публикации.

Подписаться

Дайджест Экстрактов

Лучшие методички за неделю — каждый понедельник