Using fewer bits of precision to train machine learning models limits training accuracy—or does it? This post describes cases in which we can get high-accuracy solutions using low-precision computation via a technique called bit recentering, and our theory to explain what's going on.
Low-precision computation has been gaining a lot of traction in machine learning. Companies have even started developing new hardware architectures that natively support and accelerate low-precision operations including Microsoft's Project Brainwave and Google's TPU. Even though using low precision can have a lot of systems benefits, low-precision methods have been used primarily for inference—not for training. Previous low-precision training algorithms suffered from a fundamental tradeoff: when calculations use fewer bits, more round-off error is added, which limits training accuracy. According to conventional wisdom, this tradeoff limits practitioners' ability to deploy low-precision training algorithms in their systems. But is this tradeoff really fundamental? Is it possible to design algorithms that use low precision without it limiting their accuracy?
It turns out that yes, it is sometimes possible to get high-accuracy solutions from low-precision training—and here we'll describe a new variant of stochastic gradient descent (SGD) called high-accuracy low precision (HALP) that can do it. HALP can do better than previous algorithms because it reduces the two sources of noise that limit the accuracy of low-precision SGD: gradient variance and round-off error.
When we are running SGD, in some sense what we are actually doing is averaging (or summing up) a bunch of gradient samples. The key idea behind bit centering is as the gradients become smaller, we can average them with less error using the same number of bits. To see why, think about averaging a bunch of numbers in \([-100, 100]\) and compare this to averaging a bunch of numbers in \([-1, 1]\). In the former case, we'd need to choose a fixed-point representation that can cover the entire range \([-100, 100]\) (for example, \( \{ -128, -127, \ldots, 126, 127 \} \)), while in the latter case, we can choose one that covers \([-1, 1]\) (for example, \( \{ -\frac{128}{127}, -\frac{127}{127}, \ldots, \frac{126}{127}, \frac{127}{127} \} \)). This means that with a fixed number of bits, the delta, the difference between adjacent representable numbers, is smaller in the latter case than in the former: as a consequence, the round-off error will also be lower.
This key idea gives us a key insight. To average the numbers in range \([-1, 1]\) with less error than the ones in \([-100, 100]\), we needed to use a different fixed-point representation. This insight suggests that we should dynamically update the low-precision representation: as the gradients get smaller, we should use fixed-point numbers that have a smaller delta and cover a smaller range.
But how do we know how to update our representation? What range do we need to cover? Well, if our objective is strongly convex with parameter \( \mu \), then whenever we take a full gradient at some point \( w \), we can bound the location of the optimum with \[ \| w - w^* \| \le \frac{1}{\mu} \| \nabla f(w) \|. \] This inequality gives us a range of values in which the solution can be located, and so whenever we compute a full gradient, we can re-center and re-scale the low-precision representation to cover this range. This process is illustrated in the following figure.
We call this operation bit centering. Note that even if our objective is not strongly convex, we can still perform bit-centering: now the parameter \( \mu \) becomes a hyperparameter of the algorithm. With periodic bit centering, as an algorithm converges, the quantization error decreases—and it turns out that this can let it converge to arbitrarily accurate solutions.
HALP is our algorithm which runs SVRG and uses bit centering with a full gradient at every epoch to update the low-precision representation. The full details and algorithm statement are in the paper; here, we'll just present an overview of those results. First, we showed that for strongly convex, Lipschitz smooth functions (this is the standard setting under which the convergence rate of SVRG was originally analyzed), as long as the number of bits \( b \) we use satisfies \[ 2^b > O\left(\kappa \sqrt{d} \right) \] where \( \kappa \) is the condition number of the problem, then for an appropriate setting of the step size and epoch length (details for how to set these are in the paper), HALP will converge at a linear rate to arbitrarily accurate solutions. More explicitly, for some \( 0 < \gamma < 1 \), \[ \mathbf{E}\left[ f(\tilde w_{K+1}) - f(w^*) \right] \le \gamma^K \left( f(\tilde w_1) - f(w^*) \right) \] where \( \tilde w_{K+1} \) denotes the value of the iterate after the K-th epoch. We can see this happening in the following figure.
This figure evaluates HALP on linear regression on a synthetic dataset with 100 features and 1000 examples. It compares it with base full-precision SGD and SVRG, low-precision SGD (LP-SGD), and a low-precision version of SVRG without bit centering (LP-SVRG). Notice that HALP converges to very high-accuracy solutions even with only 8-bits (although it is eventually limited by floating-point error). In this case HALP converges to an even higher-accuracy solution than full-precision SVRG because HALP uses less floating-point arithmetic and therefore is less sensitive to floating-point inaccuracy.
This was only a selection of results: there's a lot more in the paper.