Energy-Based Transformers are Scalable Learners and Thinkers (Paper Review)
47:51

Energy-Based Transformers are Scalable Learners and Thinkers (Paper Review)

Yannic Kilcher 19.07.2025 30 812 просмотров 1 107 лайков

Machine-readable: Markdown · JSON API · Site index

Поделиться Telegram VK Бот
Транскрипт Скачать .md
Анализ с AI
Описание видео
Paper: https://arxiv.org/abs/2507.02092 Code: https://github.com/alexiglad/EBT Website: https://energy-based-transformers.github.io/ Abstract: Inference-time computation techniques, analogous to human System 2 Thinking, have recently become popular for improving model performances. However, most existing approaches suffer from several limitations: they are modality-specific (e.g., working only in text), problem-specific (e.g., verifiable domains like math and coding), or require additional supervision/training on top of unsupervised pretraining (e.g., verifiers or verifiable rewards). In this paper, we ask the question "Is it possible to generalize these System 2 Thinking approaches, and develop models that learn to think solely from unsupervised learning?" Interestingly, we find the answer is yes, by learning to explicitly verify the compatibility between inputs and candidate-predictions, and then re-framing prediction problems as optimization with respect to this verifier. Specifically, we train Energy-Based Transformers (EBTs) -- a new class of Energy-Based Models (EBMs) -- to assign an energy value to every input and candidate-prediction pair, enabling predictions through gradient descent-based energy minimization until convergence. Across both discrete (text) and continuous (visual) modalities, we find EBTs scale faster than the dominant Transformer++ approach during training, achieving an up to 35% higher scaling rate with respect to data, batch size, parameters, FLOPs, and depth. During inference, EBTs improve performance with System 2 Thinking by 29% more than the Transformer++ on language tasks, and EBTs outperform Diffusion Transformers on image denoising while using fewer forward passes. Further, we find that EBTs achieve better results than existing models on most downstream tasks given the same or worse pretraining performance, suggesting that EBTs generalize better than existing approaches. Consequently, EBTs are a promising new paradigm for scaling both the learning and thinking capabilities of models. Authors: Alexi Gladstone, Ganesh Nanduru, Md Mofijul Islam, Peixuan Han, Hyeonjeong Ha, Aman Chadha, Yilun Du, Heng Ji, Jundong Li, Tariq Iqbal Links: Homepage: https://ykilcher.com Merch: https://ykilcher.com/merch YouTube: https://www.youtube.com/c/yannickilcher Twitter: https://twitter.com/ykilcher Discord: https://ykilcher.com/discord LinkedIn: https://www.linkedin.com/in/ykilcher If you want to support me, the best thing to do is to share out the content :) If you want to support me financially (completely optional and voluntary, but a lot of people have asked for this): SubscribeStar: https://www.subscribestar.com/yannickilcher Patreon: https://www.patreon.com/yannickilcher Bitcoin (BTC): bc1q49lsw3q325tr58ygf8sudx2dqfguclvngvy2cq Ethereum (ETH): 0x7ad3513E3B8f66799f507Aa7874b1B0eBC7F85e2 Litecoin (LTC): LQW2TRyKYetVC8WjFkhpPhtpbDM4Vw7r9m Monero (XMR): 4ACL8AGrEo5hAir8A9CeVrW8pEauWvnp1WnSDZxW7tziCDLhZAGsgzhRQABDnFy8yuM9fWJDviJPHKRjV4FWt19CJZN9D4n

Оглавление (10 сегментов)

Segment 1 (00:00 - 05:00)

Hello, how's it going? Today we're going to look at this paper right here. Energy- based transformers are scalable learners and thinkers. This paper combines the uh general concept of energybased models with transformers and suggests a learning paradigm for it and shows how that learning paradigm applies and how it can be made scalable. uh with it they do achieve some pretty promising results in that the experiments are on smaller scale data but in terms of language modeling uh video modeling and things like this the scaling trends show that this actually might be a more scalable architecture and a more promising architecture once you do go large scale. So I would say this is a pretty interesting direction of research and uh promising although not uh obviously verified yet at scale but I think that's what does makes it make it interesting. It's a bit different from what you're used to but it also goes back to concepts that have been around for a while and I'm very happy to see more energy based modeling work here. So let's dive in. Uh the general concept that the authors here ask are is it possible to generalize system to thinking approaches and develop models that learn to think solely from unsupervised learning. Very big and bold question. Obviously uh we're not talking about can we sort of build better models or something like this like no we are going to um system two thinking which is sort of the more logical slow thinking that humans do and we're going to do that in machines and they here largely refer to inference time computation techniques and we'll get to that in a bit. So they're asking can we teach computers how to do system two thinking and can we do so without any supervision so from purely unsupervised learning and in a way that is generalizing again very big question. So by system two thinking techniques obviously if you um know from cognitive science uh system one thinking broadly refers to sort of quick um quick and intuitive thinking. Now this is what how you do most of your day. uh if you how you go about it. If you're sort of I don't know want to eat a banana like almost all of your actions are going to be system one thinking. You sort of subconsciously walk to where you store fruit. You subconsciously grab one. You don't have to think about walking about extending your muscles. You may not even have to think of how to, you know, peel and eat and so on. All of this is just super subconscious. it's ingrained and you don't have to think about it. However, system two thinking kicks in whenever system one thinking is at its limits. Although, I'm going to guess it's debatable exactly what the difference is and if system whether system two thinking is sort of an extension of system one thinking, whether there's a spectrum or whether this is something completely different that you know is is totally separate and the two things work together. I have no idea. But in general, system two thinking is characterized by being slow and being explicit. So whenever you sort of in your mind talk to yourself and then sort of work through a logical series of steps to arrive at a conclusion, that would be system 2 thinking. And it reasonably said uh the authors here say most of machine learning so far has been largely in the domain of system one thinking. There is a bit of debate about that. So some people show that for example if you have a transformer and that transformer has many layers and you put a piece of language down here and up spits the next token prediction. In here you have multiple steps of computation given by layer by layer and therefore it's completely conceivable that there is one or another form of quote unquote thinking happening there however you want to define that. However, um largely you could say that sort of building a model and then doing a forward pass and taking its output that's kind of analogous to the system one thinking whereas system two thinking is much more um if you sort of do that

Segment 2 (05:00 - 10:00)

explicit. So let's say you do have that transformer but you use it to do chain of thought prompting. So you input something and the transformer outputs thinking tokens that refer back to itself obviously. So not you auto reggressively feed the thinking tokens into the transformer and so on. So you would produce an entire sequence. You use the transformer over and over again and you have some explicit thinking going on right here and eventually your output would uh rely on the fact that you have done that thinking. The authors here largely go for the approach that where they say the moment you use a trained model at inference time more than once sort of the moment you put more computation into um getting an answer from the model than simply doing a single forward pass. That is part of what they call thinking. Again, thinking is a big word, but to them, and that's only really at the end of the paper, thinking is actually not even an act. It's a metric. And their metric of thinking is really sort of how much better can we get by doing multiple forward passes through models rather than a single forward pass. You might already know a couple of models that do this apart from autogressive transformers. So for example, recurrent neural networks sort of intuitively do some sort of multiple forward passes, although you could debate that and say processing an entire sequence is simply one forward pass. Other models you could say are diffusion. So diffusion models uh also do multiple forward passes through a model in order to arrive at their output. Again, you might debate that and say, "No, no, no. Actually, that's just one forward one overarching forward pass. That's just kind of one inference through the model is well, you do all these computations. " Um, so again, all of this is debatable, but I'll leave it at that. You make your own decisions about what thinking is and what it means and which models are doing thinking and which ones aren't. This paper right here is looking at energy based models. And the way you do inference in energy based models is that you do multiple forward passes. And that's why they call their energy- based transformer variants thinkers because it does that because you can put more compute in and you get a quote unquote better answer. All right. Yeah. Here this is a bit illustrated. So in an autogressive transformer you would simply have your sequence uh and you would predict the next token. In an RNN you do multiple computations but again you take your sequence you predict the next token here. A diffusion transformer um obviously does some reverse inference here. and an energy based transformer. We'll get to that in one second. They also talk about sort of the current trend of reasoning models and those reasoning models again they largely are trained with uh chain of thought prompting and things like that you using reinforcement learning. So they're saying that um this usually only works in domains where rule-based rewards can easily verify answers such as math and coding. Uh so you are only it's only applicable to a small range of problem types meaning that um yes you can do reinforcement learning but you somehow have to give a reward and that reward usually is derived algorithmically. So you give it a math puzzle and um you already know the answer or you give it a coding um puzzle and you can verify that the unit tests pass or that it compiles or that it gives a correct output. So reinforcement learning is fantastic. But other than that, if you don't have supervised data, you're kind of out of luck. So we need things um we need to train these models in a way that they are able to do inference time compute um without necessarily relying on reinforcement learning. Although what you'll see is that um even though they contrast what they're doing with all of these models right here, there's actually not nothing stopping us from combining the two. To me, reasoning as the models currently do, chain of thought and all of that is completely

Segment 3 (10:00 - 15:00)

orthogonal of what these energy based models do. And it's quite conceivable that the energy- based models are also being used to do chain of thought. All right. Again, their central question, can we rely entirely on unsupervised learning to develop system 2 thinking? This uh such a capability would enable generalization of current system to thinking approaches to any problem, any modality and avoid the reliance on external human reward or model supervision. for that they say okay there are three key facets of system two thinking and from now on when you hear system two thinking just sort of think the goals of there like I'm what I'm pretty sure and I obviously have no proof of that what I'm pretty sure is they started from the models they first developed energy based transformers because energy- based models are an idea and transformers you combine them you define energy- based transformers and then they sort of worked backwards and said what are the properties of these models and oh those just magically happen to be uh the three key facets of system 2 thinking right I don't really buy that uh I think there was some reverse engineering happening right here so just take these as three nice properties of energy- based models rather than system two thinking dynamic allocation of computation in an energy- based model and you'll see that you can choose how much computation you want to invest into inference time. Um so if you have an energy based transformer and that is um a model on language and you wanted to predict the next token, you are not limited to just one forward pass. You can put in more computation and get a more accurate distribution over what that next token should be. Facet two, modeling uncertainty in continuous state spaces. Again, this is a property of energy- based models that they can give you a notion of uncertainty uh in their predictions. And so, EBMs naturally can naturally model uncertainty without having to model exact likelihoods by modeling the relative unnormalized likelihoods of predictions. As the real world often contains many inherently unpredictable elements, for instance, when a pedestrian might emerge from behind a parked vehicle, the ability to express uncertainty in predictions is essential to being cautious and is a natural capability of humans. And then facet three, verification of predictions. So in an energy based model we are doing a little bit uh what you might be used to from GANs from generative adversarial models where there is a verifier or a discriminator involved. So a model that can assess how good quote unquote something is. And so naturally energy- based models have this inherent ability to not just predict but judge uh predictions. So how do we do this? Um here again you will have uh an inference across an energy based models. You can see this is next token prediction. So you do have your context. The dog caught the and then the question is in an energy- based model, how do you get a prediction out for the next token? And the way we're going to do this is we're not just going to predict a token, but we're going to output an entire distribution over tokens. Now, this is still the same as in a classic transformer that it also outputs a distribution over tokens, but the way we do it is different. we're actually going to start with a completely random distribution and then we're going to step by step refine that distribution and sort of shape that distribution until we end up with our final distribution. And again, we can choose to do this for longer and get a more accurate um output or we can choose to do this for shorter and get a less accurate output. This again um might remind you of diffusion models but it's a bit different of how we do it here but just be mindful that um the way we produce outputs is not through a forward pass and then we have the output but the way we produce outputs is we start with something random and then there's a process that obviously involves the model uh that allows us to step by step shape um the outputs. Now yeah again so this paper is laced with philosophy.

Segment 4 (15:00 - 20:00)

philosophy. We propose viewing thinking as an optimization procedure with respect to a learned verifier which evaluates the compatibility between an input and candidate prediction. So yeah they they propose viewing thinking like the giant word and concept of thinking to happen to exactly align with what energy based models are doing which yeah again I strongly feel there is a degree of reverse engineering happening right here. All right I promised you no more philosophy and we'll dive into the models themselves. So what is an energybased model and what h how do you go about um doing inference in one? So an energy- based model is some something that works together with uh what we call an energy function. So an energy function typically maybe called e um has two inputs. So let's call it x and y. And um the energy function is supposed to tell you is supposed to give you a high number if X and Y sort of are compatible in some way. I'm going to draw a heart right here. And it's going to give you a low now wait other way around. See low energy means they go together. They're nice. They're close. Uh they just fit right and then a high number if X let me draw a little lightning and Y they don't like each other they don't go together. Now you might complain that this is still an incredibly abstract concept and that is true. Um it obviously depends on what you want to do. So let's say you want to do next token prediction, right? X would be um the context, the prompt, whatever you want like the the partial text and Y would be a distribution over the next token, right? Like a distribution over vocabulary that represents what the next token should be. So the energy would be low if that distribution points only to tokens that are actually you know valid in the language as continuations for X and it would be high if that distribution is somehow different. If you were doing image dnoising then um X would be the noised image and Y would be the dnoised image of the same image. However, if X is the noised image but Y is a different dnoised image, uh that energy would need to be high. So the energy function is what we train, right? And that is a parameterized function and that's going to be a transformer in this case. So we're going to have a transformer model. So the the whole model here is going to be this energy function. We're not going to need an additional model than this. This is the entirety of the learned parameters. This is different than a GAN. In a GAN, this would be the discriminator and you would still need to train a generator to work against the discriminator. In an energy- based model, the energy or at least in this formulation here, the energy function is all you need and the energy function is what you are training. So you were training um you were training a parameterized representation of a function that and then you feed it with data. You're kind of okay here is a sentence the dog blah blah and here is a distribution over next tokens and you train it to be low if they go together high if they don't go together. You can do this in various ways right you and we get to that later. One easy way is to do contrastive training. So you take the same context and you take the distribution the one hot distribution of the correct next token and then you take a one hot distribution over the incorrect next token and you do a contrastive training. You say the correct um the correct the the output where these two are where this is the correct next token should be lower than the output of these two where this is the incorrect next token. However, that is not a scalable way to go about it and we'll see later. So the energy function is the only thing we need.

Segment 5 (20:00 - 25:00)

Um so you might object and say hey what's the difference to loss uh because like a that just seems very much like loss the loss function is also exactly like this right um and the difference is when you use it. So an energy function you are supposed to use at inference time whereas a loss function you're supposed to use at training time. In fact, training the energy function here, training these parameters itself has a loss function associated with it. Right? So um the loss function in the contrastive training might actually be so the loss function might actually be that the energy of x and y1 must be lower or like minus the energy x and y2 right this would encourage this to go down and this to go up and that's what we want because this is the correct one and this is the incorrect one. So a loss function is a training objective and an energy function is an inference objective and that also gives you a hint of what do you actually do with this energy function. So we only train a model to predict a single number, right? This thing here, this is going to be a real number. It's not going to be an output of anything. It's going to be a number. And so what do you do if I only give you a function that where you can put in the current half sentence and then a distribution across like across the next. So here is the current half sentence and this here is a distribution over the next token and I'm simply going to tell you a number and that number is going to be higher if it's bad and lower if it's good. Well, if you just have that, you're simply going to take the same half sentence and then you're just going to try a whole bunch of these distributions, right? And you're going to see, okay, which one is the lowest, right? So you may be able to one hot encode all your vocabulary and just slap everyone in and then seeing which one is the lowest. But in fact that's not the whole space of distribution because distributions can be obviously not just one hot. So you could think of just slapping in every possible distribution and finding the lowest value right there. And now we come to the point where you might recognize, hey, this sounds a lot like optimization. And that's exactly how we go about it. So if you think of it, um if you have a trained energy function and that trained energy function is such that the if things go together that the energy function actually tells you yes, this is low or that's if your energy function is well trained, you're going to have that. Um so if you do have that then um you can simply run an optimization procedure at inference time in order to get a good output. What do I mean? So look at this. This here is um a 2D representation of the of a uh energy landscape. Again this is not loss. This is at inference time. So imagine that this axis here is uh this is a bit of a fat. Imagine that this axis here is um is your your uh context, right? This is all the possible context. No, actually what you can do is if you if Oh, yeah. Okay. This axis here is all the possible contexts, right? They are discrete. I get it. But we'll just say okay, these possible contexts are continuous. So that's our X and then here are all the possible um distributions like next token distribution, right? So this distribution right here might be here and then very, you know, like very spiky distribution or something might be here. um and so on. And you can see if the energy function is well trained, what we're trying to find are the minima of the energy function given one of the it one of the inputs is actually our x. So we fix the x and we change the y and we run an optimization procedure over the y. So in this case, I've actually done the wrong drawing. So uh both of the axis right here should obviously this

Segment 6 (25:00 - 30:00)

should be y one y like this these are now individual dimensions over your his over your distribution this distribution space. Now the distribution is obviously um way um way um more highdimensional. In fact it has the dimensionality of the whole vocabulary. But imagine your vocabulary just has like three different entries. So um because you need to normalize it, it's a two-dimensional space and that's what you optimize over here. So you're trying to do to find the minima. And how do you do that? Well, by doing gradient descent. So we're doing we're starting in a random output, a random distribution over the next token. And then we're doing gradient descent on the energy function back propagating that to the input. Right here again we have an x and we have a y like an estimate of y an initial guess or an intermediate step. We're putting both of these through a multi-layer transformer that gives us an energy function. And then we're going to take that energy function. We're going to calculate the gradient of the energy and we're going to back propagate this through the transformer to here. So this is going to be um the gradient with respect to yhat of our energy function where x is fixed right like x is fixed and we want to know how do we need to change yhat and then we do a little step in that direction and then we reevaluate it and then we do a little step and we re-evaluate it. Cool. So we're doing gradient descent at inference time. That's at least one way of doing inference in energy based models. Okay, we optimize against the energy function. You can see that this has some nice properties like if this is well behaved. Then if we do some wiggling around here, we can get sort of the variance of stuff. We can get the uncertainty, right? Is it very wide? Is it very narrow? Um is it very bumpy? And so on. uh we can also easily rec more easily recognize out of distribution data and things like this. So lot of excellent properties but the downside is we can't just get a poof uh an output in a single step. We do have to run this optimization procedure. All right. So how do you train this? a model? Oh by the way energy is not a normalized quantity. So energies are always unnormalized uh for scalability reasons. So all you really know is it less is it more. Um yeah. So this would be the in inference procedure. uh we're going to do we're gonna sample some initial guess and then we're going to run gradient descent on the energy function with like some step size right here and we're going to return the minimum that we found uh along the way. Okay, again this is the way we do inference. So how do we train a model? And there are some challenges right here. And the challenges are that if you just naively train the energy function to be sort of high on incompatible inputs and low incompatible inputs, you were going to end up with a very jagged and a very uh nons smooth energy landscape, right? This is it's just going to be like okay a lot of and then wherever your data is and then especially in high dimensions. So therefore what is really important are some energy landscape regularization techniques to be applied. They have three of them. Um one is a replay buffer. Uh and that's also often used in reinforcement learning where you have your trajectories and you sort of keep them around to grab them to train. And you usually do that just to bring some variety into the system and get away from your very local state and current data. Another thing is um they actually they add noise here and it's probably good because I forgot one thing and that is yeah how do you train and the trick here is that you train considering the way you do inference. Okay, there are two ways to train um or this paper says okay there are two ways to train these things. One is

Segment 7 (30:00 - 35:00)

contrastive. We already looked at that. What is the other one? Well, the other one is saying hey my inference my ultimate y is going to be um y0 minus alpha gradient of energy of x y 0. This is just a single step right I've done a single step of gradient descent optimization but you can see well this is my output right and therefore if I now define a loss function on my output right and so I define a loss function on y and y what's the correct one y from my data set right so this is the distribution over the next token here and this is the actual next token one hot encoded. I can define a cross entropy loss and I can use this here as the this is effectively f of what is it f ofx I guess y equals f ofx yeah so if you didn't know how this came about what would you do you would simply say okay let me derive The gradient right here f has parameters. of the parameters of the loss of f of x and y of the correct label and you would back propagate into f through uh to the parameters of f. You can do that here. Here are some parameters. This is a completely linear operation. This here I'm not sure derable operation. And so what you end up doing is you end up actually backropagating through your optimization steps. So you're going to backpropagate through an operation which already has a gradient computation inside. And you know what that means? That means you actually need second order derivatives right here. However, the second order derivatives aren't too bad um because you can do uh in this case so you require so importantly this loss is back propagated through the entire optimization process requiring second order derivatives i. e the gradients of gradients. These are computed efficiently via Hessen vector products which scale linearly with model size. So it's not the most flop efficient thing in the world, but it doesn't quadratically explode um if you scale up. So again like this might be a bit weird to people who are really just used to training forward pass transformer models but we are we're going to train um such that the optimate the inference process itself is considered during the training. So the training consider the training is okay the loss represents finding a good output including that inference time gradient descent process. So we train with the inference in mind and now in to make that scalable we do need to regularize and one part of regularization is to add this noise right here. So when we do the gradient descent at training time right at training time we're also going to do this gradient descent. And so we have some sort of energy landscape. We're going to do the gradient descent like boop boop. Okay, going here and then calculating the loss of this and then back propagating through this inference right here. Um we are going to also add noise to every one of those steps. And the reason is this helps generalization and this helps smoothness. So if you are here, let's say this is a top view and your optimization path looks like this. What you really want is you want by doing noise, you're sort of washing this out a little bit, right? And so instead of treading a path that is sort of really thin right here and the rest, you know, here and here is undefined, you want to make that path bigger. You want to sort of broaden the landscape where you reach during training and by that you make the landscape smoother. You do sacrifice a bit of accuracy obviously for this but you make the landscape a lot smoother by adding this noise during training. And so at inference time when and we're looking for generalization here. So data that we haven't seen during training time. If at

Segment 8 (35:00 - 40:00)

inference time data is close to what you've seen during training, you are not hopelessly lost because you will still be inside of this sort of more wide band that you've seen and you'll be able to follow and make something sensible out of that inference time data. Very old trick to add noise obviously but uh very effective and they do this here as well. The other one is um by randomizing the gradient step size and the number of optimization steps significantly improved generalization. Again the you don't always want to do exactly five steps with exactly the same step size. If you vary things up a bit um then you can obviously gain a lot. And uh even additionally here because we are putting in compute at inference time because we're doing multiple forward passes if I already train doing sometimes less sometimes more optimization steps I will end up with a model that is much more accustomed to sort of giving me good answers for all of these situations. And that hopefully generalizes to a way where I could then also extrapolate at that inference time put in a lot more steps than I've done during training time just because I I've sort of trained the model to be flexible to how many steps I do and I hope that in itself obviously generalizes. So those are the training um techniques that they have right here and then they also introduce their model. So their model is a transformer. Uh so they're introducing they're combining effectively energy- based modeling with transformers and they're saying okay energy based models have traditionally encountered difficulties with three characteristics which are parallelizability, stability and scalability. So energy based models really bad at this, transformers really good at this. So transformers are good in all of these three things that the energy based models are bad at and so it seems natural that they go together. So they uh present EBTS energy based transformers which are transformer implementations designed for EBMS. Um this is a challenge from an engineering perspective. Uh especially the sort of decoder only triangular attention transformers need a lot of considerations so that you don't get information leakage across these multiple uh inference steps that you do in EBM. So you no longer just do one forward pass, you do multiple. And if you want to benefit from that parallelizable training sort of doing um parallel computation with this triangular attention, you have to pay very close attention to how your your data flows. They've implemented all of this and their code is available. So that is very cool. Um they're going to research two different things here in the experimental section. One is learning scalability which is sort of the traditional thing which is how quickly can models fit the pre-training data and the other one is what they call thinking scalability. This is effectively um can we determine whether model performance improves with increased thinking and by that they mean increased number of forward passes at inference time. So if we put in more compute can we get sort of better uh and can we get better in a more scalable in a more rapid way than other models. So the first thing is they compare with this with transformer plus+ that's a sort of a training recipe to train um next token prediction single forward pass transformers and you can see right here from these graphs that indeed while the energy based um transformer does start out on a bit of a disadvantage it quickly gains over the classic transformer as you for example scale up training data, scale up batch size, and scale up the depth. Again, these models like what we're doing like what they're doing is they're effectively showing like look, the trends are really good. Some of these trends aren't that, you know, materialized yet. Like you would need to extrapolate somewhere down here to actually see. And there is still the absolute possibility that at large scale none of this trends actually go the way that they seem to go. Um but still it's quite it's quite promising. So this is uh training scalability where the energy based transformers already sort

Segment 9 (40:00 - 45:00)

of uh scale better. Now keep in mind the xaxxis right here. The x-axxis represent you know very particular quantities. The fact of the matter is still that in a regular transformer one forward pass one training step is one forward pass. And in an energybased model, one training step means you first have to do the inference procedure during the forward pass and then you have to back propagate through that inference procedure which all in all is not you know is quite a bit more br on your GPUs than a single forward pass transformer. So the x-axis here if they're like okay batch size number of tokens and so on that's all fine. um in the time domain you'll see this is quite and that's what we have right here. So you can see in terms of training flops there is and this is a log scale right the energybased transformers are significantly away from the classic transformer. However they scale faster. What they mean is that this slope right here is ever so slightly and you can also see that right here but this is embedding dimension. Let's stay with flops. The slope here is ever so slightly steeper than the slope here. And uh therefore if this trend continues there's actually a future where um because energy based models achieve better perplexity. you know, the additional flops sort of cancel out and the um you would need to invest a lot more training flops into classic transformers than into energy based transformers because the energy based transformers are just so better at sort of taking in them those making use of those flops. So not at this scale but conceivably if you believe the trends and you extrapolate um then that will at some point cross. So the second part is thinking uh the thinking so at inference time can we put in some more work and their answer here is yes indeed. So you can see while the classic transformer obviously does not scale with number of forward passes, it's going to for the same input, it's going to give you the same output no matter how many forward passes you do. The energybased transformer starts out weaker, but then as you increase the forward passes, it uh obviously gets stronger. And that's not a surprise because you do start out with something completely random, right? And then after one forward pass, you've done sort of one inference step, one gradient descent step in the energy landscape. And so you do need to do a couple to um to get ahead. Um and yeah, they do end up ahead right here with a gap to the classic transformer. Another thing they can do is they can actually look at the energies uh sort of across thinking steps. So how do the energies evolve? And they see one thing and that is that different tokens um have different sort of um energies. You can see here the light colors represent sort of lower energies. And you can see that throughout the inference of a sentence uh you do get significantly lower energies at tokens where they say okay it's a lot more clear easy. Um also here you can see at easy words so to say energy being lower and that represents a degree of sort of self assessment of these models on and also an opportunity maybe for us to put in less energy on these steps. So this ability to put in different amounts of energy, different amounts of flops into um the inference procedure combined with the fact that the energy function itself can tell you something about the current state of things and about the uncertainty and about the easiness could give rise potentially in the future to a very dynamic inference procedure where you don't always have to do oh we always do a 100 steps or something like this, right? It's a little bit the same idea as the sort of speculative decoding and things like this where oh because you can you know something more um you can maybe save some computation. What I find interesting is that the remarkable thing that there seems to be not a whole lot of difference beyond step beyond iteration one. So obviously at iteration zero the energies are you know something

Segment 10 (45:00 - 47:00)

very high right but then after iteration one you sort of seem to be in the minimum already and the further iterations they don't seem to do that much anymore. This is I think confirmed by this plot right here where you do make the most gain at the beginning. Then again this is uh very common in optimization. Um yeah, not much more other than sort of more of the same. Uh I don't want to go too much into this. They do video prediction as well and so on. Um and um compared to what is that diffusion transformers uh I hope you get the idea of what this is. The scaling trends look promising. I can say that. But obviously again they because of resource constraints um have not tried at larger scales and the base case itself is such that you do need just to expand like your fixed cost to work with energy based models is a lot higher. However, it could in fact be that at large scales that fixed cost is amortized by the gains that you make and it could actually be more beneficial to go with energy- based models than with sort of classic models. In all of this, I find the paper very cool. Uh, but I do feel like they bring a lot of philosophy in it and they compare with models that are not necessarily comparable. Like I don't think chain of thought thinking or or reasoning models or anything like this have anything to do with this unless you say oh well they also do multiple steps and so on but that's to me very abstract to me in an energy based model this multi-forward pass optimization is just the way you do inference and you can view that as one inference step and then once you have that you can might as well do chain of thought with it. You might as well do um reasoning with that. train that with reinforcement learning. Right? So these things to me have sort of not much to do with each other. Um the energy based models have nice properties no matter what. Right? Okay. I don't want to uh go and keep you here for longer than necessary. Please give the paper a read if you are interested. A lot of thanks to the authors. We did discuss this in our discord paper discussions and actually um the lead author here was part of those discussions and we're obviously super thankful for that. That is very cool. If you are interested, come join our discord. Uh we have a lot of paper discussions all the time and if even if not, I'll see you around. Bye-bye.

Другие видео автора — Yannic Kilcher

Ctrl+V

Экстракт Знаний в Telegram

Экстракты и дистилляты из лучших YouTube-каналов — сразу после публикации.

Подписаться

Дайджест Экстрактов

Лучшие методички за неделю — каждый понедельник