# Scaling Transformer to 1M tokens and beyond with RMT (Paper Explained)

## Метаданные

- **Канал:** Yannic Kilcher
- **YouTube:** https://www.youtube.com/watch?v=4Cclp6yPDuw
- **Дата:** 27.04.2023
- **Длительность:** 24:34
- **Просмотры:** 59,523
- **Источник:** https://ekstraktznaniy.ru/video/12453

## Описание

#ai #transformer #gpt4 

This paper promises to scale transformers to 1 million tokens and beyond. We take a look at the technique behind it: The Recurrent Memory Transformer, and what its strenghts and weaknesses are.

OUTLINE:
0:00 - Intro
2:15 - Transformers on long sequences
4:30 - Tasks considered
8:00 - Recurrent Memory Transformer
19:40 - Experiments on scaling and attention maps
24:00 - Conclusion

Paper: https://arxiv.org/abs/2304.11062

Abstract:
This technical report presents the application of a recurrent memory to extend the context length of BERT, one of the most effective Transformer-based models in natural language processing. By leveraging the Recurrent Memory Transformer architecture, we have successfully increased the model's effective context length to an unprecedented two million tokens, while maintaining high memory retrieval accuracy. Our method allows for the storage and processing of both local and global information and enables information flow between segment

## Транскрипт

### Intro []

hello today I wanted to give a quick look at this paper scaling Transformer to 1 million tokens and Beyond with rmt by idar bulletov Yuri kuratov and Mikhail s boardsef this paper has a big promise namely scaling Transformer inference to a humongous 1 million even 2 million tokens and in the first page plot right here they have some tasks that they want the Transformer to do for example there's a memorized task there is a detect and memorized task and there is even a reasoning task and as you can see right here we have input size in tokens and they go up to 2 million and you can pretty much see that the Transformer holds up performing these tasks across all of those token sizes now there are a few caveats to this mostly the issue is a little bit in just in the marketing of the paper itself is actually pretty good and pretty okay in fact it's a follow-up to an earlier paper on this rmt model right here this uh we're gonna look at that briefly but the paper is essentially a small extension to that other paper namely that other paper did the same thing for auto regressive so for causally masked decoder style Transformers and this one the paper even says it's not a paper it's a technical report simply applies the same thing to encoder Transformers such as Bert and it actually turns out that this variant right here is simpler so conceptually you'd actually read and do this paper right here first and then do the decoder only paper because that you know you have to handle the causal masking we're going to dive into a little bit of this but just to take expectations down a tiny bit this is not the way to go and scale Transformer to ginormous inputs in the way that you think this is very much we'll get into it so there is a

### Transformers on long sequences [2:15]

diagram right here that shows about how this is going to be achieved namely we're going to have text like blah blah blah really long sequence of text that doesn't fit into a Transformer Transformers suffer from quadratic scaling their attention mechanism requires huge amounts of memory that scales quadratically with input length so if you have a segment like this and if you want your Transformer to consume twice as much tokens uh then you need four times as much memory four times as many tokens 16 times the amount of memory you quickly run out of memory so that's why this paper or this method considers chunking the text into plug into chunks into sequences or what they call segments and then processing each segment with a Transformer and then somehow connecting these segments together over inference this is obviously not the first time we are seeing something like this also they themselves say so there have been previous attempts such as Transformer XL and various other things Way Back in the Day in like 2020 2019 already that did an approach like this newer approaches to extending the sequence length of Transformers are for example things like long former big bird and so on no big bird not big bird that do consider the whole sequence but they do something like they say okay if we process this token right here we only do local full attention and then we're going to do some kind of a sparse attention to the segments that are sort of outside of our local segment and then some special token attention to maybe like the very first token in the sequence because that might be more important or something like this so everyone's been sort of helping themselves with getting around this quadratic attention again this paper considers chunking the text into multiple segments so they process each segment with a Transformer and here think of Bert think of a task where you need to input an entire sequence and then classify that entire sequence in

### Tasks considered [4:30]

fact we can go look at the tasks they consider right now the easiest is a memory memorize task so the memorize task consists of text right there is text XXX somewhere in the text there is a fact the fact might be something like this Daniel went back to the hallway and then at the very end of that sequence there is a question so all in between they just fill in text that's just noise that's just there to confuse the model it's just random text you grab something from Wikipedia just insert it here so at the very end there is a question where is Daniel and the model is supposed to answer hallway either by classification or by pointing into some vocabulary there are various ways to implement that but you can do it um with a Transformer like bird okay this is in essence a sequence classification task into like a preset vocabulary so if this here was smaller than let's say 512 tokens then you could just feed it into Bert and have Bert train and answer the thing however if this is longer than the bird context size then you need to go to a few tricks so the memorize tasks considers obviously um text that is longer than just one segment for the purposes of this paper there's also the detect and memorize a task which is slightly harder so in the memorize task you're guaranteed that the fact is always in the first segment and in the detect and memorize task the fact could be anywhere so you also have to detect where that fact is in any of these sequences that you need to remember and then at the end um at the end answer by the way remember here is only technically done in this you could also like a long form or something like this would not use memory but they would simply have or rely on some sort of attention from the question to the place where the fact was written so that's how a normal Transformer would do it this paper right here will write something into memory and then consider that memory in the last segment right here there is a reasoning task where it's even slightly more complicated so there will be one fact there will be another fact somewhere in the sequence and you need to consider both facts in order to answer the question at the end they do have examples right here but I think the example is kind of wrong so Fact one would be the hallway is east of the bathroom facto is the bedroom is west of the bathroom and the question is what is the bathroom east of and the answer is bedroom but I think like just looking at this I only need Fact Two to answer that I don't even need Fact one so maybe this is just kind of a wonky example that they've tried to give right here I don't see how Fact one is relevant here um but I trust that the actual data set this two-way relation uh Babi task actually has two-way relationship uh data points okay now that we know what we're dealing with let's look at how this model does it so

### Recurrent Memory Transformer [8:00]

what we need to do is at least in this sense as I said we need to have a way to carry information over from one segment to the next segment and so on and they Implement that by forms of these memory tokens right here so in a regular Transformer you simply input your text segment as tokens so there's going to be token token token and you train that using causal uh no sorry invert mask language modeling or something like this what they do here is they add a series a fixed amount of memory tokens let's say there are eight memory tokens here and the output obviously is going to be again the number of tokens right here now whether you train those or not um that's up to you I think they start with a pre-trained bird and um they don't do the masked language modeling during training of this so they only fine tune the task but you still it still output those tokens and you'd also output eight new or four let's say four new memory tokens so four memory tokens go in four memory tokens uh get out now these aren't the same tokens but you just dedicate slots and say these are the memory tokens of the input and output and you take feed them back into the input right here and again there is Transformer layers and you feed there are output tokens and you take them so this is how you can store information let's say there is an important fact right here in the text um the Transformer by means of attention mechanism transforming things across their layer could use that to store that fact somehow into these memory tokens and then in the next segment let's say there is a question right here that it needs to answer that this is the final segment it could consult that memory together with the questions there could be cross attention or sorry well that's technically self-attention there'd be self-attention into that memory to retrieve the fact that was stored there if this is not the last segment it could also simply recognize that there is no fact right here and it could just pass through the memory tokens to the output unmodified so you have various ways right here notably this is distinct from something like Transformer XL or Transformer XL tried to do and for that I'm going to show you the previous paper right here the recurrent memory Transformer paper which was accepted at nurips in 2022 and you'll immediately recognize things because it's as I said this one million tokens paper is our technical report is just an extension of this paper right here it's even the same diagram they actually start with this and they say oh no we actually need to make it more complicated because we have this decoder only Transformers so we have to have like a read a portion of the memory and the right because attention can't go both ways so if as we produce this segment right here we need to consult the read but then we need to write to the right like we need to produce the right tokens that we can then pass into the read tokens of the next segment so it gets more complicated and after that they go a step back for this technical report the Transformer XL what it does is it simply says hey what we can do is we can simply take the of this Transformer layer right here so these are the different layers of the Transformer this would be one Yellow Block in the other diagram of this layer we're simply going to take the last kind of hidden state right here like the last attention I guess key or value query key and value no key and value pair maybe we're simply going to take these states right here and we're going to have this segment do cross attention into this so you can always the previous segment um attention Maps I guess not maps um States like you can do cross attention into the previous segments States like you would do cross attention in an encoder decoder way but you simply say well I do it into the previous segment so each segment here can technically look at the intermediate state of the previous layer and that's how you achieve sort of a temporal notion but there is no there is a stop gradient always because he eventually you're going to run out of memory because if you back propagate you need to consider um you need to keep all of this in memory and that's why it's simply just stop grain they say you can attend but the last segment has no chance of learning what to effectively store for the future so memory has two components one is I'm going to look back into the past that is trainable in Transformer XL you train looking back into the past how you need to do that the other component is obviously what do I need to store can I learn to effectively store something and that Transformer Excel cannot do because there is no back propagation through time to teach a previous layer what to store in its memory output that the next layer can consider the recurrent memory Transformer essentially solves that by doing back propagation through time now I can simplify this a little bit by saying maybe you're confused by you know what is like the memory is here and there are these memory tokens and so on just consider this let's say we have um let's consider this uh segment by segment text segment by text segment so we have segment one that's the input text segment and let's say we have like a hidden State h0 that's uh M like let's call that M memory zero we just initialize that to all zeros that goes in there's a box this is this could be anything like any box at all and then there's going to be an output at this that this is optional we only really need the output at the end and also an output M1 so these two are outputs we can have multiple output heads on a neural network right or we can output one vector and then split it in the middle and consider the left hand part the output and the right hand part the new memory right the new memory will then go into the it will go into this function so segment 2 gets in here there is output 2 and there is memory 2 and so on now the catches the weight of this and this are obviously shared right this is always the same function so in effect the way this looks is it's always the same box and then there is the segment of the ith element and then the memory will always be fed back so the memory of the if segment will always be get back and the output of the ith second and there you have it and we know what this is we already uh no it's like it's an RNN like it's this is just a recurrent now and they don't make a secret out of it like the thing is called recurrent memory mechanism all that is an RNN where the part right here is a Transformer like that that's it um there is the way they arrange the memory and so on and especially in the previous paper with the read and write and whatnot all of that is a technical detail of how the Transformer Works in essence this is a recurrent neural network where the base building block is a Transformer so this is not scaling Transformers to 1 million tokens or anything like this is an RNN that has a Transformer as a building block um I'm not trying to put this paper down obviously they are very clear what they're doing right um I just feel the community has seen the title and it's going like um but yeah so considering this why is this a new at I highly doubt this is a new thing um and it's maybe just written up differently the reason something like Transformer XL hasn't done this because in Transformer Excel I could just do back propagation Through Time what they do in this older paper is they do back prop through time and now obviously they can't unroll for thousands and thousands of segments you can't do that in an RNN either so what you usually do is you back prop three steps or four steps because you need to keep all of this stuff in memory if you want to back prop um through time so you simply don't um also it gets very noisy which is evident in these older papers experiments right here as they up the back prop through time you can see this right here now compared to Transformer Excel and look how Transformer XL outperforms here on test perplexity um as you increase the visible context while training so just how much context you see and you get better as you up the back prop Through Time number of steps but you can see that as you go this here actually gets better than Transformer XL but then um with more visible context it kind of degrades I don't know exactly how to interpret this um you know really well but I can guess maybe that's a property of uh too many too much back prop through time is known to be unstable and they have also combined Transformer XL here with backprop through time and that turns out to work pretty great as well the reason Transformer XL or things like this haven't done back prop through time is because they come from a time when a Transformer XL just fit into memory like it's like okay that's what we can do we cannot do multiple of these and now we can fit multiple of these into the memory multiple steps and so we can do the backdrop through time which is linear in scale then obviously once you have trained that let's say you train it on seven steps which they do that's the maximum they do right here seven segments once you train it then on inferencing it over thousands and thousands of segments is not a problem anymore because if you don't have to do back prop as you go segment next segment next segment you can forget all of the old uh memories and whatnot and hidden States and you can go forever essentially so

### Experiments on scaling and attention maps [19:40]

what's the catch uh so they compare along with like ooh scaling opt you know if we go bigger how does it scale it scales quadratically and we only scale linearly yes of course this is distinctively not the same thing right it's an RNN it's not a Transformer what it does is it brings the power of Transformer like the language understanding um and combines that with an RNN so here you can see as you train on more segments for example let's say this detect this one is a good example this memorization if you train on one segment boom as soon as you evaluate on two segments you go down uh it simply expects the it solves the problem of answering the question using attention right and attention only works within the segment so it can't go to two segments because it doesn't even learn how to use the memories if you train on two segments you can see uh you can go even to three and then you start degrading so it learns a little bit how to use that memory as you go up you train with seven segments which is I'm gonna guess is the highest thing that fits into their GPU memory for training for back prop um as you do that you can see then you have a pretty stable algorithm and all you did is you taught the algorithm of detect the fact store it and then consid and then don't change the memory until you need to consider it right you've taught that algorithm it's very different from being able to you know read an entire book and then draw different things from that and ask questions about it like a Transformer could if it really had the full context size so this is going to be applicable to very distinctive tasks where you need to grab like single facts a few single facts that are dispersed in a very long text that you otherwise don't need to consider 99 of it and you need to figure that out somehow from that text and then carry that over to the end maybe two facts as you see right here so the good thing of doing this is you can bring this power of the Transformer to do the language understanding within one segment right you can make that really powerful um as for example birth and birth derivatives are and you can use that to detect the facts right they can be quite complicated as long as they fit into one segment and then you can use the RNN part to sort of carry that over time what you can do is like you know have complex interdependencies and interrelations and you need to consider many things that are dispersed throughout the text I hope it's a bit clear what works and what doesn't work they have nice investigations into how the model uses the memory so if there is no fact in the input you'd simply see attention like diagonal attention from the tokens to themselves if there is a fact in the memory then you see lots of attention to that to those memory tokens so wherever that fact is it will be stored into the memory and then once you once you see the question or once the model sees the question it's then going to consult the memory so you're gonna see a lot of attention on the on these memory tokens right here so as it reads from the memory again it's fairly cool for the types of things that the these authors here investigate and yeah this plot over here to scale to a million tokens that it simply means it has learned that algorithm and that algorithm um once being learned with like seven segments is now independent of the length like it has robustly learned to ignore the noise

### Conclusion [24:00]

um fairly sure you could achieve that with a bunch of other techniques um but this is one of them uh yeah so that was it for shortly for me um again it is a cool paper and the paper itself and the technical report they're very clear what they're doing but it's been a bit overhyped in terms of what it means that was it let me know in the comment please over hype this video obviously I'll see you bye-bye
