Skip to content

Commit 7079db0

Browse files
committed
+ pg pytorch implement
1 parent 73124c5 commit 7079db0

File tree

10 files changed

+185
-23
lines changed

10 files changed

+185
-23
lines changed
1.73 KB
Binary file not shown.

PG/Cartpole_pytorch/PG_CartPole.py

+173
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.distributions import Categorical
4+
import numpy as np
5+
import gym
6+
from gym.spaces import Discrete, Box
7+
import argparse
8+
import random
9+
10+
seed = 1
11+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12+
torch.manual_seed(seed)
13+
random.seed(seed)
14+
np.random.seed(seed)
15+
if torch.cuda.is_available():
16+
torch.cuda.manual_seed_all(seed)
17+
DEBUG = False
18+
else:
19+
DEBUG = True
20+
21+
def weight_init(m):
22+
'''
23+
Code from https://gist.github.com/jeasinema/ed9236ce743c8efaf30fa2ff732749f5
24+
Usage:
25+
model = Model()
26+
model.apply(weight_init)
27+
'''
28+
if isinstance(m, nn.Linear):
29+
nn.init.xavier_normal_(m.weight.data)
30+
nn.init.normal_(m.bias.data)
31+
32+
def reward_to_go(rews):
33+
n = len(rews)
34+
rtgs = np.zeros_like(rews)
35+
for i in reversed(range(n)):
36+
rtgs[i] = rews[i] + (rtgs[i+1] if i+1 < n else 0)
37+
return rtgs
38+
39+
class MLP(nn.Module):
40+
def __init__(self, sizes, activation=nn.Tanh, output_activation=None):
41+
super().__init__()
42+
43+
net = []
44+
for i in range(len(sizes)-1):
45+
net.append(nn.Linear(sizes[i], sizes[i+1]))
46+
if i == len(sizes) - 2:
47+
if output_activation is not None:
48+
net.append(output_activation())
49+
else:
50+
net.append(activation())
51+
52+
self.mlp = nn.Sequential(
53+
*net,
54+
nn.Softmax(dim=-1)
55+
)
56+
57+
def forward(self, x):
58+
return self.mlp(x)
59+
60+
def train(env_name='CartPole-v0', hidden_sizes=[32], lr=1e-2,
61+
epochs=50, batch_size=5000, render=False):
62+
63+
# make environment, check spaces, get obs / act dims
64+
env = gym.make(env_name)
65+
assert isinstance(env.observation_space, Box), \
66+
"This example only works for envs with continuous state spaces."
67+
assert isinstance(env.action_space, Discrete), \
68+
"This example only works for envs with discrete action spaces."
69+
70+
obs_dim = env.observation_space.shape[0]
71+
n_acts = env.action_space.n
72+
73+
policy = MLP(sizes=[obs_dim]+hidden_sizes+[n_acts])
74+
policy.apply(weight_init)
75+
optimizer = torch.optim.Adam(policy.parameters(), lr=lr)
76+
77+
# for training policy
78+
def train_one_epoch():
79+
# make some empty lists for logging.
80+
batch_obs = [] # for observations
81+
batch_acts = [] # for actions
82+
batch_weights = [] # for reward-to-go weighting in policy gradient
83+
batch_rets = [] # for measuring episode returns
84+
batch_lens = [] # for measuring episode lengths
85+
86+
# reset episode-specific variables
87+
obs = env.reset() # first obs comes from starting distribution
88+
done = False # signal from environment that episode is over
89+
ep_rews = [] # list for rewards accrued throughout ep
90+
91+
# render first episode of each epoch
92+
finished_rendering_this_epoch = False
93+
94+
# collect experience by acting in the environment with current policy
95+
policy.eval()
96+
while True:
97+
# rendering
98+
if (not finished_rendering_this_epoch) and render:
99+
env.render()
100+
101+
# save obs
102+
batch_obs.append(obs.copy())
103+
104+
# act in the environment
105+
with torch.no_grad():
106+
act_probs = policy(torch.tensor(obs, dtype=torch.float))
107+
dist = Categorical(act_probs)
108+
act = dist.sample().item()
109+
110+
obs, rew, done, _ = env.step(act)
111+
112+
# save action, reward
113+
batch_acts.append(act)
114+
ep_rews.append(rew)
115+
116+
if done:
117+
# if episode is over, record info about episode
118+
ep_ret, ep_len = sum(ep_rews), len(ep_rews)
119+
batch_rets.append(ep_ret)
120+
batch_lens.append(ep_len)
121+
122+
# the weight for each logprob(a_t|s_t) is reward-to-go from t
123+
batch_weights += list(reward_to_go(ep_rews))
124+
125+
# reset episode-specific variables
126+
obs, done, ep_rews = env.reset(), False, []
127+
128+
# won't render again this epoch
129+
finished_rendering_this_epoch = True
130+
131+
# end experience loop if we have enough of it
132+
if len(batch_obs) > batch_size:
133+
break
134+
135+
# take a single policy gradient update step
136+
policy.train()
137+
batch_obs = torch.tensor(batch_obs, dtype=torch.float)
138+
batch_acts = torch.tensor(batch_acts)
139+
batch_weights = torch.tensor(batch_weights)
140+
141+
batch_act_probs = policy(batch_obs)
142+
dist = Categorical(batch_act_probs)
143+
log_probs = dist.log_prob(batch_acts)
144+
loss = (- log_probs * batch_weights).mean()
145+
146+
optimizer.zero_grad()
147+
loss.backward()
148+
optimizer.step()
149+
150+
return loss, batch_rets, batch_lens
151+
152+
# training loop
153+
max_avg_ret = 0
154+
for i in range(epochs):
155+
batch_loss, batch_rets, batch_lens = train_one_epoch()
156+
print(f'epoch: {i:2d} loss: {batch_loss:.3f} episode average rewards: {np.mean(batch_rets):.3f} episode average len: {np.mean(batch_lens):.3f}')
157+
158+
if np.mean(batch_rets) > max_avg_ret:
159+
max_avg_ret = np.mean(batch_rets)
160+
torch.save(policy.state_dict(), 'PG_{}.pth'.format(env_name))
161+
162+
env.close()
163+
164+
165+
if __name__ == '__main__':
166+
parser = argparse.ArgumentParser()
167+
parser.add_argument('--env_name', '--env', type=str, default='CartPole-v0')
168+
parser.add_argument('--render', action='store_true')
169+
parser.add_argument('--lr', type=float, default=1e-2)
170+
parser.add_argument('--epochs', type=int, default=50)
171+
args = parser.parse_args()
172+
print('\nUsing reward-to-go formulation of policy gradient.\n')
173+
train(env_name=args.env_name, render=args.render, lr=args.lr, epochs=args.epochs)
File renamed without changes.

README.md

+12-23
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ Table of contents
3333
* [Gomoku](#gomoku)
3434
* [AlphaGomoku](#alphagomoku)
3535
* [RNA Folding Path](#rna-folding-path)
36+
* [Atari Game Roms](#atari-game-roms)
3637

3738

3839
Q-Learning
@@ -208,23 +209,16 @@ Doom Deadly Corridor
208209
<img src="imgs/play_doom_deadly_corridor.gif" alt="play Doom Deadly Corridor">
209210
</div>
210211

211-
其中,Dueling DQN 的神经网络如下图: [[code]](DDDQN/Doom-Deadly-Corridor/.py)
212+
其中,Dueling DQN 的神经网络如下图: [[code]](DDDQN/Doom-Deadly-Corridor/)
212213

213214
![Dueling DQN](imgs/dueling_DQN2.png)
214215

215216
Prioritized Experience Replay 采用 SumTree 的方法:
216217

217218
![SumTree](imgs/sumtree.png)
218219

219-
训练大约 xxx 轮后结果如下:
220-
221-
![]()
222-
223-
```
224-
```
225-
226-
[0]. [Improvements in Deep Q Learning: Dueling Double DQN, Prioritized Experience Replay, and fixed Q-targets](https://medium.freecodecamp.org/improvements-in-deep-q-learning-dueling-double-dqn-prioritized-experience-replay-and-fixed-58b130cc5682)
227-
[1]. [Let’s make a DQN: Double Learning and Prioritized Experience Replay](https://jaromiru.com/2016/11/07/lets-make-a-dqn-double-learning-and-prioritized-experience-replay/)
220+
[0]. [Improvements in Deep Q Learning: Dueling Double DQN, Prioritized Experience Replay, and fixed Q-targets](https://medium.freecodecamp.org/improvements-in-deep-q-learning-dueling-double-dqn-prioritized-experience-replay-and-fixed-58b130cc5682)<br/>
221+
[1]. [Let’s make a DQN: Double Learning and Prioritized Experience Replay](https://jaromiru.com/2016/11/07/lets-make-a-dqn-double-learning-and-prioritized-experience-replay/)<br/>
228222
[2]. [Double Dueling Deep Q Learning with Prioritized Experience Replay - Notebook](https://github.com/simoninithomas/Deep_reinforcement_learning_Course/blob/master/Dueling%20Double%20DQN%20with%20PER%20and%20fixed-q%20targets/Dueling%20Deep%20Q%20Learning%20with%20Doom%20(%2B%20double%20DQNs%20and%20Prioritized%20Experience%20Replay).ipynb)
229223

230224

@@ -242,7 +236,7 @@ CartPole Game
242236
<img src="imgs/play_cartpole.gif" alt="Play CartPole Game">
243237
</div>
244238

245-
其中,Policy Gradient 神经网络如下图。[[code]](PG/Cartpole/PG_Cartpole.py)
239+
其中,Policy Gradient 神经网络如下图。
246240

247241
![Policy Gradient Network](imgs/pg_network.png)
248242

@@ -280,6 +274,8 @@ Max reward so far: 111837.0
280274
[*] Model Saved: ./model/model.ckpt
281275
```
282276

277+
具体代码请参见:[[tensorflow]](PG/Cartpole_tensorflow/PG_Cartpole.py) [[pytorch]](PG/Cartpole_pytorch/PG_Cartpole.py)
278+
283279

284280
Doom Deathmatch
285281
---------------
@@ -288,20 +284,9 @@ Doom Deathmatch
288284
<img src="imgs/play_doom_deathmatch.gif" alt="play Doom Deathmatch">
289285
</div>
290286

291-
神经网络如下图。[[code]](PG/Doom-Deathmatch/PG_Doom_Deathmatch.py)
292-
293287
![](imgs/pg_doom_deathmatch.png)
294288

295-
训练 5000 局后结果如下:
296-
297-
![]()
298-
![]()
299-
300-
```
301-
302-
```
303-
304-
[to be done]
289+
神经网络如上,具体代码请参见:[[code]](PG/Doom-Deathmatch/PG_Doom_Deathmatch.py)
305290

306291

307292
[0]. [An introduction to Policy Gradients with Cartpole and Doom](https://medium.freecodecamp.org/an-introduction-to-policy-gradients-with-cartpole-and-doom-495b5ef2207f)<br/>
@@ -467,3 +452,7 @@ RNA Folding Path
467452

468453
使用深度强化学习来学习 RNA 分子的二级结构折叠路径。具体说明这里就不再重复了,请参见这里:[[link]](RNA_Secondary_Structure_Folding_Path/)
469454

455+
Atari Game Roms
456+
===============
457+
458+
这里有一些 Atari 游戏的 Rom,可以导入到 retro 环境中,方便进行游戏。[[link]](Roms/)

Roms/Roms.zip

18.8 MB
Binary file not shown.

0 commit comments

Comments
 (0)