许可优化
许可优化
产品
产品
解决方案
解决方案
服务支持
服务支持
关于
关于
软件库
当前位置:服务支持 >  软件文章 >  Tabular Dyna-Q算法实现(8.1)

Tabular Dyna-Q算法实现(8.1)

阅读数 16
点赞 0
article_banner

算法伪 代码

在这里插入图片描述



使用 maze 环境:maze_env 的代码见这里

import numpy as np
import pandas as pd
from maze_env import Maze
import random


class Q(object):
    def __init__(self, action_space):
        self.nA = action_space
        self.actions = list(range(action_space))

        self.q_table = pd.DataFrame(columns=self.actions)
        self.init_Q()

    def init_Q(self):
    	# 对所有的状态以及动作进行初始化
        for x in range(5, 165, 40):
            for y in range(5, 165, 40):
                if x == 45 and y == 85:
                    s = 'terminal'
                elif x == 85 and y == 45:
                    s = 'terminal'
                elif x == 85 and y == 85:
                    s = 'terminal'
                else:
                    s = [x+0.0, y+0.0, x + 30.0, y + 30.0]
                    s = str(s)
                if s not in self.q_table.index:
                    self.q_table = self.q_table.append(
                        pd.Series([0] * len(self.actions),
                                  index=self.q_table.columns,
                                  name=s)
                    )

    def target_policy(self, s):
        # target_policy is the greedy policy
        # self.check_state_exist(s)
        A = self.target_policy_probs(s)
        return np.random.choice(range(self.nA), p=A)

    def target_policy_probs(self, s, epsilon=.3):
        A = np.ones(self.nA, dtype=float) * epsilon / self.nA
        best_action = np.argmax(self.q_table.loc[s, :])
        A[best_action] += (1.0 - epsilon)
        return A


class Model(object):
    def __init__(self):
        self.model = dict()

    def store(self, s, a, r, s_):
        self.model[s, a] = [r, s_]


if __name__ == '__main__':
    env = Maze()
    action_space = env.n_actions
    RL = Q(action_space)
    model = Model()

    gamma = 0.9
    alpha = 0.01
    n_times = 4

    for episode in range(100):
        state = env.reset()

        while True:
            env.render()
            action = RL.target_policy(str(state))
            state_, reward, done = env.step(action)

            G = reward + gamma * np.max(RL.q_table.loc[str(state_), :])
            RL.q_table.loc[str(state), action] += alpha * (G - RL.q_table.loc[str(state), action])

            model.store(str(state), action, reward, str(state_))

            for i in range(n_times):
                S_A = random.choice(list(model.model.keys()))
                S_A = list(S_A)
                S = S_A[0]
                A = S_A[1]
                R, S_ = model.model[S, A]
                G = R + gamma * np.max(RL.q_table.loc[S_, :])
                RL.q_table.loc[S, A] += alpha * (G - RL.q_table.loc[S, A])

            if done:
                break

            state = state_

    print('game over')
    env.destroy()


免责声明:本文系网络转载或改编,未找到原创作者,版权归原作者所有。如涉及版权,请联系删


相关文章
技术文档
QR Code
微信扫一扫,欢迎咨询~
customer

online

联系我们
武汉格发信息技术有限公司
湖北省武汉市经开区科技园西路6号103孵化器
电话:155-2731-8020 座机:027-59821821
邮件:tanzw@gofarlic.com
Copyright © 2023 Gofarsoft Co.,Ltd. 保留所有权利
遇到许可问题?该如何解决!?
评估许可证实际采购量? 
不清楚软件许可证使用数据? 
收到软件厂商律师函!?  
想要少购买点许可证,节省费用? 
收到软件厂商侵权通告!?  
有正版license,但许可证不够用,需要新购? 
联系方式 board-phone 155-2731-8020
close1
预留信息,一起解决您的问题
* 姓名:
* 手机:

* 公司名称:

姓名不为空

姓名不为空

姓名不为空
手机不正确

手机不正确

手机不正确
公司不为空

公司不为空

公司不为空