We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 2ccc48d commit b7c1acfCopy full SHA for b7c1acf
MADDPG.py
@@ -103,14 +103,16 @@ def update_policy(self):
103
104
target_Q = th.zeros(
105
self.batch_size).type(FloatTensor)
106
+
107
target_Q[non_final_mask] = self.critics_target[agent](
108
non_final_next_states.view(-1, self.n_agents * self.n_states),
109
non_final_next_actions.view(-1,
- self.n_agents * self.n_actions))
110
-
+ self.n_agents * self.n_actions)
111
+ ).squeeze()
112
# scale_reward: to scale reward in Q functions
- target_Q = (target_Q * self.GAMMA) + (
113
- reward_batch[:, agent] * scale_reward)
114
+ target_Q = (target_Q.unsqueeze(1) * self.GAMMA) + (
115
+ reward_batch[:, agent].unsqueeze(1) * scale_reward)
116
117
loss_Q = nn.MSELoss()(current_Q, target_Q.detach())
118
loss_Q.backward()
0 commit comments