Skip to content

Commit 7af54d6

Browse files
committed
Add the assignments of the lecture 6.
1 parent a34a2eb commit 7af54d6

File tree

5 files changed

+331
-0
lines changed

5 files changed

+331
-0
lines changed

labs/06/bboxes_utils.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
#!/usr/bin/env python3
2+
import argparse
3+
from math import log
4+
from typing import Callable
5+
import unittest
6+
7+
import torch
8+
9+
# Bounding boxes and anchors are expected to be PyTorch tensors,
10+
# where the last dimension has size 4.
11+
12+
# For bounding boxes in pixel coordinates, the 4 values correspond to:
13+
TOP: int = 0
14+
LEFT: int = 1
15+
BOTTOM: int = 2
16+
RIGHT: int = 3
17+
18+
19+
def bboxes_area(bboxes: torch.Tensor) -> torch.Tensor:
20+
"""Compute area of given set of bboxes.
21+
22+
Each bbox is parametrized as a four-tuple (top, left, bottom, right).
23+
24+
If the bboxes.shape is [..., 4], the output shape is bboxes.shape[:-1].
25+
"""
26+
return torch.relu(bboxes[..., BOTTOM] - bboxes[..., TOP]) \
27+
* torch.relu(bboxes[..., RIGHT] - bboxes[..., LEFT])
28+
29+
30+
def bboxes_iou(xs: torch.Tensor, ys: torch.Tensor) -> torch.Tensor:
31+
"""Compute IoU of corresponding pairs from two sets of bboxes `xs` and `ys`.
32+
33+
Each bbox is parametrized as a four-tuple (top, left, bottom, right).
34+
35+
Note that broadcasting is supported, so passing inputs with
36+
`xs.shape=[num_xs, 1, 4]` and `ys.shape=[1, num_ys, 4]` produces an output with
37+
shape `[num_xs, num_ys]`, computing IoU for all pairs of bboxes from `xs` and `ys`.
38+
Formally, the output shape is `torch.broadcast_shapes(xs.shape, ys.shape)[:-1]`.
39+
"""
40+
intersections = torch.stack([
41+
torch.maximum(xs[..., TOP], ys[..., TOP]),
42+
torch.maximum(xs[..., LEFT], ys[..., LEFT]),
43+
torch.minimum(xs[..., BOTTOM], ys[..., BOTTOM]),
44+
torch.minimum(xs[..., RIGHT], ys[..., RIGHT]),
45+
], dim=-1)
46+
47+
xs_area, ys_area, intersections_area = bboxes_area(xs), bboxes_area(ys), bboxes_area(intersections)
48+
49+
return intersections_area / (xs_area + ys_area - intersections_area)
50+
51+
52+
def bboxes_to_rcnn(anchors: torch.Tensor, bboxes: torch.Tensor) -> torch.Tensor:
53+
"""Convert `bboxes` to a R-CNN-like representation relative to `anchors`.
54+
55+
The `anchors` and `bboxes` are arrays of four-tuples (top, left, bottom, right);
56+
you can use the TOP, LEFT, BOTTOM, RIGHT constants as indices of the
57+
respective coordinates.
58+
59+
The resulting representation of a single bbox is a four-tuple with:
60+
- (bbox_y_center - anchor_y_center) / anchor_height
61+
- (bbox_x_center - anchor_x_center) / anchor_width
62+
- log(bbox_height / anchor_height)
63+
- log(bbox_width / anchor_width)
64+
65+
If the `anchors.shape` is `[anchors_len, 4]` and `bboxes.shape` is `[anchors_len, 4]`,
66+
the output shape is `[anchors_len, 4]`.
67+
"""
68+
# TODO: Implement according to the docstring.
69+
raise NotImplementedError()
70+
71+
72+
def bboxes_from_rcnn(anchors: torch.Tensor, rcnns: torch.Tensor) -> torch.Tensor:
73+
"""Convert R-CNN-like representation relative to `anchor` to a `bbox`.
74+
75+
If the `anchors.shape` is `[anchors_len, 4]` and `rcnns.shape` is `[anchors_len, 4]`,
76+
the output shape is `[anchors_len, 4]`.
77+
"""
78+
# TODO: Implement according to the docstring.
79+
raise NotImplementedError()
80+
81+
82+
def bboxes_training(
83+
anchors: torch.Tensor, gold_classes: torch.Tensor, gold_bboxes: torch.Tensor, iou_threshold: float,
84+
) -> tuple[torch.Tensor, torch.Tensor]:
85+
"""Compute training data for object detection.
86+
87+
Arguments:
88+
- `anchors` is an array of four-tuples (top, left, bottom, right)
89+
- `gold_classes` is an array of zero-based classes of the gold objects
90+
- `gold_bboxes` is an array of four-tuples (top, left, bottom, right)
91+
of the gold objects
92+
- `iou_threshold` is a given threshold
93+
94+
Returns:
95+
- `anchor_classes` contains for every anchor either 0 for background
96+
(if no gold object is assigned) or `1 + gold_class` if a gold object
97+
with `gold_class` is assigned to it
98+
- `anchor_bboxes` contains for every anchor a four-tuple
99+
`(center_y, center_x, height, width)` representing the gold bbox of
100+
a chosen object using parametrization of R-CNN; zeros if no gold object
101+
was assigned to the anchor
102+
If the `anchors` shape is `[anchors_len, 4]`, the `anchor_classes` shape
103+
is `[anchors_len]` and the `anchor_bboxes` shape is `[anchors_len, 4]`.
104+
105+
Algorithm:
106+
- First, for each gold object, assign it to an anchor with the largest IoU
107+
(the anchor with smaller index if there are several). In case several gold
108+
objects are assigned to a single anchor, use the gold object with smaller
109+
index.
110+
- For each unused anchor, find the gold object with the largest IoU
111+
(again the gold object with smaller index if there are several), and if
112+
the IoU is >= iou_threshold, assign the object to the anchor.
113+
"""
114+
# TODO: First, for each gold object, assign it to an anchor with the
115+
# largest IoU (the anchor with smaller index if there are several). In case
116+
# several gold objects are assigned to a single anchor, use the gold object
117+
# with smaller index.
118+
119+
# TODO: For each unused anchor, find the gold object with the largest IoU
120+
# (again the gold object with smaller index if there are several), and if
121+
# the IoU is >= threshold, assign the object to the anchor.
122+
123+
anchor_classes, anchor_bboxes = ..., ...
124+
125+
return anchor_classes, anchor_bboxes
126+
127+
128+
def main(args: argparse.Namespace) -> tuple[Callable, Callable, Callable]:
129+
return bboxes_to_rcnn, bboxes_from_rcnn, bboxes_training
130+
131+
132+
class Tests(unittest.TestCase):
133+
def test_bboxes_to_from_rcnn(self):
134+
data = [
135+
[[0, 0, 10, 10], [0, 0, 10, 10], [0, 0, 0, 0]],
136+
[[0, 0, 10, 10], [5, 0, 15, 10], [.5, 0, 0, 0]],
137+
[[0, 0, 10, 10], [0, 5, 10, 15], [0, .5, 0, 0]],
138+
[[0, 0, 10, 10], [0, 0, 20, 30], [.5, 1, log(2), log(3)]],
139+
[[0, 9, 10, 19], [2, 10, 5, 16], [-0.15, -0.1, -1.20397, -0.51083]],
140+
[[5, 3, 15, 13], [7, 7, 10, 9], [-0.15, 0, -1.20397, -1.60944]],
141+
[[7, 6, 17, 16], [9, 10, 12, 13], [-0.15, 0.05, -1.20397, -1.20397]],
142+
[[5, 6, 15, 16], [7, 7, 10, 10], [-0.15, -0.25, -1.20397, -1.20397]],
143+
[[6, 3, 16, 13], [8, 5, 12, 8], [-0.1, -0.15, -0.91629, -1.20397]],
144+
[[5, 2, 15, 12], [9, 6, 12, 8], [0.05, 0, -1.20397, -1.60944]],
145+
[[2, 10, 12, 20], [6, 11, 8, 17], [0, -0.1, -1.60944, -0.51083]],
146+
[[10, 9, 20, 19], [12, 13, 17, 16], [-0.05, 0.05, -0.69315, -1.20397]],
147+
[[6, 7, 16, 17], [10, 11, 12, 14], [0, 0.05, -1.60944, -1.20397]],
148+
[[2, 2, 12, 12], [3, 5, 8, 8], [-0.15, -0.05, -0.69315, -1.20397]],
149+
]
150+
# First run on individual anchors, and then on all together
151+
for anchors, bboxes, rcnns in [map(lambda x: [x], row) for row in data] + [zip(*data)]:
152+
anchors, bboxes, rcnns = [torch.tensor(data, dtype=torch.float32) for data in [anchors, bboxes, rcnns]]
153+
torch.testing.assert_close(bboxes_to_rcnn(anchors, bboxes), rcnns, atol=1e-3, rtol=1e-3)
154+
torch.testing.assert_close(bboxes_from_rcnn(anchors, rcnns), bboxes, atol=1e-3, rtol=1e-3)
155+
156+
def test_bboxes_training(self):
157+
anchors = torch.tensor([[0, 0, 10, 10], [0, 10, 10, 20], [10, 0, 20, 10], [10, 10, 20, 20]])
158+
for gold_classes, gold_bboxes, anchor_classes, anchor_bboxes, iou in [
159+
[[1], [[14, 14, 16, 16]], [0, 0, 0, 2], [[0, 0, 0, 0]] * 3 + [[0, 0, log(.2), log(.2)]], 0.5],
160+
[[2], [[0, 0, 20, 20]], [3, 0, 0, 0], [[.5, .5, log(2), log(2)]] + [[0, 0, 0, 0]] * 3, 0.26],
161+
[[2], [[0, 0, 20, 20]], [3, 3, 3, 3],
162+
[[y, x, log(2), log(2)] for y in [.5, -.5] for x in [.5, -.5]], 0.24],
163+
[[0, 1], [[3, 3, 20, 18], [10, 1, 18, 21]], [0, 0, 0, 1],
164+
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [-0.35, -0.45, 0.53062, 0.40546]], 0.5],
165+
[[0, 1], [[3, 3, 20, 18], [10, 1, 18, 21]], [0, 0, 2, 1],
166+
[[0, 0, 0, 0], [0, 0, 0, 0], [-0.1, 0.6, -0.22314, 0.69314], [-0.35, -0.45, 0.53062, 0.40546]], 0.3],
167+
[[0, 1], [[3, 3, 20, 18], [10, 1, 18, 21]], [0, 1, 2, 1],
168+
[[0, 0, 0, 0], [0.65, -0.45, 0.53062, 0.40546], [-0.1, 0.6, -0.22314, 0.69314],
169+
[-0.35, -0.45, 0.53062, 0.40546]], 0.17],
170+
]:
171+
gold_classes, anchor_classes = torch.tensor(gold_classes), torch.tensor(anchor_classes)
172+
gold_bboxes, anchor_bboxes = torch.tensor(gold_bboxes), torch.tensor(anchor_bboxes)
173+
computed_classes, computed_bboxes = bboxes_training(anchors, gold_classes, gold_bboxes, iou)
174+
torch.testing.assert_close(computed_classes, anchor_classes, atol=1e-3, rtol=1e-3)
175+
torch.testing.assert_close(computed_bboxes, anchor_bboxes, atol=1e-3, rtol=1e-3)
176+
177+
178+
if __name__ == '__main__':
179+
unittest.main()

labs/06/svhn_competition.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
#!/usr/bin/env python3
2+
import argparse
3+
import datetime
4+
import os
5+
import re
6+
7+
import numpy as np
8+
import timm
9+
import torch
10+
import torchvision.transforms.v2 as v2
11+
12+
import bboxes_utils
13+
import npfl138
14+
npfl138.require_version("2425.6")
15+
from npfl138.datasets.svhn import SVHN
16+
17+
# TODO: Define reasonable defaults and optionally more parameters.
18+
# Also, you can set the number of threads to 0 to use all your CPU cores.
19+
parser = argparse.ArgumentParser()
20+
parser.add_argument("--batch_size", default=..., type=int, help="Batch size.")
21+
parser.add_argument("--epochs", default=..., type=int, help="Number of epochs.")
22+
parser.add_argument("--seed", default=42, type=int, help="Random seed.")
23+
parser.add_argument("--threads", default=1, type=int, help="Maximum number of threads to use.")
24+
25+
26+
def main(args: argparse.Namespace) -> None:
27+
# Set the random seed and the number of threads.
28+
npfl138.startup(args.seed, args.threads)
29+
npfl138.global_keras_initializers()
30+
31+
# Create logdir name.
32+
args.logdir = os.path.join("logs", "{}-{}-{}".format(
33+
os.path.basename(globals().get("__file__", "notebook")),
34+
datetime.datetime.now().strftime("%Y-%m-%d_%H%M%S"),
35+
",".join(("{}={}".format(re.sub("(.)[^_]*_?", r"\1", k), v) for k, v in sorted(vars(args).items())))
36+
))
37+
38+
# Load the data. The individual examples are dictionaries with the keys:
39+
# - "image", a `[3, SIZE, SIZE]` tensor of `torch.uint8` values in [0-255] range,
40+
# - "classes", a `[num_digits]` PyTorch vector with classes of image digits,
41+
# - "bboxes", a `[num_digits, 4]` PyTorch vector with bounding boxes of image digits.
42+
# The `decode_on_demand` argument can be set to `True` to save memory and decode
43+
# each image only when accessed, but it will most likely slow down training.
44+
svhn = SVHN(decode_on_demand=False)
45+
46+
# Load the EfficientNetV2-B0 model without the classification layer.
47+
# Apart from calling the model as in the classification task, you can call it using
48+
# output, features = efficientnetv2_b0.forward_intermediates(batch_of_images)
49+
# obtaining (assuming the input images have 224x224 resolution):
50+
# - `output` is a `[N, 1280, 7, 7]` tensor with the final features before global average pooling,
51+
# - `features` is a list of intermediate features with resolution 112x112, 56x56, 28x28, 14x14, 7x7.
52+
efficientnetv2_b0 = timm.create_model("tf_efficientnetv2_b0.in1k", pretrained=True, num_classes=0)
53+
54+
# Create a simple preprocessing performing necessary normalization.
55+
preprocessing = v2.Compose([
56+
v2.ToDtype(torch.float32, scale=True), # The `scale=True` also rescales the image to [0, 1].
57+
v2.Normalize(mean=efficientnetv2_b0.pretrained_cfg["mean"], std=efficientnetv2_b0.pretrained_cfg["std"]),
58+
])
59+
60+
# TODO: Create the model and train it.
61+
model = ...
62+
63+
# Generate test set annotations, but in `args.logdir` to allow parallel execution.
64+
os.makedirs(args.logdir, exist_ok=True)
65+
with open(os.path.join(args.logdir, "svhn_competition.txt"), "w", encoding="utf-8") as predictions_file:
66+
# TODO: Predict the digits and their bounding boxes on the test set.
67+
# Assume that for a single test image we get
68+
# - `predicted_classes`: a 1D array with the predicted digits,
69+
# - `predicted_bboxes`: a [len(predicted_classes), 4] array with bboxes;
70+
for predicted_classes, predicted_bboxes in ...:
71+
output = []
72+
for label, bbox in zip(predicted_classes, predicted_bboxes):
73+
output += [int(label)] + list(map(float, bbox))
74+
print(*output, file=predictions_file)
75+
76+
77+
if __name__ == "__main__":
78+
main_args = parser.parse_args([] if "__file__" not in globals() else None)
79+
main(main_args)

lectures/lecture06.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#### Video: https://lectures.ms.mff.cuni.cz/video/rec/npfl138/2425/npfl138-2425-06-czech.mp4, CZ Lecture
66
#### Video: https://lectures.ms.mff.cuni.cz/video/rec/npfl138/2425/npfl138-2425-06-english.mp4, EN Lecture
77
#### Questions: #lecture_6_questions
8+
#### Lecture assignment: bboxes_utils
9+
#### Lecture assignment: svhn_competition
810

911
- R-CNN [[R-CNN](https://arxiv.org/abs/1311.2524)]
1012
- Fast R-CNN [[Fast R-CNN](https://arxiv.org/abs/1504.08083)]

tasks/bboxes_utils.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
### Assignment: bboxes_utils
2+
#### Date: Deadline: Apr 09, 22:00
3+
#### Points: 2 points
4+
5+
This is a preparatory assignment for `svhn_competition`. The goal is to
6+
implement several bounding box manipulation routines in the
7+
[bboxes_utils.py](https://github.com/ufal/npfl138/tree/master/labs/06/bboxes_utils.py)
8+
module. Notably, you need to implement the following methods:
9+
- `bboxes_to_rcnn`: convert given bounding boxes to a R-CNN-like
10+
representation relative to the given anchors;
11+
- `bboxes_from_rcnn`: convert R-CNN-like representations relative to
12+
given anchors back to bounding boxes;
13+
- `bboxes_training`: given a list of anchors and gold objects, assign gold
14+
objects to anchors and generate suitable training data (the exact algorithm
15+
is described in the template).
16+
17+
The [bboxes_utils.py](https://github.com/ufal/npfl138/tree/master/labs/06/bboxes_utils.py)
18+
contains simple unit tests, which are evaluated when executing the module,
19+
which you can use to check the validity of your implementation. Note that
20+
the template does not contain type annotations because Python typing system is
21+
not flexible enough to describe the tensor shape changes.
22+
23+
When submitting to ReCodEx, the method `main` is executed, returning the
24+
implemented `bboxes_to_rcnn`, `bboxes_from_rcnn` and `bboxes_training`
25+
methods. These methods are then executed and compared to the reference
26+
implementation.

tasks/svhn_competition.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
### Assignment: svhn_competition
2+
#### Date: Deadline: Apr 09, 22:00
3+
#### Points: 5 points+5 bonus
4+
5+
The goal of this assignment is to implement a system performing object
6+
recognition, optionally utilizing the pretrained EfficientNetV2-B0 backbone
7+
(or any other model from the [timm](https://huggingface.co/docs/timm) library).
8+
9+
The [Street View House Numbers (SVHN) dataset](https://ufal.mff.cuni.cz/~straka/courses/npfl138/2425/demos/svhn_train.html)
10+
annotates for every photo all digits appearing on it, including their bounding
11+
boxes. The dataset can be loaded using the [npfl138.datasets.svhn](https://github.com/ufal/npfl138/blob/master/labs/npfl138/datasets/svhn.py)
12+
module. Similarly to the `CAGS` dataset, the `train/dev/test` are PyTorch
13+
`torch.utils.data.Dataset`s, and every element is a dictionary with the following keys:
14+
- `"image"`: a square 3-channel image stored using `torch.Tensor` of type `torch.uint8`,
15+
- `"classes"`: a 1D `torch.Tensor` with all digit labels appearing in the image,
16+
- `"bboxes"`: a `[num_digits, 4]` 2D `torch.Tensor` with bounding boxes of every
17+
digit in the image, each represented as `[TOP, LEFT, BOTTOM, RIGHT]`.
18+
19+
Each test set image annotation consists of a sequence of space separated
20+
five-tuples _label top left bottom right_, and the annotation is considered
21+
correct, if exactly the gold digits are predicted, each with IoU at least 0.5.
22+
The whole test set score is then the prediction accuracy of individual images.
23+
You can again evaluate your predictions using the
24+
[npfl138.datasets.svhn](https://github.com/ufal/npfl138/blob/master/labs/npfl138/datasets/svhn.py)
25+
module, either by running with `python3 -m npfl138.datasets.svhn --evaluate=path --dataset=dev/test`
26+
or using the `svhn.evaluate` method. Futhermore, you can visualize your
27+
predictions by using `python3 -m npfl138.datasets.svhn --visualize=path --dataset=dev/test`.
28+
29+
The task is a [_competition_](https://ufal.mff.cuni.cz/courses/npfl138/2425-summer#competitions).
30+
Everyone who submits a solution achieving at least _20%_ test set accuracy gets
31+
5 points; the remaining 5 bonus points are distributed depending on relative ordering
32+
of your solutions. Note that I usually need at least _35%_ development set
33+
accuracy to achieve the required test set performance.
34+
35+
You should start with the
36+
[svhn_competition.py](https://github.com/ufal/npfl138/tree/master/labs/06/svhn_competition.py)
37+
template, which generates the test set annotation in the required format.
38+
39+
_A baseline solution can use RetinaNet-like single stage detector,
40+
using only a single level of convolutional features (no FPN)
41+
with single-scale and single-aspect anchors. Focal loss is available as
42+
[torchvision.ops.sigmoid_focal_loss](https://pytorch.org/vision/main/generated/torchvision.ops.sigmoid_focal_loss.html)
43+
and non-maximum suppression as
44+
[torchvision.ops.nms](https://pytorch.org/vision/main/generated/torchvision.ops.nms.html) or
45+
[torchvision.ops.batched_nms](https://pytorch.org/vision/main/generated/torchvision.ops.batched_nms.html)._

0 commit comments

Comments
 (0)