admin 管理员组

文章数量: 1087135


2023年12月23日发(作者:数据库的distinct是什么意思)

#epsilon-greedy#get the actiondef get_action(s, Q, epsilon, pi_0): direction = ["up", "right", "down", "left"] #print("s = " + str(s))

#probability epsilon to random search if () < epsilon: next_direction = (direction, p=pi_0[s, :]) else: #move by the maximum Q next_direction = direction[max(Q[s, :])]

if next_direction == "up": action = 0 elif next_direction == "right": action = 1 elif next_direction == "down": action = 2 elif next_direction == "left": action = 3

return action

#get next state by actiondef get_s_next(s, a, Q, epsilon, pi_0): direction = ["up", "right", "down", "left"] next_direction = direction[a]

if next_direction == "up": s_next = s - 3 elif next_direction == "right": s_next = s + 1 elif next_direction == "down": s_next = s + 3 elif next_direction == "left": s_next = s - 1

return s_next通过state, action更新Q矩阵的Sarsa算法:推导:kμk=∑j=1xjk−11=k(xk+∑j=1xj)1=k(xk+(k−1)μk−1)1=μk−1+k(xk−μk−1)1kfor each state

St with

Gt:N(St)←N(St)+1V(St)←V(St)+1N(St)(Gt−V(St))Gt=Rt+1+γV(St+1)得到:Q(st,at)=Q(st,at)+η∗(Rt+1+γQ(st+1,at+1)−Q(st,at))其中Rt+1+γQ(st+1,at+1)−Q(st,at)叫做TD error根据上述公式更新Q:

eta = 0.1gamma = 0.9epsilon = 0.5v = (Q, axis=1) #select the maximum Q value for each stateis_continue = Trueepisode = 1while is_continue: print("episode: " + str(episode))

epsilon = epsilon / 2 #epsilon-greedy

[s_a_history, Q] = goal_maze_ret_s_a_Q(Q, epsilon, eta, gamma, pi_0)

new_v = (Q, axis=1) #maximum value for each state

print(((new_v - v)))

v = new_v

print("steps to reach goal is: " + str(len(s_a_history) - 1))

episode = episode + 1 if episode > 100: break虽然写了到100 episode,但是从output看出很快就收敛到最优路径episode: 10.227489819094steps to reach goal is: 14episode: 20.1steps to reach goal is: 10episode: 30.5steps to reach goal is: 4episode: 40.78steps to reach goal is: 4episode: 50.78steps to reach goal is: 4episode: 60.18steps to reach goal is: 4episode: 70.4steps to reach goal is: 4episode: 8看下收敛后的路径及Q矩阵

完整迭代到100次:eta = 0.1 #learning-rategamma = 0.9 #decrease ratev = (Q, axis=1) #maximum value for each stateis_continue = Trueepisode = 1V = [] #state value for each ((Q, axis=1)) #get the maximum value for each statewhile is_continue: print("episode " + str(episode))

[s_a_history, Q] = goal_maze_ret_Q_learning(Q, eta, gamma) #get one path

new_v = (Q, axis=1)

print(((new_v - v))) #get the error

v = new_v

(v)

print("steps to reach goal: " + str(len(s_a_history) - 1)) episode = episode + 1 if episode > 100: is_continue = False收敛结果

episode 10.1steps to reach goal: 20episode 20.6steps to reach goal: 16episode 30.6steps to reach goal: 8episode 40.8steps to reach goal: 6episode 50.3steps to reach goal: 6episode 60.2steps to reach goal: 4episode 70.3steps to reach goal: 4episode 80.6steps to reach goal: 4episode 90.1steps to reach goal: 4episode 100.1steps to reach goal: 4episode 110.5steps to reach goal: 4episode 120.7steps to reach goal: 4episode 130.6steps to reach goal: 4episode 140.8steps to reach goal: 4episode 150.5steps to reach goal: 4episode 160.1steps to reach goal: 4episode 170.9steps to reach goal: 4episode 180.9steps to reach goal: 4episode 190.2steps to reach goal: 4episode 200.8steps to reach goal: 4


本文标签: 收敛 路径 根据上述 算法 数据库