Skip to content

Commit 0a2b81d

Browse files
committed
QLearning play Taxi-v2 game
1 parent df018bf commit 0a2b81d

File tree

8 files changed

+75
-1
lines changed

8 files changed

+75
-1
lines changed

QLearning/QLearning_Taxi_v2.py

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import numpy as np
2+
import gym
3+
import random
4+
5+
env = gym.make("Taxi-v2")
6+
7+
action_size = env.action_space.n
8+
state_size = env.observation_space.n
9+
qtable = np.zeros((state_size, action_size))
10+
11+
# Hyperparameters
12+
total_episodes = 50000
13+
total_test_episodes = 100
14+
max_steps = 99
15+
learning_rate = 0.7
16+
gamma = 0.618
17+
epsilon = 1.0
18+
max_epsilon = 1.0
19+
min_epsilon = 0.01
20+
decay_rate = 0.01
21+
22+
# Train
23+
for episode in range(total_episodes):
24+
state = env.reset()
25+
26+
for step in range(max_steps):
27+
exp_exp_tradeoff = random.uniform(0, 1)
28+
if exp_exp_tradeoff > epsilon:
29+
action = np.argmax(qtable[state, :])
30+
else:
31+
action = env.action_space.sample()
32+
33+
new_state, reward, done, info = env.step(action)
34+
qtable[state, action] += learning_rate * (reward + gamma * np.max(qtable[new_state, :]) - qtable[state, action])
35+
36+
state = new_state
37+
if done: break
38+
39+
epsilon = min_epsilon + (max_epsilon - min_epsilon) * np.exp(-decay_rate * (episode+1))
40+
41+
42+
# Play the Game
43+
rewards = []
44+
for episode in range(total_test_episodes):
45+
state = env.reset()
46+
total_rewards = 0
47+
48+
print('='*20)
49+
print("[*] Episode", episode)
50+
print('='*20)
51+
52+
for step in range(max_steps):
53+
env.render()
54+
action = np.argmax(qtable[state, :])
55+
state, reward, done, info = env.step(action)
56+
total_rewards += reward
57+
58+
if done:
59+
rewards.append(total_rewards)
60+
print('[*] Score', total_rewards)
61+
break
62+
63+
env.close()
64+
print('[*] Average Score: ' + str(sum(rewards) / total_test_episodes))

README.md

+11-1
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,18 @@ players draw rate: 0.09528
8585

8686
Taxi v2
8787
-------
88-
[to be done]
8988

89+
<div align=center>
90+
<img width="93" height="133" src="imgs/taxi1.png" alt="Taxi v2">
91+
<img width="93" height="133" src="imgs/taxi2.png" alt="Taxi v2">
92+
<img width="93" height="133" src="imgs/taxi3.png" alt="Taxi v2">
93+
<img width="93" height="133" src="imgs/taxi4.png" alt="Taxi v2">
94+
<img width="93" height="133" src="imgs/taxi5.png" alt="Taxi v2">
95+
<img width="93" height="133" src="imgs/taxi6.png" alt="Taxi v2">
96+
<img width="93" height="133" src="imgs/taxi7.png" alt="Taxi v2">
97+
</div>
98+
99+
基于 `Q-Learning``Taxi v2` 游戏:[[code]](QLearning/QLearning_Taxi_v2.py)
90100

91101

92102
[0]. [Diving deeper into Reinforcement Learning with Q-Learning](https://medium.freecodecamp.org/diving-deeper-into-reinforcement-learning-with-q-learning-c18d0db58efe)<br/>

imgs/taxi1.png

3.26 KB
Loading

imgs/taxi2.png

2.19 KB
Loading

imgs/taxi3.png

1.71 KB
Loading

imgs/taxi4.png

1.68 KB
Loading

imgs/taxi5.png

1.66 KB
Loading

imgs/taxi6.png

1.61 KB
Loading

0 commit comments

Comments
 (0)