File tree 4 files changed +44
-6
lines changed
4 files changed +44
-6
lines changed Original file line number Diff line number Diff line change 26
26
27
27
import datasets
28
28
29
- from IPython import embed
30
-
31
29
32
30
class Trainer :
33
31
def __init__ (self , options ):
Original file line number Diff line number Diff line change @@ -35,11 +35,16 @@ def __init__(self, opts):
35
35
self .use_sparse = False
36
36
37
37
if opts .use_wavelets :
38
- if opts .use_sparse :
39
- self .use_sparse = True
40
- if opts .use_224 :
41
- raise NotImplementedError
38
+ try :
39
+ if opts .use_sparse :
40
+ self .use_sparse = True
41
+ if opts .use_224 :
42
+ raise NotImplementedError
43
+ except AttributeError :
44
+ opts .use_sparse = False
45
+ self .use_sparse = False
42
46
47
+ if opts .use_sparse :
43
48
self .decoder = SparseDecoderWave (enc_features = self .encoder .num_ch_enc , decoder_width = decoder_width )
44
49
else :
45
50
if opts .use_224 :
Original file line number Diff line number Diff line change @@ -38,6 +38,24 @@ are then replaced with **sparse** ones.
38
38
39
39
This is because the network first needs to learn to predict sparse wavelet coefficients before we can use sparse convolutions.
40
40
41
+ ## 🗂 Environment Requirements 🗂 ##
42
+
43
+ We recommend creating a new Anaconda environment to use WaveletMonoDepth. Use the following to setup a new environment:
44
+
45
+ ```
46
+ conda env create -f environment.yml
47
+ conda activate wavelet-mdp
48
+ ```
49
+
50
+ Our work uses [ Pytorch Wavelets] ( https://github.com/fbcotter/pytorch_wavelets ) , a great package from Fergal Cotter
51
+ which implements the Inverse Discrete Wavelet Transform (IDWT) used in our work, and a lot more!
52
+ To install Pytorch Wavelets, simply run:
53
+ ```
54
+ git clone https://github.com/fbcotter/pytorch_wavelets
55
+ cd pytorch_wavelets
56
+ pip install .
57
+ ```
58
+
41
59
## 🚗🚦 KITTI 🌳🛣
42
60
[ Depth Hints] ( https://github.com/nianticlabs/depth-hints ) was used as a baseline for KITTI.
43
61
Original file line number Diff line number Diff line change
1
+ name : wavelet-mdp
2
+ channels :
3
+ - default
4
+ - pytorch
5
+ - conda-forge
6
+ dependencies :
7
+ - opencv=3.4.2
8
+ - matplotlib=3.1.2
9
+ - numpy=1.19.5
10
+ - scikit-learn=0.24.2
11
+ - pip
12
+ - pip :
13
+ - Pillow==6.2.1
14
+ - tensorboardX==1.5
15
+ - scikit-image==0.16.2
16
+ - torch==1.7.1
17
+ - torchvision==0.8.2
You can’t perform that action at this time.
0 commit comments