## Dictionary Learning

Dictionary learning consists of jointly solving for both a dictionary \(D\), the columns of which store features and weights \(X\) which store coefficients applied to \(D\) required to best approximate a collection of input datapoints stored as columns of \(Y\). The weights \(X\) are required to be sparse, i.e. have relatively few non-zero entries:

An additional constraint must be added to \(D\) because, without one, the 1-norm in the objective can be made arbitrarily small by allowing \(D\) to grow arbitrarily large. Typically this constraint is that columne of \(D\) have unit 2-norm.

This is a non-convex problem, but it can be solved relatively well by iterating weight and dictionary updates, leaving the other fixed. Ignoring the norm constraint, the dictionary update is a linear least-squares problem and can be solved as:

After this, the columns can be normalized.

The weight update is a sparse-coding problem and can be solved using ADMM. The first step is to introduce a splitting variable that allows the Frobenius-norm and 1-norm to be handled separately:

Once in this form, the update algorithm for the weights can be written from inspection based on the link above using \(F(x) = \argmin \frac{1}{2}\| Y - D X \|_F^2\) and \(G(Z) = \lambda \| Z \|_1\).

All that remains is to determine the proximal operators and initialization. The proximal operator for \(X\) is for a pretty conventional least-squares problem:

The corresponding proximal operator for \(Z\) is for 'pure' 1-norm, which is just the soft-thresholding operator:

If you're in doubt where either of these come from, consult the ADMM link above.

This leads to some interesting choices. Solving the split problem has two sets of variables, \(X\) and \(Z\). \(X\) minimizes the data term but \(Z\) enforces the constraints. Nominally, both of these should be equal but this only holds at the limit of iteration count, assuming the overall algorithm works at all (the problem is non-convex after all). Which should be used when updating \(D\)?

For a test problem involving around 5.5M points of 48 dof data, I found that using \(Z\) is vastly preferable. Using a dictionary with only 100 atoms (~2X overcomplete) yields reconstruction errors around 1% or less but the percentage of non-zeros in the coding weights is only around 25%. It settled relatively quickly. Using \(X\) on the other hand oscillated quite a bit and did not reach these levels of accuracy/sparsity.

I conclude that it's better to use a looser fit than slacker constraint satisfaction, for this problem at least. Presumably the dictionary/codebooks adapt to each other better when they are consistent. More 'compressible' problems may yield sparser solutions.

```
import numpy
def sparse_coding( Y, D, X, rho, num_iterations, Z=None, U=None ):
if Z is None:
Z = X.copy()
if U is None:
U = X - Z
# precompute solve and part of RHS
iDtD = numpy.linalg.inv( numpy.dot(D.transpose(),D) + numpy.eye(D.shape[1]) )
DtY = numpy.dot( D.transpose(), Y )
for iter in range(num_iterations):
print(' Sparse coding iteration [{}/{}]'.format(iter+1,num_iterations) )
# primary update
X = numpy.dot( iDtD, DtY + Z - U )
# splitting variable update
T = X + U
Z = numpy.maximum( numpy.abs(T) - rho, 0.0)*numpy.sign(T)
# lagrange multiplier update
U = T - Z
return X, Z, U
def dictionary_learning( Y, num_atoms, rho=0.001, num_outer_iterations=10, num_inner_iterations=10, epsilon=1e-8 ):
# initialize the dictionary and weights
X = numpy.random.standard_normal( (num_atoms, Y.shape[1]) )
D = numpy.random.standard_normal( (Y.shape[0], num_atoms) )
# outer loop to fit dictionary and weights
Z = X.copy()
U = X - Z
for outer in range(num_outer_iterations):
print( ' Outer iteration [{}/{}]'.format(outer+1,num_outer_iterations) )
# dictionary update
D = numpy.linalg.solve(
numpy.dot(Z,Z.transpose())+epsilon*numpy.eye(X.shape[0]),
numpy.dot(Z,Y.transpose())
).transpose()
for i in range(D.shape[1]):
D[:,i] /= numpy.linalg.norm(D[:,i])
# sparse coding weight update
X, Z, U = sparse_coding( Y, D, X, rho, num_inner_iterations, Z, U )
# print some stats
print( ' ||Y-DX|| RMS error: {}'.format( numpy.sqrt(numpy.mean(numpy.square(Y - numpy.dot(D,Z)))) ) )
print( ' mean(nnz(X)): {}'.format( numpy.mean( numpy.sum(numpy.abs(Z)>1e-4, axis=0) ) ) )
# return dictionary and solution variables
return D, Z
```