KOALA: A Kalman Optimization Algorithm with Loss Adaptivity

Aram Davtyan, Sepehr Sameni, Llukman Cerkezi, Givi Meishvili, Adam Bielski, Paolo Favaro. In AAAI, 2022.

View the Project on GitHub Araachie/koala

Paper:
[Arxiv]

Code:
[GitHub]

Abstract

Optimization is often cast as a deterministic problem, where the solution is found through some iterative procedure such as gradient descent. However, when training neural networks the loss function changes over (iteration) time due to the randomized selection of a subset of the samples. This randomization turns the optimization problem into a stochastic one. We propose to consider the loss as a noisy observation with respect to some reference optimum. This interpretation of the loss allows us to adopt Kalman filtering as an optimizer, as its recursive formulation is designed to estimate unknown parameters from noisy measurements. Moreover, we show that the Kalman Filter dynamical model for the evolution of the unknown parameters can be used to capture the gradient dynamics of advanced methods such as Momentum and Adam. We call this stochastic optimization method KOALA, which is short for Kalman Optimization Algorithm with Loss Adaptivity. KOALA is an easy to implement, scalable, and efficient method to train neural networks. We provide convergence analysis and show experimentally that it yields parameter estimates that are on par with or better than existing state of the art optimization algorithms across several neural network architectures and machine learning tasks, such as computer vision and language modeling.

Method

In machine learning, given the dataset and the loss function , we are interested in minimizing the empirical risk with respect to network parameters ), i.e., we want to find a such that

Due to large datasets, SGD-like algorithms use minibatch risks

Because of the central limit theorem, the minibatch loss tends to be Gaussian with mean the empirical loss .

We define training as the task of finding given the noisy minibatch risks :

where and is a feasible loss value that we aim for .

Снимок экрана 2022-06-05 в 16 47 05

By modeling the state dynamics (i.e., the network parameters) via a dynamical system we can use the Extended Kalman Filtering equations to identify the parameters, which resembles into an optimization framework that we call KOALA.

With different state dynamics we derive two algorithms: KOALA-V (Vanilla) and KOALA-M (Momentum).

Снимок экрана 2022-06-05 в 16 52 06

For more details, please, check the paper.

Results

We have tested our algorithm against SGD and Adam on CIFAR-10/100 and ImageNet32. The results are shown in the table below. For more quantitative results, please, refer to the full text.

Снимок экрана 2022-06-05 в 16 55 17

Citation

The paper is to appear in the Proceedings of the 36th AAAI Conference on Artificial Intelligence. In the meantime we suggest using the arxiv preprint bibref.

Davtyan, A., Sameni, S., Cerkezi, L., Meishvili, G., Bielski, A., & Favaro, P. (2021). KOALA: A Kalman Optimization Algorithm with Loss Adaptivity. arXiv preprint arXiv:2107.03331.

@article{davtyan2021koala,
  title    = {KOALA: A Kalman Optimization Algorithm with Loss Adaptivity},
  author   = {Davtyan, Aram and Sameni, Sepehr and Cerkezi, Llukman and Meishvili, Givi and Bielski, Adam and Favaro, Paolo},
  journal  = {arXiv preprint arXiv:2107.03331},
  year    = {2021}
}