non maximal suppression

Created
TagsCV

19) Implement non maximal suppression as efficiently as you can. [src]

What is non-max suppression?

The objects in the image can be of different sizes and shapes, and to capture each of these perfectly, the object detection algorithms create multiple bounding boxes. (left image). Ideally, for each object in the image, we must have a single bounding box. Something like the image on the right.

To select the best bounding box, from the multiple predicted bounding boxes, these object detection algorithms use non-max suppression. This technique is used to “suppress” the less likely bounding boxes and keep only the best one.

So we now understand why do we need NMS and what is it used for. Let us now understand how exactly is the concept implemented

he non-max suppression will first select the bounding box with the highest objectiveness score. And then remove all the other boxes with high overlap. So here, in the above image,

  1. We will select the Green bounding box for the dog (since it has the highest objectiveness score of 98%)
  1. And remove yellow and red boxes for the dog (because they have a high overlap with the green box)

The following is the process of selecting the best bounding box using NMS-

Step 1: Select the box with highest objectiveness score

Step 2: Then, compare the overlap (intersection over union) of this box with other boxes

Step 3: Remove the bounding boxes with overlap (intersection over union) >50%

Step 4: Then, move to the next highest objectiveness score

Step 5: Finally, repeat steps 2-4

For this image, we are going to use the non-max suppression function nms() from the torchvision library. This function requires three parameters-

Here, since the above coordinates are in x1, y1, width, height format, we will determine the x2, y2 in the following manner-

import matplotlib.pyplot as plt
import matplotlib.patches as patches
import torch
from torchvision.ops import nms


boxes = torch.tensor([[190,380,(190+300),(380+150)],
                      [300,420,(300+150),(420+210)],
                      [320,360,(320+200),(360+230)],

                      [390,50,(390+300),(50+330)],
                      [490,45,(490+200),(45+500)],
                      [480,130,(480+150),(130+400)]], dtype=torch.float32)

scores = torch.tensor([[0.90],[0.98],[0.82], [0.87],[0.98],[0.82]], dtype=torch.float32)

nms(boxes = boxes, scores = scores, iou_threshold=0.2)

image = plt.imread('index.jpg')

# draw emtpy figure
fig = plt.figure()

# define axis
ax = fig.add_axes([0, 0, 1, 1])

# plot image
plt.imshow(image)
      
# create rectangular patch
rect_1 = patches.Rectangle((300, 420), 150, 210, edgecolor='green', facecolor='none', linewidth=2)
rect_4 = patches.Rectangle((490, 45), 200, 500, edgecolor='green', facecolor='none', linewidth=2)
    
# add patch
ax.add_patch(rect_1)
ax.text(328, 416, 'Box1:98%', color='green')

ax.add_patch(rect_4)
ax.text(492, 33, 'Box4:98%', color='green')


# show figure
plt.show()