Skip to content

Commit d910b8a

Browse files
xuzhao9facebook-github-bot
authored andcommitted
Fix the soft_actor_critic model (#2326)
Summary: Unfortunately, #2318 has a bug that breaks the `soft_actor_critic` model. Pull Request resolved: #2326 Reviewed By: aaronenyeshi Differential Revision: D58871386 Pulled By: xuzhao9 fbshipit-source-id: 5f8b5fbe00722ccb647b08a8089fd52a7719208c
1 parent 1425f68 commit d910b8a

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

torchbenchmark/models/soft_actor_critic/nets.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch import distributions as pyd
55
from torch import nn
66

7-
from . import utils
7+
from . import sac_utils
88
from torchbenchmark.util.distribution import SquashedNormal
99

1010
def weight_init(m):
@@ -30,11 +30,11 @@ def __init__(self, obs_shape, out_dim=50):
3030
self.conv3 = nn.Conv2d(32, 32, kernel_size=3, stride=1)
3131
self.conv4 = nn.Conv2d(32, 32, kernel_size=3, stride=1)
3232

33-
output_height, output_width = utils.compute_conv_output(
33+
output_height, output_width = sac_utils.compute_conv_output(
3434
obs_shape[1:], kernel_size=(3, 3), stride=(2, 2)
3535
)
3636
for _ in range(3):
37-
output_height, output_width = utils.compute_conv_output(
37+
output_height, output_width = sac_utils.compute_conv_output(
3838
(output_height, output_width), kernel_size=(3, 3), stride=(1, 1)
3939
)
4040

@@ -63,15 +63,15 @@ def __init__(self, obs_shape, out_dim=50):
6363
self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
6464
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
6565

66-
output_height, output_width = utils.compute_conv_output(
66+
output_height, output_width = sac_utils.compute_conv_output(
6767
obs_shape[1:], kernel_size=(8, 8), stride=(4, 4)
6868
)
6969

70-
output_height, output_width = utils.compute_conv_output(
70+
output_height, output_width = sac_utils.compute_conv_output(
7171
(output_height, output_width), kernel_size=(4, 4), stride=(2, 2)
7272
)
7373

74-
output_height, output_width = utils.compute_conv_output(
74+
output_height, output_width = sac_utils.compute_conv_output(
7575
(output_height, output_width), kernel_size=(3, 3), stride=(1, 1)
7676
)
7777

torchbenchmark/models/soft_actor_critic/sac.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
import torch
55

6-
from . import envs, nets, replay, utils
6+
from . import envs, nets, replay, sac_utils
77

88

99
class SACAgent:

0 commit comments

Comments
 (0)