Eh. There is nothing diffusion about this. Nothing to do with denoising. This setup is still purely causal, making it quite a dishonest framing IMO. There is no more introspection here than what happens in MTP + SD setups.<p>Let me explain what is going on here. This is basically a form of multi-token prediction. And speculative decoding in inference. See my earlier post[1] to understand what that is. TL;DR, in multi-token prediction you train separate LM heads to predict the next as well as next to next token as well as... Upto chosen next kth token. Training multiple LM heads is expensive and can be unnecessary, so what people typically do is have a common base for all the k heads, explained further in [1]. These guys do another variant.<p>Here is what they do mechanically, given a sequence p consisting of five tokens PE([p1, p2, p3, p4, p5]). Where PE(.) adds relative position info to each token.<p>1. Create an augmented sequence PE([p1 MASK MASK MASK MASK]). Do a training pass on that, with the ground truth sequence p1..5. Here it is trained to, for example, to predict p3 given p1+pos=-2 MASK+pos=-1 MASK+pos=0, loosely notating.<p>2. Then separately[2], train it <i>as usual</i> on PE([p1 p2 p3 p4 p5]).<p>Step (1) teaches it to do multi-token prediction, essentially the single LM head will (very very loosely speaking) condition on the position `k` of the special MASK token and "route" it to the "implicit" k'th LM head.<p>Step (2) teaches it to be a usual LLM and predict the next token. No MASK tokens involved.<p>So far, you have trained a multi-token predictor.<p>Now during inference<p>You use this for speculative decoding. You generate 5 tokens ahead at once with MASK tokens. And then you run that sequence through the LLM again. This has the same benefits as usual speculative decoding, namely that you can do matrix-matrix multiplication as opposed to matrix-vector. The former is more memory-bandwidth efficient due to higher arithmetic intensity.<p>here is an example,<p>query = ["what", "is", "2+2"])
prompt = PE([...query, MASK*5])
you run output = LLM(prompt). Say output is ["what", "is", "2+2", "it", "is", "4"]. Note that the NN is trained to predict the kth next token when faced with positionally encoded MASK tokens. So you get all 5 in one go. To be precise, it learns to predict "4" given ["what", "is", "2+2", MASK, MASK]. Since it does not need the "it" and "is" explicitly, you can do it in parallel with generating the "it" and the "is". "is" is predicted given ["what", "is", "2+2", MASK], for example, and that also doesn't depend on the explicit "it" being there, and thus can also be done in parallel with generating "it", which is just normal generating the next token given the query. And then you use this as a draft in your speculative decoding setup.<p>Their claim is that using a multi-token predictor this way as a draft model works really well. To be clear, this is still causal, the reason diffusion models have hype is because they are capable of global refinement. This is not. In the same thread as [1], I explain how increasing the number of MASK tokens, i.e increasing `k`, i.e the number of tokens you predict at once in your multi-token prediction setup quickly leads to poor quality. This paper agrees with that. They try out k=2,3,4,8. They see a drop in quality at 8 itself. So finally, this is 4-token-prediction with self-speculative decoding, removing seemingly no existing limitation of such setups.<p>[1] <a href="https://news.ycombinator.com/item?id=45221692">https://news.ycombinator.com/item?id=45221692</a><p>[2] Note that it is computationally a single forward pass. Attention masks help you fuse steps 1 and 2 into a single operation. However, you still have 2 separate loss values.