RecT: A Recursive Transformer Architecture for Generalizable Mathematical Reasoning

In collaboration with Rohan Deshpande and Jerry Chen.

This project began in Stanford University’s CS 224N (Natural Language Processing with Deep Learning) class in Winter 2021. It was submitted, accepted, and presented at the Neural-Symbolic Learning and Reasoning (NeSy) Workshop at the International Joint Conference on Learning & Reasoning (IJCLR) in October 2021.

Read the full paper here.

Abstract

There has been increasing interest in recent years on investigating whether neural models can learn mathematical reasoning. Previous approaches have attempted to train models to derive the final answer directly from the question in a single step. However, they have typically failed to generalize to problems more complex than those in the training set.

In this paper, we posit that these failures can be circumvented by introducing a strongly supervised recursive framework to the traditional transformer architecture. Rather than having the model output the answer directly in a single shot, we reduce each problem into a sequence of intermediate steps that are teacher forced during training. During inference, the autoregressive model recursively generates each intermediate step until it arrives at the final solution.

We validate our method by training models to solve a popular mathematical reasoning task: complex addition and subtraction with parentheses. Our model not only attains a near perfect accuracy on problems of similar difficulty to the train set but also showcases generalization capabilities: while current state-of-the-art neural architectures completely fail to extrapolate to more complex arithmetic problems, we achieve a 66.26% accuracy.