Vassilis Papadopoulos, Jérémie Wenger, Clément Hongler
Large Language Models (LLMs) typically model the probability of observing a token given the previous tokens, generating text autoregressively. The product of these conditional probabilities forms the joint probability over the sequence, representing the likelihood of observing the entire sequence of tokens.
The authors of the paper observe that this probability distribution can also be learned in a right-to-left fashion, by predicting the previous token instead of the next one. This raises the question: does modelling left-to-right produce a better joint distribution estimate? If so, why?
The authors denote this phenomenon as a “forward arrow of time” when the left-to-right model consistently outperforms the right-to-left model. Their main result is the identification of a consistent forward arrow of time in language models across various architectures, model sizes and languages, with the performance gap increasing as model size and capability increase.
They argue that natural language inherently possesses a certain degree of sparsity. In the forward direction, this sparsity is less pronounced due to the natural progression of language being more predictable and structured. This reduced sparsity makes it easier for language models to predict the next token, leading to better performance (lower perplexity). Conversely, predicting previous tokens in the backward direction encounters higher sparsity and less predictable structures, resulting in higher perplexity.
To illustrate this, the authors construct artificial datasets where one direction is easier to model. For instance, they create a dataset of random prime numbers followed by their product. Modelling this sequence in reverse, starting with the product, would require the model to perform prime factorization, a computationally complex task. This example demonstrates a strong forward arrow of time, emphasising the computational complexity aspect.
The authors hope these insights can inform improved training procedures that explicitly account for the arrow of time present in the dataset. This could lead to more efficient and effective language models.
Arrows of Time for Large Language Models