Deconvolution via ADMM in Python - Part 2: Code
In the previous part of this post I introduced deconvolution and built up the core of a total-variation deconvolution algorithm based on ADMM.
def deconvolve_ADMM( K, B, num_iters=40, rho=1.0, gamma=0.1 ):
# store the image size
dim = B.shape
# pad out the kernel if its the wrong shape
if K.shape != dim:
raise ValueError('B and K must have same shape')
# define the two derivative operators
Dx = Kernel2D( [ 0, 0], [-1, 0], [-1.0, 1.0] )
Dy = Kernel2D( [-1, 0], [ 0, 0], [-1.0, 1.0] )
# define an initial solution estimate
I = numpy.zeros( dim )
# define the two splitting variables and lagrangr multipliers
Zx = Dx.mul( I )
Zy = Dy.mul( I )
Ux = numpy.zeros( Zx.shape )
Uy = numpy.zeros( Zy.shape )
# cache the necessary terms for the I update, need to circularly
# shift the kernel so it's DC spot lies at the corner
fK = numpy.fft.fft2( numpy.roll( K/numpy.sum(K), (dim[0]//2,dim[1]//2), axis=(0,1) ) )
fB = numpy.fft.fft2( B )
fDx = Dx.spectrum( dim )
fDy = Dy.spectrum( dim )
# build the numerator and denominator
num_init = numpy.conj( fK )*fB
den = numpy.conj( fK )*fK + rho*( numpy.conj(fDx)*fDx + numpy.conj(fDy)*fDy )
# define the L1 soft-thresholding operator
soft_threshold = lambda q: numpy.sign(q)*numpy.maximum( numpy.abs( q ) - gamma/rho, 0.0 )
# main solver loop
for iter in range( num_iters ):
print('ADMM iteration [%d/%d]'%(iter,num_iters))
# I-update
V = rho*( Dx.mul( Zx - Ux, trans=True) + Dy.mul( Zy - Uy, trans=True ) )
I = numpy.real( numpy.fft.ifft2( (num_init + numpy.fft.fft2(V))/den ) )
# Z-updates, cache the gradient filter results
tmp_x = Dx.mul( I )
tmp_y = Dy.mul( I )
Zx = soft_threshold( tmp_x + Ux )
Zy = soft_threshold( tmp_y + Uy )
# multiplier update
Ux += tmp_x - Zx
Uy += tmp_y - Zy
# return reconstructed result
return I
That is it.