admin 管理员组文章数量: 1087135
2023年12月23日发(作者:数据库的distinct是什么意思)
#epsilon-greedy#get the actiondef get_action(s, Q, epsilon, pi_0): direction = ["up", "right", "down", "left"] #print("s = " + str(s))
#probability epsilon to random search if () < epsilon: next_direction = (direction, p=pi_0[s, :]) else: #move by the maximum Q next_direction = direction[max(Q[s, :])]
if next_direction == "up": action = 0 elif next_direction == "right": action = 1 elif next_direction == "down": action = 2 elif next_direction == "left": action = 3
return action
#get next state by actiondef get_s_next(s, a, Q, epsilon, pi_0): direction = ["up", "right", "down", "left"] next_direction = direction[a]
if next_direction == "up": s_next = s - 3 elif next_direction == "right": s_next = s + 1 elif next_direction == "down": s_next = s + 3 elif next_direction == "left": s_next = s - 1
return s_next通过state, action更新Q矩阵的Sarsa算法:推导:kμk=∑j=1xjk−11=k(xk+∑j=1xj)1=k(xk+(k−1)μk−1)1=μk−1+k(xk−μk−1)1kfor each state
St with
Gt:N(St)←N(St)+1V(St)←V(St)+1N(St)(Gt−V(St))Gt=Rt+1+γV(St+1)得到:Q(st,at)=Q(st,at)+η∗(Rt+1+γQ(st+1,at+1)−Q(st,at))其中Rt+1+γQ(st+1,at+1)−Q(st,at)叫做TD error根据上述公式更新Q:
eta = 0.1gamma = 0.9epsilon = 0.5v = (Q, axis=1) #select the maximum Q value for each stateis_continue = Trueepisode = 1while is_continue: print("episode: " + str(episode))
epsilon = epsilon / 2 #epsilon-greedy
[s_a_history, Q] = goal_maze_ret_s_a_Q(Q, epsilon, eta, gamma, pi_0)
new_v = (Q, axis=1) #maximum value for each state
print(((new_v - v)))
v = new_v
print("steps to reach goal is: " + str(len(s_a_history) - 1))
episode = episode + 1 if episode > 100: break虽然写了到100 episode,但是从output看出很快就收敛到最优路径episode: 10.227489819094steps to reach goal is: 14episode: 20.1steps to reach goal is: 10episode: 30.5steps to reach goal is: 4episode: 40.78steps to reach goal is: 4episode: 50.78steps to reach goal is: 4episode: 60.18steps to reach goal is: 4episode: 70.4steps to reach goal is: 4episode: 8看下收敛后的路径及Q矩阵
完整迭代到100次:eta = 0.1 #learning-rategamma = 0.9 #decrease ratev = (Q, axis=1) #maximum value for each stateis_continue = Trueepisode = 1V = [] #state value for each ((Q, axis=1)) #get the maximum value for each statewhile is_continue: print("episode " + str(episode))
[s_a_history, Q] = goal_maze_ret_Q_learning(Q, eta, gamma) #get one path
new_v = (Q, axis=1)
print(((new_v - v))) #get the error
v = new_v
(v)
print("steps to reach goal: " + str(len(s_a_history) - 1)) episode = episode + 1 if episode > 100: is_continue = False收敛结果
episode 10.1steps to reach goal: 20episode 20.6steps to reach goal: 16episode 30.6steps to reach goal: 8episode 40.8steps to reach goal: 6episode 50.3steps to reach goal: 6episode 60.2steps to reach goal: 4episode 70.3steps to reach goal: 4episode 80.6steps to reach goal: 4episode 90.1steps to reach goal: 4episode 100.1steps to reach goal: 4episode 110.5steps to reach goal: 4episode 120.7steps to reach goal: 4episode 130.6steps to reach goal: 4episode 140.8steps to reach goal: 4episode 150.5steps to reach goal: 4episode 160.1steps to reach goal: 4episode 170.9steps to reach goal: 4episode 180.9steps to reach goal: 4episode 190.2steps to reach goal: 4episode 200.8steps to reach goal: 4
版权声明:本文标题:强化学习之迷宫问题(MC,Sarsa,Q-learning实现) 内容由网友自发贡献,该文观点仅代表作者本人, 转载请联系作者并注明出处:http://roclinux.cn/p/1703288046a445634.html, 本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容,一经查实,本站将立刻删除。
发表评论