Skip to content

Commit b7c1acf

Browse files
committed
update to pytorch 0.4.0
1 parent 2ccc48d commit b7c1acf

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

MADDPG.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -103,14 +103,16 @@ def update_policy(self):
103103

104104
target_Q = th.zeros(
105105
self.batch_size).type(FloatTensor)
106+
106107
target_Q[non_final_mask] = self.critics_target[agent](
107108
non_final_next_states.view(-1, self.n_agents * self.n_states),
108109
non_final_next_actions.view(-1,
109-
self.n_agents * self.n_actions))
110-
110+
self.n_agents * self.n_actions)
111+
).squeeze()
111112
# scale_reward: to scale reward in Q functions
112-
target_Q = (target_Q * self.GAMMA) + (
113-
reward_batch[:, agent] * scale_reward)
113+
114+
target_Q = (target_Q.unsqueeze(1) * self.GAMMA) + (
115+
reward_batch[:, agent].unsqueeze(1) * scale_reward)
114116

115117
loss_Q = nn.MSELoss()(current_Q, target_Q.detach())
116118
loss_Q.backward()

0 commit comments

Comments
 (0)