【强化学习】SarsaLambda算法详解以及用于二维空间探索【Python实现】

SarsaLambda算法

本文工作基于之前的几篇文章的项目,如果有疑问可以看下面文章:

回到正题上。

无论是在Sarsa算法还是Q-Learning中,每次学习都是只迭代Q表中的[S, A]这个位置的节点。
之前也说过,这样的迭代效率非常低,因为这样每次都只有下一个能直接获取到奖励的节点,其[S, A]组合才能获得更新。

而后续的路径上的节点的更新,只能通过不同的episode的迭代反复试错来往后传导数据。
这样算法的学习效率就很低。

因此,提出了SarsaLambda算法。

Lambda

我看到很多中文教程,都是说是往后传Lambda个节点,包括mofan大神的文章。
其实这种说法是有问题的。

准确说,应该是以Lambda系数的衰减往后传递信息。

具体算法迭代

  • 而Sarsa的迭代公式是:

一般的更新的公式
D e l t a = R + γ ∗ Q [ S n e x t , A n e x t ] − Q [ S , A ] Delta = R + \gamma * {Q[S_{next}, A_{next}]} - Q[S, A]Delta=R+γQ[Snext,Anext]Q[S,A]

对于下一步是终点的更新公式
D e l t a = R Delta = RDelta=R

其实到这里都还是Sarsa的部分。(或者整个架构还是原来的Q-Learning)

多出了一个新的Table,不妨称之为E表E表是用来记录路径衰减信息的。

论文中给出的更新公式是

E [ S , A ] + = 1 E[S, A] += 1E[S,A]+=1

这时候,更新再更新Q表。值得注意的是,这里不是单纯的更新该step下涉及的[S, A]组合,而是更新整个Q表

Q = α ∗ D e l t a ∗ E Q = \alpha * Delta * EQ=αDeltaE

更新完之后,再衰减整个E表(如下,但是我觉得乘不乘 γ \gammaγ 问题都不大)。

E ∗ = λ ∗ γ E *= \lambda * \gammaE=λγ

然后每个episode之后,都需要重新初始化一下E表,毕竟他是记录每次的状态转移的路径的。之后不同路径,肯定是需要重新刷新一下。

lambda更新思路的改进

上面用到E [ S , A ] + = 1 E[S, A] += 1E[S,A]+=1的算法。其实有个小问题。

以二维图探索为例,如果一个路径中存在在某个节点反复摇摆,即存在环。那么该地方的E值会被不断拉大。甚至有可能将比直接获得奖励的节点,受到的影响还要大。而这就与设计E的初衷不一致了。

因此一个可能的操作就是,类似于clips操作一样,做截断。

E [ S , : ] = 0 E [ S , A ] = 1 E[S, :] = 0\\ E[S, A] = 1E[S,:]=0E[S,A]=1

这样就相当于初始化了一个环。从而降低了循环引用的问题。(这个我看各路文章都没有讲过这个问题。还是有必要指出来的

代码实现

  • env.py 发现envr中random函数有点问题。理论上,随机到初始化的节点上的。这样的图是没有意义的。这里做个简单修改。
import time
import numpy as np


class Maze(object):
    def __init__(self, shape=None, hell_num=2):
        if (shape is None) or (not isinstance(shape, (tuple, list))) or (len(shape) > 2):
            shape = (5, 5)
        self.shape = shape
        self.map = np.zeros(shape)
        self.actions = {
            'u': [-1, 0],
            'd': [1, 0],
            'l': [0, -1],
            'r': [0, 1]
        }

        for _ in range(hell_num):
            self._random_num(shape, -1)
        self._random_num(shape, 1)

        self.point = None
        self.refresh()

    def _random_num(self, shape, v):
        n = shape[0] * shape[1]
        while True:
            rd_num = np.random.randint(1, n - 1)
            y = rd_num // shape[0]
            x = rd_num % shape[0]
            if self.map[x][y] == 0:
                self.map[x][y] = v
                break

    def refresh(self):
        self.point = [0, 0]

    def point_check(self, point):
        flags = [0, 1]
        for f in flags:
            if (point[f] < 0) or (point[f] >= self.shape[f]):
                return False
        return True

    def get_env_feedback(self, A):
        if A not in self.actions:
            raise Exception("Wrong Action")
        A = self.actions[A]
        point_ = [
            self.point[0] + A[0],
            self.point[1] + A[1]
        ]
        if self.point_check(point_):
            self.point = point_
            R = self.map[self.point[0]][self.point[1]]
            done = (R != 0)
        else:
            R, done = -1, False
        return self.point, R, done

    def show_matrix(self, m):
        for x in m:
            print(' '.join(list(map(lambda i: str(int(i)) if not isinstance(i, str) else i, x))))

    def update(self, done, episode, step, r=None):
        # os.system("cls")
        m = self.map.tolist()
        m[self.point[0]][self.point[1]] = 'x'
        self.show_matrix(m)
        print("==========")
        if done:
            print("episode: %s; step: %s; reward: %s" % (episode, step, r))
            time.sleep(3)
        else:
            time.sleep(0.3)
  • RL_Brain.py(learn函数,有个参数algo,可以用来控制, E迭代的策略)
import pandas as pd
import numpy as np


class RLBrain(object):
    def __init__(self, actions, lr=0.1, gamma=0.9, epsilon=0.9, trace=0.9):
        self.actions = actions
        self.q_table = pd.DataFrame(
            [],
            columns=self.actions
        )
        self.e_table = pd.DataFrame(
            [],
            columns=self.actions
        )
        # trace is what called lambda.
        self.lr, self.gamma, self.epsilon, self.trace = lr, gamma, epsilon, trace

    def check_state(self, s):
        if (s not in self.q_table.index) or (s not in self.e_table.index):
            to_append = pd.Series(
                [0] * len(self.actions),
                index=self.actions,
                name=s
            )
            if s not in self.q_table.index:
                self.q_table = self.q_table.append(
                    to_append
                )
            if s not in self.e_table.index:
                self.e_table = self.e_table.append(
                    to_append
                )

    def choose_action(self, s):
        self.check_state(s)
        state_table = self.q_table.loc[s, :]

        if (np.random.uniform() >= self.epsilon) or (state_table == 0).all():
            return np.random.choice(self.actions)
        else:
            return np.random.choice(state_table[state_table == np.max(state_table)].index)

    def learn(self, s, s_, a, r, done, a_, algo=0):
        self.check_state(s_)
        q_old = self.q_table.loc[s, a]
        if done:
            q_new = r
        else:
            q_new = r + self.gamma * self.q_table.loc[s_, a_]
        delta = q_new - q_old

        # update e_table
        if algo == 1:
            self.e_table.loc[s, a] += 1
        else:
            self.e_table.loc[s, :] = 0
            self.e_table.loc[s, a] = 1

        self.q_table += self.lr * delta * self.e_table

        self.e_table *= self.gamma * self.trace

    def refresh(self):
        self.e_table *= 0

  • treasure_maze_main.py
from RL_Brain import RLBrain
from env import Maze

if __name__ == '__main__':
    ALPHA = 0.1
    GAMMA = 0.9
    EPSILON = 0.9
    MAX_EPISODE = 15

    env = Maze(shape=(3, 4))
    RL = RLBrain(actions=list(env.actions.keys()))
    for episode in range(MAX_EPISODE):
        env.refresh()

        step_counter = 0
        done = False

        env.update(done, episode, step_counter)

        s = env.point
        a = RL.choose_action(str(s))
        while not done:
            s_, r, done = env.get_env_feedback(a)
            if not done:
                a_ = RL.choose_action(str(s_))
            else:
                a_ = None
            RL.learn(str(s), str(s_), a, r, done, a_)
            s = s_
            a = a_
            step_counter += 1
            env.update(done, episode, step_counter, r)


版权声明:本文为a19990412原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。