FISTA-Net: Learning a Fast Iterative Shrinkage Thresholding Network for Inverse Problems in Imaging

The overall architecture of the proposed FISTA-Net.

Abstract

Inverse problems are essential to imaging applications. In this paper, we propose a model-based deep learning network, named FISTA-Net, by combining the merits of interpretability and generality of the model-based Fast Iterative Shrinkage/Thresholding Algorithm (FISTA) and strong regularization and tuning-free advantages of the data-driven neural network. By unfolding the FISTA into a deep network, the architecture of FISTA-Net consists of multiple gradient descent, proximal mapping, and {momentum modules} in cascade. Different from FISTA, the gradient matrix in FISTA-Net can be updated during iteration, and a proximal operator network is developed for nonlinear thresholding, which can be learned through end-to-end training. Key parameters of FISTA-Net, including the gradient step size, thresholding value, and momentum scalar, are tuning-free and learned from training data rather than hand-crafted. We further impose positive and monotonous constraints on these parameters to ensure they converge properly. The experimental results, evaluated both visually and quantitatively, show that the FISTA-Net can optimize parameters for different imaging tasks, i.e. Electromagnetic Tomography (EMT) and X-ray Computational Tomography (X-ray CT). It outperforms the state-of-the-art model-based and deep learning methods and exhibits good generalization ability over other competitive learning-based approaches under different noise levels.

Publication
IEEE Transactions on Medical Imaging, 2021