标签:div represent col 选择 this ges update help 移动
强化学习:当前的奖励值:
#当前奖励 = 当前的概率*(及时奖励 + 衰减系数 * 下一次的奖励)
不断迭代,直到当前的奖励值不发生变换
import numpy as np from gridworld import GridworldEnv env = GridworldEnv() def value_iteration(env, theta=0.0001, discount_factor=1.0): """ Value Iteration Algorithm. Args: env: OpenAI environment. env.P represents the transition probabilities of the environment. theta: Stopping threshold. If the value of all states changes less than theta in one iteration we are done. discount_factor: lambda time discount factor. Returns: A tuple (policy, V) of the optimal policy and the optimal value function. """ def one_step_lookahead(state, V): """ Helper function to calculate the value for all action in a given state. Args: state: The state to consider (int) V: The value to use as an estimator, Vector of length env.nS Returns: A vector of length env.nA containing the expected value of each action. """ # 每个位置的4个方向,计算当前位置的奖励值 A = np.zeros(env.nA) # 迭代四次 for a in range(env.nA): for prob, next_state, reward, done in env.P[state][a]: #当前奖励 = 当前的概率*(及时奖励 + 衰减系数 * 下一次的奖励) A[a] += prob * (reward + discount_factor * V[next_state]) return A V = np.zeros(env.nS) while True: # Stopping condition delta = 0 # Update each state... for s in range(env.nS): # Do a one-step lookahead to find the best action A = one_step_lookahead(s, V) # 选择奖励值最高的数 best_action_value = np.max(A) # Calculate delta across all states seen so far delta = max(delta, np.abs(best_action_value - V[s])) # Update the value function #V[s]使用最好的奖励值表示 V[s] = best_action_value # Check if we can stop # 如果奖励值不发生变化,跳出循环 if delta < theta: break # Create a deterministic policy using the optimal value function # 获得当前位置对应的最佳移动方向 policy = np.zeros([env.nS, env.nA]) for s in range(env.nS): # One step lookahead to find the best action for this state A = one_step_lookahead(s, V) # 最好的方向 best_action = np.argmax(A) # Always take the best action # s表示位置,best_action表示方向,用于后续的操作 policy[s, best_action] = 1.0 return policy, V policy, v = value_iteration(env) print("Policy Probability Distribution:") print(policy) print("") print("Reshaped Grid Policy (0=up, 1=right, 2=down, 3=left):") print(np.reshape(np.argmax(policy, axis=1), env.shape)) print("")
标签:div represent col 选择 this ges update help 移动
原文地址:https://www.cnblogs.com/my-love-is-python/p/10082200.html