Hello everyone, welcome back to this AI Coffee Break. Today, the topic might appear technical, but if you’ve ever wondered what energy-based models are and how they differ from standard neural networks, this video is for you. We’ll break down how a new paper combined energy-based models with transformers — which are the backbone of large language models, vision models, and even video generators — to create Energy-Based Transformers. The cool part is that unlike standard transformers that always spend the same amount of compute per token, Energy-Based Transformers can think longer before they produce each token, stop early on easy tokens, and even tell us when they’re uncertain. Grab a cup of something and let’s dive in! Energy-Based Transformers were introduced in this paper by Alexi Gladstone and colleagues. The idea is to make transformers part of so-called energy-based models – a class of models that don’t directly output probabilities but instead assign an energy score to how well a guess fits the input. Energy-based models, or in short EBMs, flip the usual way neural networks work. A standard neural net takes an input x — say an image or a question — and directly outputs a probability distribution y over possible answers. An EBM, in contrast, takes both x and a candidate y as input. The model’s output is then an energy score to this pair: low energy corresponds to well-fitting answers, high energy corresponds to unfitting answers. The key thing to keep in mind is this: in EBMs, the answer you want to evaluate isn’t the output — it’s part of the input. The model’s job is not to generate probabilities directly, but to judge how compatible an input–answer pair is. One can train such an EBM via contrastive pairs: show it a question with the correct continuation from the dataset, and the same question with a random wrong continuation, then push the energy down for the correct one and up for the wrong one. But this approach is inefficient, because in high-dimensional spaces like language and images, there are essentially infinite possible wrong answers. You’d need an enormous number of negative samples to really teach the model where the good valleys of the energy landscape are. If you only sample a few negatives, you risk wasting computation on trivial cases that the model could already reject easily, while still failing to cover the important regions of the space. The alternative training procedure — and the one used in Energy-Based Transformers — is more complicated but much more effective. Here is the idea in a nutshell and we’ll go into more detail in just a second: Instead of relying on random negatives, you treat the prediction itself as an optimization problem. You basically teach the model to shape its entire energy landscape so that gradient descent naturally flows toward the correct answer. You start with a random guess for the answer distribution, which goes into the energy based model which is implemented just as a normal neural network, could be a transformer. This EBM outputs a scalar, which is the energy, let’s say 122 in this case. Then the EBM refines this energy step by step through gradient descent on the energy, and after several steps you compare the refined guess to the true answer using a standard loss like cross-entropy. So, in this way, the model learns to shape its entire energy landscape so that gradient descent on the energy naturally flows toward the correct answer. This was a bit quick, so going slow and in more detail, the training looks like this: Via training, you want to teach the model to assign high energy to the incorrect token probability distributions and low energy to correct ones. So at step t=0 you feed the EBM the context x and you also initialize a guess for the probability vector, call it ŷ₀. And now, you want to find the probability vector which minimises the energy under the current EBM parameters. For this you go step by step, you update the guess ŷ via gradient descent on the energy: subtracting a scaled gradient of the energy from the initial guess. If you look closely, you notice that this is the standard stochastic gradient descent formula, and while usually you have weights theta, this time you have y. Also, usually you minimise the loss L, this time the energy E. So now, we are doing backprop on the frozen parameters and we update only the inputs ŷ (like one would do for adversarial examples).
Segment 2 (05:00 - 10:00)
After N refinement steps, you end up with ŷᴺ — a refined probability vector. So, you do a fixed number of steps, or stop when the energy does not decrease anymore. And now, after having minimised the energy under current EBM parameters, we are ready to see whether the energy assignments make sense and update them accordingly. For this, you use the training data to derive a training loss. We compare ŷᴺ against y which is the true one-hot target vector using a standard supervised loss — typically cross entropy for language modeling. This loss is then used in normal backpropagation to update the model parameters via gradient descent. This formula looks easy enough on paper, being seemingly the normal gradient descent that we do in training neural networks, but the implementation of this is quite tricky, because but this loss depends on ŷᴺ — the refined guess after N steps of energy minimization. That means the refinement process itself, which was gradient descent on the energy, now sits inside the computation graph. So when we backpropagate through the loss, we’re differentiating through those gradient updates. In other words, we need second-order gradients — gradients of gradients. The authors compute these efficiently using Hessian–vector products, which scale linearly with model size, so the cost stays manageable. During inference, things change. You no longer have the ground truth y. Instead, you only have the context x and the model starts with one or more random guesses for the next-token distribution. Each guess gets refined step by step, again by minimizing the energy until the energy can’t be lowered anymore or for a pre-determined amount of steps. Doing it flexibly has a very human-like flavor: when a problem is easy, the model needs only a couple of refinement steps. When the problem is hard, it can spend more time — more computation — to get it right. So, this was the way in which EBM training and inference happens. Now, what’s the big deal about EBMs and why would we like to have them as part of Transformers? Well, three things. First, EBMs can dynamically allocate more computation to harder problems because hard tokens to generate would need more steps to minimise the energy. nd, they can self-verify: the energy score itself tells the model whether its prediction is good or not. This can be useful for comparing the quality of different generations. third, they can model uncertainty — if the energy stays high, the model basically knows it’s unsure and the user might find this information useful. Now, back to energy-based transformers, or short EBTs. For making an energy-based LLM, the authors trained an autoregressive EBT from scratch on RedPajamaV2. As explained before, instead of outputting probabilities directly like an autoregressive LLM, the EBT took both the input sequence and a probability vector as input, and the transformer inside learned to assign low energy if the guess matched the real next token. During training, the guess probability vector was refined through gradient descent on the energy for a fixed number of steps, and after that, the model weights were updated so that refined probability guesses moved closer to the true target. So, unfortunately, in this paper, energy refinement is run for a fixed number of steps. That’s a bit of a missed opportunity, because the real promise of this approach lies in being adaptive: spending just a few steps on easy tokens and many more on harder ones. This flexibility would mirror human reasoning, where we go through simple problems quickly but take extra time to work through the difficult ones. The reason they chose fixed number of steps is because EBMs are not known for stable training behaviour and a fixed number of steps helps with stabilisation. To further stabilise EBT training, the authors had to use three well-established tricks. Let’s go through them one by one. First: adding noise to the refinement steps. If you always update the guess deterministically, you risk the model getting stuck in some local valley of the energy landscape. By adding a bit of random Gaussian noise at every gradient step — a technique called Langevin dynamics — the model is forced to explore slightly different directions and go into territories around the usually explored data. So, the noise encourages it to explore more of the surroundings of the training data instead of just overfitting. Instead, it learns to handle more diverse cases around the training data, which are cases that it might encounter at test time. Second: the authors use a replay buffer to stabilise training. Normally, every
Segment 3 (10:00 - 14:00)
refinement starts from a completely random guess. But if the model only ever sees that, it won’t learn how to handle partially optimized states — which is exactly what happens during inference when guesses are gradually refined. To fix this, they store past predictions and intermediate guesses in a buffer. During training, the model sometimes starts refinement from one of those saved states instead of pure noise. This exposes it to a richer variety of “in-progress” situations and helps it generalize better. Third: randomizing the step sizes. If every gradient descent update has the same fixed step size, the model could learn one very rigid optimization path. That’s brittle: if you change the number of steps or the difficulty of the problem, the model might fail. So instead, the authors randomize the step size and the number of steps. This forces the model to cope with different trajectories of optimization, making it robust to variations and better at generalizing. Together, these three techniques regularize the energy landscape: the noise keeps it from learning only narrow basins around training data that do not generalise at test time, the replay buffer ensures the valleys are well-shaped even mid-way down, and random step sizes prevent the model from memorizing one descent path. So what about the results? Honestly, the paper reports quite a confusing mix of different figures using different model sizes, training setups, and data budgets. But the overall rough picture is this: On validation perplexity, a vanilla transformer wins at first, but then the EBT catches up. If this trend continues for more than just 6 billion training tokens to something like trillions, it could look amazing for EBTs, but right now, we don’t know. And sure, right now EBTs are about ten times more expensive in FLOPs than a transformer at the same perplexity. But their scaling rate is, at least on this graph, slightly steeper: so, the more you train them, the more efficiently they improve. That means at trillion-token scale, they might actually surpass transformers both in quality and efficiency, but who knows; these lines are a fit over just 5 data points. Also, EBTs can benefit from more energy refinement steps before committing to the next tokens, while vanilla transformers always have a constant budget and performance per token. On GSM8K, on BigBench math and syntax tasks, EBTs results look good, though I am confused by the experimental settings of this paper and would like to see such results at larger data and model scales to be actually convinced. And EBTs aren’t just for text. The authors also tested autoregressive video models and bidirectional image transformers, and on these preliminary experiments, EBTs again appear to scale better and generalize better than diffusion transformers or standard ViTs. But as before, I’d need larger scale expriments to be convinced. So, all in all, EBTs could be a recipe for models that don’t just guess, but actually think — dynamically adjusting effort, self-checking predictions, and admitting uncertainty. In other words: closer to human-like System 2 thinking – that’s what the authors motivate the entire paper with and to me, it is a bit too philosophical to be honest, especially since they have no experiment with dynamic on the fly allocation of compute, which is a bummer. Of course, there are still challenges: training and inference are more expensive, hyperparameters need careful tuning, and large-scale experiments are still missing. The direction looks promising, but I am not yet fully convinced at this point from the experiments currently presented in the paper. What do you think of energy-based models? Did they convince you as them being the next do-it-all architecture replacing autoregressive transformers? Let me know in the comments. If you liked this paper breakdown, give the video a like and hit subscribe for more AI Coffee Breaks. See you next time. Okay, bye!
Ctrl+V
Экстракт Знаний в Telegram
Экстракты и дистилляты из лучших YouTube-каналов — сразу после публикации.