Skip to content

Commit 0badcd1

Browse files
dagarcia-nvidianvpaulius
authored andcommitted
NVidia's pytorch-based object_detection implementation (#240)
* Delete Caffe2 object_detection * Added new pytorch-based object_detection * object_detection: removed unused configs; deleted misleading code * object_detection Dockerfile now based on public image and specifies exact library versions
1 parent 5309aff commit 0badcd1

File tree

169 files changed

+13547
-197
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

169 files changed

+13547
-197
lines changed

.gitmodules

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,3 @@
1-
[submodule "object_detection/caffe2"]
2-
path = object_detection/caffe2/caffe2
3-
url = https://github.com/pytorch/pytorch.git
4-
[submodule "object_detection/detectron"]
5-
path = object_detection/caffe2/detectron
6-
url = https://github.com/ddkang/Detectron.git
71
[submodule "community"]
82
path = community
93
url = https://github.com/mlperf/community.git

object_detection/Dockerfile

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
FROM pytorch/pytorch:1.0.1-cuda10.0-cudnn7-devel
2+
3+
RUN echo 'debconf debconf/frontend select Noninteractive' | debconf-set-selections
4+
5+
# install basics
6+
RUN apt-get update -y \
7+
&& apt-get install -y apt-utils=1.2.29ubuntu0.1 \
8+
libglib2.0-0=2.48.2-0ubuntu4.1 \
9+
libsm6=2:1.2.2-1 \
10+
libxext6=2:1.3.3-1 \
11+
libxrender-dev=1:0.9.9-0ubuntu1
12+
13+
RUN pip install ninja==1.8.2.post2 \
14+
yacs==0.1.5 \
15+
cython==0.29.5 \
16+
matplotlib==3.0.2 \
17+
opencv-python==4.0.0.21 \
18+
mlperf_compliance==0.0.10 \
19+
torchvision==0.2.2
20+
21+
# install pycocotools
22+
RUN git clone https://github.com/cocodataset/cocoapi.git \
23+
&& cd cocoapi/PythonAPI \
24+
&& git reset --hard ed842bffd41f6ff38707c4f0968d2cfd91088688 \
25+
&& python setup.py build_ext install
26+
27+
# For information purposes only, these are the versions of the packages which we've successfully used:
28+
# $ pip list
29+
# Package Version Location
30+
# -------------------- ----------------- -------------------------------------------------
31+
# backcall 0.1.0
32+
# certifi 2018.11.29
33+
# cffi 1.11.5
34+
# cycler 0.10.0
35+
# Cython 0.29.5
36+
# decorator 4.3.2
37+
# fairseq 0.6.0 /scratch/fairseq
38+
# ipython 7.2.0
39+
# ipython-genutils 0.2.0
40+
# jedi 0.13.2
41+
# kiwisolver 1.0.1
42+
# maskrcnn-benchmark 0.1 /scratch/mlperf/training/object_detection/pytorch
43+
# matplotlib 3.0.2
44+
# mkl-fft 1.0.10
45+
# mkl-random 1.0.2
46+
# mlperf-compliance 0.0.10
47+
# ninja 1.8.2.post2
48+
# numpy 1.16.1
49+
# opencv-python 4.0.0.21
50+
# parso 0.3.2
51+
# pexpect 4.6.0
52+
# pickleshare 0.7.5
53+
# Pillow 5.4.1
54+
# pip 19.0.1
55+
# prompt-toolkit 2.0.8
56+
# ptyprocess 0.6.0
57+
# pycocotools 2.0
58+
# pycparser 2.19
59+
# Pygments 2.3.1
60+
# pyparsing 2.3.1
61+
# python-dateutil 2.8.0
62+
# pytorch-quantization 0.2.1
63+
# PyYAML 3.13
64+
# setuptools 40.8.0
65+
# six 1.12.0
66+
# torch 1.0.0.dev20190225
67+
# torchvision 0.2.1
68+
# tqdm 4.31.1
69+
# traitlets 4.3.2
70+
# wcwidth 0.1.7
71+
# wheel 0.32.3
72+
# yacs 0.1.5
Lines changed: 95 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,90 +1,95 @@
1-
# 1. Problem
2-
Object detection and segmentation. Metrics are mask and box mAP.
3-
4-
# 2. Directions
5-
### Steps to configure machine
6-
Standard script.
7-
8-
### Steps to download and verify data
9-
init and update the submodules in this directory.
10-
11-
Run the provided shell scripts *in this directory*.
12-
13-
### Steps to run and time
14-
Build the docker container.
15-
16-
```
17-
sudo docker build -t detectron .
18-
```
19-
20-
Run the docker container and mount the data appropriately
21-
22-
```
23-
sudo nvidia-docker run
24-
-v /mnt/disks/data/coco/:/packages/detectron/lib/datasets/data/coco
25-
-it detectron /bin/bash
26-
```
27-
28-
(replace /mnt/disks/data/coco/ with the data directory)
29-
30-
Run the command:
31-
```
32-
time stdbuf -o 0 \
33-
python tools/train_net.py --cfg configs/12_2017_baselines/e2e_mask_rcnn_R-50-FPN_1x.yaml \
34-
--box_min_ap 0.377 --mask_min_ap 0.339 \
35-
--seed 3 | tee run.log
36-
```
37-
38-
# 3. Dataset/Environment
39-
### Publication/Attribution
40-
Microsoft COCO: Common Objects in Context
41-
42-
### Data preprocessing
43-
Only horizontal flips are allowed.
44-
45-
### Training and test data separation
46-
As provided by MS-COCO (2017 version).
47-
48-
### Training data order
49-
Randomly.
50-
51-
### Test data order
52-
Any order.
53-
54-
# 4. Model
55-
### Publication/Attribution
56-
He, Kaiming, et al. "Mask r-cnn." Computer Vision (ICCV), 2017 IEEE International Conference on.
57-
IEEE, 2017.
58-
59-
We use a version of Mask R-CNN with a ResNet50 backbone.
60-
61-
### List of layers
62-
Running the timing script will display a list of layers.
63-
64-
### Weight and bias initialization
65-
The ResNet50 base must be loaded from the provided weights. They may be quantized.
66-
67-
### Loss function
68-
Multi-task loss (classification, box, mask). Described in the Mask R-CNN paper.
69-
70-
Classification: Smooth L1 loss
71-
72-
Box: Log loss for true class.
73-
74-
Mask: per-pixel sigmoid, average binary cross-entropy loss.
75-
76-
### Optimizer
77-
Momentum SGD. Weight decay of 0.0001, momentum of 0.9.
78-
79-
# 5. Quality
80-
### Quality metric
81-
As Mask R-CNN can provide both boxes and masks, we evaluate on both box and mask mAP.
82-
83-
### Quality target
84-
Box mAP of 0.377, mask mAP of 0.339
85-
86-
### Evaluation frequency
87-
Once per epoch, 118k.
88-
89-
### Evaluation thoroughness
90-
Evaluate over the entire validation set. Use the official COCO API to compute mAP.
1+
# 1. Problem
2+
Object detection and segmentation. Metrics are mask and box mAP.
3+
4+
# 2. Directions
5+
6+
### Steps to configure machine
7+
8+
1. Checkout the MLPerf repository
9+
```
10+
mkdir -p mlperf
11+
cd mlperf
12+
git clone https://github.com/mlperf/training.git
13+
```
14+
2. Install CUDA and Docker
15+
```
16+
source training/install_cuda_docker.sh
17+
```
18+
3. Build the docker image for the object detection task
19+
```
20+
cd training/object_detection/
21+
nvidia-docker build . -t mlperf/object_detection
22+
```
23+
24+
4. Run docker container and install code
25+
```
26+
nvidia-docker run -v .:/workspace -t -i --rm --ipc=host mlperf/object_detection \
27+
"cd mlperf/training/object_detection && ./install.sh"
28+
```
29+
Now exit the docker container (Ctrl-D) to get back to your host.
30+
31+
### Steps to download data
32+
```
33+
# From training/object_detection/
34+
source download_dataset.sh
35+
```
36+
37+
### Steps to run benchmark.
38+
```
39+
nvidia-docker run -v .:/workspace -t -i --rm --ipc=host mlperf/object_detection \
40+
"cd mlperf/training/object_detection && ./run_and_time.sh"
41+
```
42+
43+
# 3. Dataset/Environment
44+
### Publication/Attribution
45+
Microsoft COCO: Common Objects in Context
46+
47+
### Data preprocessing
48+
Only horizontal flips are allowed.
49+
50+
### Training and test data separation
51+
As provided by MS-COCO (2017 version).
52+
53+
### Training data order
54+
Randomly.
55+
56+
### Test data order
57+
Any order.
58+
59+
# 4. Model
60+
### Publication/Attribution
61+
He, Kaiming, et al. "Mask r-cnn." Computer Vision (ICCV), 2017 IEEE International Conference on.
62+
IEEE, 2017.
63+
64+
We use a version of Mask R-CNN with a ResNet50 backbone.
65+
66+
### List of layers
67+
Running the timing script will display a list of layers.
68+
69+
### Weight and bias initialization
70+
The ResNet50 base must be loaded from the provided weights. They may be quantized.
71+
72+
### Loss function
73+
Multi-task loss (classification, box, mask). Described in the Mask R-CNN paper.
74+
75+
Classification: Smooth L1 loss
76+
77+
Box: Log loss for true class.
78+
79+
Mask: per-pixel sigmoid, average binary cross-entropy loss.
80+
81+
### Optimizer
82+
Momentum SGD. Weight decay of 0.0001, momentum of 0.9.
83+
84+
# 5. Quality
85+
### Quality metric
86+
As Mask R-CNN can provide both boxes and masks, we evaluate on both box and mask mAP.
87+
88+
### Quality target
89+
Box mAP of 0.377, mask mAP of 0.339
90+
91+
### Evaluation frequency
92+
Once per epoch, 118k.
93+
94+
### Evaluation thoroughness
95+
Evaluate over the entire validation set. Use the official COCO API to compute mAP.

object_detection/caffe2/Dockerfile

Lines changed: 0 additions & 63 deletions
This file was deleted.

object_detection/caffe2/caffe2

Lines changed: 0 additions & 1 deletion
This file was deleted.

object_detection/caffe2/detectron

Lines changed: 0 additions & 1 deletion
This file was deleted.

object_detection/caffe2/extract_dataset.sh

Lines changed: 0 additions & 15 deletions
This file was deleted.

object_detection/caffe2/run_and_time.sh

Lines changed: 0 additions & 5 deletions
This file was deleted.

object_detection/download_dataset.sh

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,26 @@
1-
wget https://s3-us-west-2.amazonaws.com/detectron/coco/coco_annotations_minival.tgz
2-
wget http://images.cocodataset.org/zips/train2014.zip
3-
wget http://images.cocodataset.org/zips/val2014.zip
4-
wget http://images.cocodataset.org/annotations/annotations_trainval2014.zip
1+
#!/bin/bash
2+
3+
# Get COCO 2014 data sets
4+
mkdir -p pytorch/datasets/coco
5+
pushd pytorch/datasets/coco
6+
7+
curl -O https://dl.fbaipublicfiles.com/detectron/coco/coco_annotations_minival.tgz
8+
tar xzf coco_annotations_minival.tgz
9+
10+
curl -O http://images.cocodataset.org/zips/train2014.zip
11+
unzip train2014.zip
12+
13+
curl -O http://images.cocodataset.org/zips/val2014.zip
14+
unzip val2014.zip
15+
16+
curl -O http://images.cocodataset.org/annotations/annotations_trainval2014.zip
17+
unzip annotations_trainval2014.zip
18+
19+
# TBD: MD5 verification
20+
# $md5sum *.zip *.tgz
21+
#f4bbac642086de4f52a3fdda2de5fa2c annotations_trainval2017.zip
22+
#cced6f7f71b7629ddf16f17bbcfab6b2 train2017.zip
23+
#442b8da7639aecaf257c1dceb8ba8c80 val2017.zip
24+
#2d2b9d2283adb5e3b8d25eec88e65064 coco_annotations_minival.tgz
25+
26+
popd

object_detection/download_weights.sh

Lines changed: 0 additions & 1 deletion
This file was deleted.

object_detection/hashes.md5

Lines changed: 0 additions & 4 deletions
This file was deleted.

0 commit comments

Comments
 (0)