Extract Image Patches in Python
There are a number of imaging algorithms that work with patch representations. Non-local means, dictionary learning and Gaussian Mixture Models are some examples.
In order to use these, you need to be able to extract image patches and reconstruct images efficiently. There are probably faster methods but the following does the trick for me without much in the way of dependencies:
import numpy
def extract_grayscale_patches( img, shape, offset=(0,0), stride=(1,1) ):
"""Extracts (typically) overlapping regular patches from a grayscale image
Changing the offset and stride parameters will result in images
reconstructed by reconstruct_from_grayscale_patches having different
dimensions! Callers should pad and unpad as necessary!
Args:
img (HxW ndarray): input image from which to extract patches
shape (2-element arraylike): shape of that patches as (h,w)
offset (2-element arraylike): offset of the initial point as (y,x)
stride (2-element arraylike): vertical and horizontal strides
Returns:
patches (ndarray): output image patches as (N,shape[0],shape[1]) array
origin (2-tuple): array of top and array of left coordinates
"""
px, py = numpy.meshgrid( numpy.arange(shape[1]),numpy.arange(shape[0]))
l, t = numpy.meshgrid(
numpy.arange(offset[1],img.shape[1]-shape[1]+1,stride[1]),
numpy.arange(offset[0],img.shape[0]-shape[0]+1,stride[0]) )
l = l.ravel()
t = t.ravel()
x = numpy.tile( px[None,:,:], (t.size,1,1)) + numpy.tile( l[:,None,None], (1,shape[0],shape[1]))
y = numpy.tile( py[None,:,:], (t.size,1,1)) + numpy.tile( t[:,None,None], (1,shape[0],shape[1]))
return img[y.ravel(),x.ravel()].reshape((t.size,shape[0],shape[1])), (t,l)
def reconstruct_from_grayscale_patches( patches, origin, epsilon=1e-12 ):
"""Rebuild an image from a set of patches by averaging
The reconstructed image will have different dimensions than the
original image if the strides and offsets of the patches were changed
from the defaults!
Args:
patches (ndarray): input patches as (N,patch_height,patch_width) array
origin (2-tuple): top and left coordinates of each patch
epsilon (scalar): regularization term for averaging when patches
some image pixels are not covered by any patch
Returns:
image (ndarray): output image reconstructed from patches of
size ( max(origin[0])+patches.shape[1], max(origin[1])+patches.shape[2])
weight (ndarray): output weight matrix consisting of the count
of patches covering each pixel
"""
patch_width = patches.shape[2]
patch_height = patches.shape[1]
img_width = numpy.max( origin[1] ) + patch_width
img_height = numpy.max( origin[0] ) + patch_height
out = numpy.zeros( (img_height,img_width) )
wgt = numpy.zeros( (img_height,img_width) )
for i in range(patch_height):
for j in range(patch_width):
out[origin[0]+i,origin[1]+j] += patches[:,i,j]
wgt[origin[0]+i,origin[1]+j] += 1.0
return out/numpy.maximum( wgt, epsilon ), wgt
if __name__ == '__main__':
import cv2
import time
import matplotlib.pyplot as plt
img = cv2.imread('lena.png')[:,:,::-1]
start = time.time()
p, origin = extract_grayscale_patches( img[:,:,2], (8,8), stride=(1,1) )
end = time.time()
print( 'Patch extraction took: {}s'.format(numpy.round(end-start,2)) )
start = time.time()
r, w = reconstruct_from_grayscale_patches( p, origin )
end = time.time()
print('Image reconstruction took: {}s'.format(numpy.round(end-start,2)) )
print( 'Reconstruction error is: {}'.format( numpy.linalg.norm( img[:r.shape[0],:r.shape[1],2]-r ) ) )
plt.subplot( 131 )
plt.imshow( img[:,:,2] )
plt.title('Input image')
plt.subplot( 132 )
plt.imshow( p[p.shape[0]//2] )
plt.title('Central patch')
plt.subplot( 133 )
plt.imshow( r )
plt.title('Reconstructed image')
plt.show()
Remove the end bit to have only numpy as a dependency.