Deep Q Learning Code Analyze (2)
分析的源码来自于deepmind在Natrue上发表的论文Human-level control through deep reinforcement learning所附的源码。源码下载
文件结构 续上
NeuralQLearner.lua
该文件定义了一个dqn.NerualQLearner的类,该类主要制定了深度Q学习的学习规则。同样地,这里对该类的成员函数一一进行解读。
nql:__init(args)
类对象的初始化。由于初始化的对象很多,这里就不一一介绍,主要介绍几个难以理解的成员变量,其他成员变量可参考源文件。
另外,在初始化的函数中,还调用了lua语言内置的pcall函数来载入网络和预处理网络。例如如下代码:
pcall函数是lua的内置处理函数,一般的使用方法是msg,err=pcall(func,param),通过调用func(param)函数,如果调用成功,则msg返回true,err返回func(param)的返回值,如果出现错误和异常,则msg返回nil,err返回错误的信息。该函数在lua相当于try,out的作用。
通过使用pcall调用载入函数,可以事先对self.network和self.preproc进行初始化。
nql:reset(state)
重置类对象,主要是载入state.best_network和state.model。然后将self.dw归零,将执行步数self.numSteps置零。
nql:preprocess(rawstate)
将原始状态进行预处理
nql:getQUpdate(args)
该函数主要以一个transition: < s,a,r,s’,term>作为输入,然后通过计算获得Q值,以及targets,残差等。
首先该函数会载入args的参数,包括args.s,args.a,args.r,args.s2,args.term。然后按照下面的程序计算:
这里得到了delta=r + (1-terminal) gamma max_a Q(s2, a) - Q(s, a),注意到,如果定义了self.clip_delta,那么将残差进行限幅操作,将幅度不在[-self.clip_delta,self.clip_delta]的delta值强行clip。
同时,函数定义了targets矩阵,其中target是一个二维矩阵,第一维表示batch_size,第二维表示actions。这里,我们将delta的值赋给target对应的action位置,其他action处,target=0。
最后函数返回targets,delta以及q2_max的值。
nql:qLearnMinibatch()
这个函数的主要目的是执行一个minibatch的Q-learning的update,其中采用的更新权重的方法是PMSProp,这里w += alpha (r + gamma max Q(s2,a2) - Q(s,a)) dQ(s,a)/dw
a=addcdiv(b,c,d)表示a=a+b*d/c
nql:sample_validation_data()
利用transition类的sample函数,采样self.validsize个样本,并将数据存储到self.valid\(s,a,r,s2,term)中。
nql:compute_validation_statistics()
计算得到validation上的平均Q_max值,和平均误差(误差指target和destination之间的差)
nql:eGreedy()
该函数主要的目的按照greed expolation的方式去选择一个action
nql:greedy(state)
这个函数的目的就是用来根据最大Q值选择一个action的值,注意到,如果有几个action的Q值均为最大,那么随机选择一个action执行。
nql:perceive(reward,rawstate,terminal,testing,testing_ep)
这个函数会与transition类之间进行交互,然后更新Q值,选择action,并进行参数的优化。
首先,将rawstate进行预处理,并定义当前状态
然后根据self.max_reward,self.min_reward和self.rescale_r将reward进行限幅。
调用transition类,将state(这里的state只包含一帧图像)加入recent存储区,然后从transition中采样得到新的state,此时的state是由多帧构成的。接下来将新的transition存储到存储区内。
|
|
利用eGreedy算法得到新的action
进行Q-learning更行权重,这里更新每隔self.update_freq步才进行一次权重的的更新,也就是说每两次更新之间执行self.update_freq次action,然后每次更新会重复连续学习self.n_replay次
更新学习步数
学习完之后,此时的状态和action都发生的改变,我们需要将last的状态和action进行一个更新。
每隔self.target_q个步骤,将参数copy到target网络。
最后返回要执行的actionIndex值
nql:createNetwork()
创建一个三个线性层的网络,这是一个三层的多层感知器
nql:_loadNet()
载入网络,返回self.network
nql:init(arg)
手动初始化
nql:report()
调用nnutil.lua中的get_weight_norms,get_grad_norms函数,输出network的信息。
train_agent.lua
这是训练的主程序,这里对其进行解析。
初始化
调用setup.lua进行初始化,得到game_env,game_actions,agent,opt。
然后初始化参数列表
训练
调用nql:perceive()函数进行训练,得到执行的action_index。
如果游戏已经结束,那么重新进入下一个游戏
每隔opt.prog_freq步就输出网络的信息
在lua语言中,不会自动处理垃圾,需要调用collectgarbage()手动处理。
训练(在特定的步数上进行验证)
每隔opt.eval_freq步就进行验证。首先进行初始化。
然后调用nql:perceive()进行验证,注意到这里的testing参数为true,ep固定为0.05。
计算时间
获得统计数据,注意到由于每次测试都有可能执行了不同的eposide,我们这里计算每个eposide的平均值。
输出信息
训练(在特定的步数上进行保存)
每隔opt.save_freq步或者训练完之后,将网络进行保存。对于保存的程序,这里就不进行分析了。