2 min read

Categories

This post explains how to implement a function to generate a crop of a particular object from an image given the position of the centroid of the object. The images below show an example use case: cropping all the individual oranges from an image.

Download an image

This example uses the image of oranges from the NMS and IoU posts.

!curl -o oranges.jpg https://lh3.googleusercontent.com/pw/ACtC-3fLFL_y58xx1GVy6jLQ0quLpoctt-WG5yo5dR1N3RurI4Qodnnj_JeCEQG-kzILCAUNgZmcA5QlkuLYnbW33Y1XTj48knehvFywJoz1ni3U6MtGiiJzvz4edv0kU0y7RzYRvuWXbewA5glVbkx_Ja-PXg=w1312-h1393-no

Define the function

Torchvision has several image transforms such as CenterCrop, RandomResizedCrop, FiveCrop etc but if you want to generate a custom crop from an image, there is a utils function called crop that takes as input the position of the top left corner of the crop region, along with the height and width of the crop.ped image If the crop is smaller than the required input size, then the image is padded with zeros.

In this example, the position of the centroid of the cropped region is known. Below, we define a wrapper that uses the torchvision crop function.

import math
import numpy as np
import matplotlib.pyplot as plt
import torch
from torchvision.transforms.functional import crop

def generate_crop(
    image: torch.Tensor,
    centroid_x: int,
    centroid_y: int,
    crop_width: int = 256,
    crop_height: int = 256,
) -> torch.Tensor:
    """
    Generate a crop of an image around a central point
    with a specified width and height.

    Args:
      image: FloatTensor of shape (H, W, 3)
      centroid_x: int, giving the x position
        of the centre of the crop area
      centroid_y: int, giving the y position
        of the centre of the crop area
      crop_width: int, width of the cropped image
      crop_height: int height of the cropped image
    Returns:
      cropped_image: FloatTensor of shape
        (crop_height, crop_width, 3)
    """
    # Find the top and left positions of the crop
    # region in the input image
    left = math.floor(centroid_x - crop_width/2)
    top = math.floor(centroid_y - crop_height/2)

    # Reshape image to (3, H, W)
    image = image.permute(2, 0, 1)

    # Generate the crop
    cropped_image = crop(image, top, left, crop_height, crop_width)

    # Reshape back to (H, W, 3)
    return cropped_image.permute(1, 2, 0)

Example use case

Now we can apply this function to generate crops of all the oranges in the input image

img = plt.imread('oranges.jpg')
image_tensor = torch.from_numpy(img)

# The boxes are defined as [center_x, center_y, width, height]
boxes =  [
  [285, 310, 268, 248],
  [1130, 330, 258, 258],
  [135, 550, 250, 250],
  [670, 490, 230, 240],
  [840, 265, 240, 200],
  [830, 1258, 245, 250],
  [1120, 670, 150, 150],
  [245, 750, 120, 170]
]

# Generate crop for each orange
crops = [generate_crop(image_tensor, *box) for box in boxes]

# Visualize each cropped image
fig, axs = plt.subplots(nrows=2, ncols=4, figsize=(16, 16), )
axs = np.array(axs).reshape(-1)

for ax, crop_im in zip(axs, crops):
  ax.imshow(crop_im)
  plt.axis('off')

plt.tight_layout()