4
4
from torch import distributions as pyd
5
5
from torch import nn
6
6
7
- from . import utils
7
+ from . import sac_utils
8
8
from torchbenchmark .util .distribution import SquashedNormal
9
9
10
10
def weight_init (m ):
@@ -30,11 +30,11 @@ def __init__(self, obs_shape, out_dim=50):
30
30
self .conv3 = nn .Conv2d (32 , 32 , kernel_size = 3 , stride = 1 )
31
31
self .conv4 = nn .Conv2d (32 , 32 , kernel_size = 3 , stride = 1 )
32
32
33
- output_height , output_width = utils .compute_conv_output (
33
+ output_height , output_width = sac_utils .compute_conv_output (
34
34
obs_shape [1 :], kernel_size = (3 , 3 ), stride = (2 , 2 )
35
35
)
36
36
for _ in range (3 ):
37
- output_height , output_width = utils .compute_conv_output (
37
+ output_height , output_width = sac_utils .compute_conv_output (
38
38
(output_height , output_width ), kernel_size = (3 , 3 ), stride = (1 , 1 )
39
39
)
40
40
@@ -63,15 +63,15 @@ def __init__(self, obs_shape, out_dim=50):
63
63
self .conv2 = nn .Conv2d (32 , 64 , kernel_size = 4 , stride = 2 )
64
64
self .conv3 = nn .Conv2d (64 , 64 , kernel_size = 3 , stride = 1 )
65
65
66
- output_height , output_width = utils .compute_conv_output (
66
+ output_height , output_width = sac_utils .compute_conv_output (
67
67
obs_shape [1 :], kernel_size = (8 , 8 ), stride = (4 , 4 )
68
68
)
69
69
70
- output_height , output_width = utils .compute_conv_output (
70
+ output_height , output_width = sac_utils .compute_conv_output (
71
71
(output_height , output_width ), kernel_size = (4 , 4 ), stride = (2 , 2 )
72
72
)
73
73
74
- output_height , output_width = utils .compute_conv_output (
74
+ output_height , output_width = sac_utils .compute_conv_output (
75
75
(output_height , output_width ), kernel_size = (3 , 3 ), stride = (1 , 1 )
76
76
)
77
77
0 commit comments