This post explains how to implement a function to generate a crop of a particular object from an image given the position of the centroid of the object. The images below show an example use case: cropping all the individual oranges from an image.
Download an image
!curl -o oranges.jpg https://lh3.googleusercontent.com/pw/ACtC-3fLFL_y58xx1GVy6jLQ0quLpoctt-WG5yo5dR1N3RurI4Qodnnj_JeCEQG-kzILCAUNgZmcA5QlkuLYnbW33Y1XTj48knehvFywJoz1ni3U6MtGiiJzvz4edv0kU0y7RzYRvuWXbewA5glVbkx_Ja-PXg=w1312-h1393-no
Define the function
Torchvision has several image transforms such as
FiveCrop etc but if you want to generate a custom crop from an image, there is a utils function called
crop that takes as input the position of the top left corner of the crop region, along with the height and width of the crop.ped image If the crop is smaller than the required input size, then the image is padded with zeros.
In this example, the position of the centroid of the cropped region is known. Below, we define a wrapper that uses the torchvision
import math import numpy as np import matplotlib.pyplot as plt import torch from torchvision.transforms.functional import crop def generate_crop( image: torch.Tensor, centroid_x: int, centroid_y: int, crop_width: int = 256, crop_height: int = 256, ) -> torch.Tensor: """ Generate a crop of an image around a central point with a specified width and height. Args: image: FloatTensor of shape (H, W, 3) centroid_x: int, giving the x position of the centre of the crop area centroid_y: int, giving the y position of the centre of the crop area crop_width: int, width of the cropped image crop_height: int height of the cropped image Returns: cropped_image: FloatTensor of shape (crop_height, crop_width, 3) """ # Find the top and left positions of the crop # region in the input image left = math.floor(centroid_x - crop_width/2) top = math.floor(centroid_y - crop_height/2) # Reshape image to (3, H, W) image = image.permute(2, 0, 1) # Generate the crop cropped_image = crop(image, top, left, crop_height, crop_width) # Reshape back to (H, W, 3) return cropped_image.permute(1, 2, 0)
Example use case
Now we can apply this function to generate crops of all the oranges in the input image
img = plt.imread('oranges.jpg') image_tensor = torch.from_numpy(img) # The boxes are defined as [center_x, center_y, width, height] boxes = [ [285, 310, 268, 248], [1130, 330, 258, 258], [135, 550, 250, 250], [670, 490, 230, 240], [840, 265, 240, 200], [830, 1258, 245, 250], [1120, 670, 150, 150], [245, 750, 120, 170] ] # Generate crop for each orange crops = [generate_crop(image_tensor, *box) for box in boxes] # Visualize each cropped image fig, axs = plt.subplots(nrows=2, ncols=4, figsize=(16, 16), ) axs = np.array(axs).reshape(-1) for ax, crop_im in zip(axs, crops): ax.imshow(crop_im) plt.axis('off') plt.tight_layout()