Skip to content

Commit 46df033

Browse files
committed
Fix sac
1 parent 7c98b1c commit 46df033

File tree

4 files changed

+8
-8
lines changed

4 files changed

+8
-8
lines changed

torchbenchmark/models/soft_actor_critic/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from .envs import load_gym
1313
from .sac import SACAgent
1414
from .replay import PrioritizedReplayBuffer, ReplayBuffer
15-
from .utils import hard_update, soft_update
15+
from .sac_utils import hard_update, soft_update
1616

1717

1818
def learn_standard(

torchbenchmark/models/soft_actor_critic/nets.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch import distributions as pyd
55
from torch import nn
66

7-
from . import utils
7+
from . import sac_utils
88
from torchbenchmark.util.distribution import SquashedNormal
99

1010
def weight_init(m):
@@ -30,11 +30,11 @@ def __init__(self, obs_shape, out_dim=50):
3030
self.conv3 = nn.Conv2d(32, 32, kernel_size=3, stride=1)
3131
self.conv4 = nn.Conv2d(32, 32, kernel_size=3, stride=1)
3232

33-
output_height, output_width = utils.compute_conv_output(
33+
output_height, output_width = sac_utils.compute_conv_output(
3434
obs_shape[1:], kernel_size=(3, 3), stride=(2, 2)
3535
)
3636
for _ in range(3):
37-
output_height, output_width = utils.compute_conv_output(
37+
output_height, output_width = sac_utils.compute_conv_output(
3838
(output_height, output_width), kernel_size=(3, 3), stride=(1, 1)
3939
)
4040

@@ -63,15 +63,15 @@ def __init__(self, obs_shape, out_dim=50):
6363
self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
6464
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
6565

66-
output_height, output_width = utils.compute_conv_output(
66+
output_height, output_width = sac_utils.compute_conv_output(
6767
obs_shape[1:], kernel_size=(8, 8), stride=(4, 4)
6868
)
6969

70-
output_height, output_width = utils.compute_conv_output(
70+
output_height, output_width = sac_utils.compute_conv_output(
7171
(output_height, output_width), kernel_size=(4, 4), stride=(2, 2)
7272
)
7373

74-
output_height, output_width = utils.compute_conv_output(
74+
output_height, output_width = sac_utils.compute_conv_output(
7575
(output_height, output_width), kernel_size=(3, 3), stride=(1, 1)
7676
)
7777

torchbenchmark/models/soft_actor_critic/sac.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
import torch
55

6-
from . import envs, nets, replay, utils
6+
from . import envs, nets, replay, sac_utils
77

88

99
class SACAgent:

0 commit comments

Comments
 (0)