# Characterizing Test Time Compute on Graph Structur… | Kudzo Ahegbebu | OpenAI Scholars Demo Day 2021

## Метаданные

- **Канал:** OpenAI
- **YouTube:** https://www.youtube.com/watch?v=8iz5v3Q0g9I
- **Дата:** 10.05.2021
- **Длительность:** 17:11
- **Просмотры:** 7,345
- **Источник:** https://ekstraktznaniy.ru/video/11581

## Описание

Learn more: https://openai.com/blog/openai-scholars-2021-final-projects#kudzo

## Транскрипт

### Intro []

uh hi um my talk is going to be about characterizing test time compute on graph structured problems um most of my scholars project has been spent thinking about uh this question of whether we can uh create models that continuously improve their outputs the more compute that we give them at test time this is something that i'll call the test time compute dream and i think there's much anthromorphic motivation here after all as humans when we're being evaluated our answers tend to become better the longer we're given to think machine learning models for the most part don't exhibit this ability which seems a little weird so i tend to bucket this test time compute stuff into two general categories one is generalization improvement mechanisms which deal with the question of how can we create models that use test time compute to learn more general algorithms instead of learning simple statistical associations and data ideally we'd like these models to use the extra compute to resolve ambiguity and to correct and refine their own answers the second side of this coin is efficiency stuff and this deals with the question of how we can decouple the amount of parameters that a model has from uh the amount of time that it takes to run the model at inference with the motivation here being that if we can construct models that are larger but that don't incur a larger computational cost for those extra parameters than we would okay so how did we tackle this question

### Shortest Path Task [1:39]

um the overwhelming vast majority of this project was actually spent on something i'll likely only spend a single slide talking about in the interest of time and that's the shortest path task the shortest path is a sequence to sequence modeling task in which i give the model a pair of tokens representing pairs of u. s cities and i expected to output a sequence of target tokens that represent the shortest path between the destinations the stuff i'll mostly be presenting on only really took shape in the past three or four weeks and it involved investigating some of these test time properties on uh graph neural networks operating over the game of sudoku

### Control Flop Budget [2:20]

okay like i said most of my project was spent on the shortest pathwork in which we were trying to answer this question if we control the flop the total flop budget of um our models is there ever a point where the test time performance of models like the one that you see on the left which use this sort of top layer occurrence ever begins to approach or match the performance of models that don't have this recurrence but maybe are larger have been trained for longer the way we did this was by keeping the training complete budget in terms of flops fixed for all the models and then training these recurrent models with a fixed number of time steps during training with loss evaluated at every single time step and then during test time evaluating them with more steps of recurrence see if it ever reaches a point where the extra compute allows them to in some sense catch up to the larger models that were trained without this recurrence long story short it largely doesn't seem to work we never really see this sort of phase transition recurrence alone doesn't seem to be enough

### Linear Probe [3:26]

enough to be clear if you run a linear probe on the embedding space for these models they actually seem to learn something like the locations or something at least isometric to the locations of the cities fairly quickly which indicates that the problem isn't actually learning where the cities are it seems to be that even with the extra recurrence the extra compute learning a general shortest path algorithm is difficult occurrence alone doesn't seem to be enough we need additional structure on top of that which is where the graph neural network stuff comes in

### Input Representation Phase [4:01]

so graph neural networks or networks that operate on graph structured data there are a few main parts the first part is this input representation phase where you pass in your graph structured data x here represents the nodes in your graph which contain the features that you care about these could be the locations of u. s cities or the values of cells on a sudoku board a represents the adjacencies which encode some concept of the edges of the graph in other words what relationships nodes have with each other the gnn processes this graph by iteratively performing a learned message passing operation between the nodes where it attempts to refine its internal representation of those nodes at the end of this refinement phase we can then run classification tasks on either the individual nodes or if we aggregate the nodes we can run classification on the entire graph

### Graph refinement equation [4:57]

okay a key feature of these gnns is this graph refinement equation which i'll come back to at least twice in this presentation um it looks wild in its general form but all it really is just three parts um it says that the hidden state for a node i is updated by a function that takes in the node embedding for that node and all pairs of that node's neighbors passed through some function and then aggregated using your favorite aggregation function cool okay so how do we do this for

### Graph refinement for sudoku [5:25]

sudoku well every cell on the sudoku board corresponds to a node on the graph this the nodes on this graph refine their representations by passing messages to themselves or their neighbors using that graph refinement equation we just saw and now what's typically done is that you run this graph refinement phase for a fixed number of times let's say 10 time steps and then at the very end you run your linear projection and you make a prediction what we do a little differently here is that we make a prediction at every point along the graph refinement phase and we evaluate the loss at every single point as well this allows the model to be more robust to being evaluated during the graph for fun to being evaluated with more graph refinement iterations at test time than it was trained on um at training

### Graph refinement in practice [6:19]

okay so how does this actually look like in practice here's one solving sudoku this is real data by the way what's cool about this is that it appears to prioritize spending the extra compute resources on attending to and refining tokens that have assigned a low probability high uncertainty to in the previous time steps the red things become more green and the green things stay green

### Graph generalization [6:43]

okay this is cool because it's a sign that the test time computer dream is at least in principle possible if we look at this graph which shows the gnn operating over two data sets one is normal and the other is hard we see generalization in two different senses one as we increase the amount of iterations or test time compute we see that the accuracy of the network improves in an almost monotonically increasing way by the way the accuracy here is measured on the sequence level which means that i only count it if it gets the entire board correct the other sense is that if we give the network problems that are harder than the ones it was trained on it still performs well okay

### Deep equilibrium models [7:24]

so if the argument here is that more test time compute more iterations is good what would happen if we could evaluate this model at infinite depth in other words could we do better in order to answer this question we need to steal the machinery of deep equilibrium models now i don't have a whole lot of time to go into the details of deep equilibrium models but i suggest that you check out the paper by xiao zubai or the europe's workshop from this past year the gist is that deep equilibrium models are inspired by the observation that we can often rewrite a standard neural network as an implicit function that instead of specifying explicitly how to compute the layer's output as a function of its input we instead specify the conditions in which we would like the layer's output to satisfy after rewriting these layers as implicit functions it turns out that most of them converge to a fixed point which allows us to instead of keeping track of the intermediaries that graph refinement phase in our auto grad library we could instead use an arbitrary black box root finding algorithm and to evaluate this convergence point this is equivalent to running an infinite depth weight tied feed forward network but has the notable advantage that we can analytically back propagate through this equilibrium point using something called the implicit function theorem cool um yeah how's this relevant to gnns well if you take a look at that graph refinement equation from earlier it looks exactly like a fixed point equation which means that we can apply the machinery of deep equilibrium nuts here if you try this out it actually

### Deep equilibrium [9:04]

works really well with a big caveat that i'll related to early stopping that i'll get to in the next slide these early training curves are preliminary but kind of dramatic the deep equilibrium sorry the deep equilibrium gnn trains a lot faster than the traditional gnn further because we're using the machinery of d people agreements to save us from having to keep track of the intermediate steps of that graph refinement phase in our auto grad library the memory usage of the dp equilibrium is smaller than the regular one as well okay so what's the caveat

### Caveat [9:42]

well uh as far as i can tell every single time i've been uh i've run this i've run into this weird collapse that happens where it starts training and it's doing really great and then it dies and i haven't quite been able to figure out why this happens i suspect it has to do with the growth of the spectral norm of the operators inside the gnn as it's being evaluated by the fixed point iterator but it also just could be a bug in my code um stopping training early when this degeneracy begins is proven to be fine and i'm still investigating the problem but i just wanted to point this out for completeness

### Network Structure [10:22]

completeness okay shifting gears a little bit can we do better still in another way gnns are fine as we've seen they seem to do well on these relational reasoning style tasks but one potential oddity is that we must be explicit about the network the structure of the network of the graph that is we must explicitly tell the network which nodes are connected to each other nodes for sudoku for instance we must be explicit about saying that things that are in the same row column things that are in the same cell are connected could we instead learn the adjacencies from scratch from the raw unstructured data here's the idea okay transformers seem to be pretty good at learning how relevant pairs of tokens are to each other on the other hand gnns are good at operating over structured data what if we could use the tension head from a standard transformer to extract an adjacency matrix which we then feed into the gnn here's how it works we first feed a small transformer our input with a small modification that at the top layer we use the probability scores to categorically sample the top k indices which are the most relevant to that particular token that extracts k neighborhoods for each token which we can then feed into our tnn now sampling indices is a non-differentiable operation however we can compensate for this by using the surrogate loss thing outlined below this is taken from a paper by john shulman and uh it just provides a general framework for gradient estimation through stochastic compute graphs the formulism just gives us a way to convert stochastic compute graphs into deterministic compute graphs and evaluate a surrogate loss using standard back propagation that provides an unbiased estimator of the gradient through the stochastic node cool okay so if you try this out it works kind of um the reality is that it just trains much slower than the standard gnn and you know vanilla policy gradients are high variance they're kind of messy but and the performance actually is worse than the standard gnn but it does show that in principle we could train a gnn from scratch that learns the adjacencies from scratch as well which is kind of cool okay

### Conclusion [12:50]

conclusion yeah so uh test time compute mechanisms i think are largely underexplored but hold much promise they have the potential for improved generalization mechanisms potential for improved sample efficiency i think recurrence plus message passing seems to be a really interesting combination and if the methods of this presentation seem uh contrived that's because they are but ultimately like i'm while the specific methods are kind of crude i'm bullish on the idea of test time compute in general and i think that the next few years we'll see critical breakthroughs that make use of ideas that have test time compute at their core

### Questions [13:32]

that's it i'd like to thank my mentor will gus and i'd also like to thank uh the program organizers and my cohort and uh all the people that um gave me early feedback on uh some of this stuff and thank you and now i'll take questions uh let's see let me stop sharing okay so this first question here is how do i extend this gnn setting to sequence modeling like the language modeling loss in uh gpt yeah so um you could imagine that each output token corresponds to either uh yeah you could do this in a couple ways like you could imagine like an auto regressive type thing where like you're at each point evaluating the state of the entire graph and outputting an output sequence and then feeding that output sequence into uh the sort of beginning of the model and then running this again doing this sort of auto aggressively is one way um and then yeah but i'm sure there are other ways that i'm just not familiar with but yeah so what type of problems you expect test time to compute to really shine in yeah this is a great question i think sorry my dog gets really excited here uh i think ultimately uh test time compute will shine in problems that really have sort of these relational reasoning style tasks where we need to relate our previous outputs to things that we're currently processing or problems where we need to condition the amount of computes that we do on the complexity of our inputs okay uh does the stochastic compute graph mean that gnns can be applied to settings without inductive biases that's the ultimate hope i think this is just very crude early work that shows that you could potentially also just learn the adjacencies without hand uh baking the inductive bias though i mean i think part of the appeal of graph neural networks is that like they're so easy to bake in inductive biases that you just feed in the graph as it is and that is the inductive bias for your data so there's definitely a trade-off here and it's not like super clear that doing this is like always the right thing to do okay last question what does it look like if you threshold the learned adjacency weights to produce a discrete graph structure is this roughly right threshold the learned adjacency weights oh right so yeah that's a good point these are discrete that's the whole point of the sampling thing is that the adjacencies are the indices for each token which correspond to the other tokens that are um they're near it so this isn't like a tension where we're doing like a soft max over um over the other output tokens like we're using the transformers probability scores and then we're discretely sampling like which we're using the transformers attention uh scores asks our weights for our discrete sampling if that makes sense um but yeah cool i think i'm over time so um yeah i'll hand it back over to francis i guess
