Deep Q Learning Code Analyze (1)
分析的源码来自于deepmind在Natrue上发表的论文Human-level control through deep reinforcement learning所附的源码。源码下载
文件结构
代码采用torch框架进行组织,编写的语言均为lua语言,其中包括convnet.lua, convnet_atari3.lua, initenv.lua, net_downsample_2x_full_y.lua, NeuralQLearner.lua, nnutils.lua, Rectifier.lua, Scale.lua, train_agent.lua, TransitionTable.lua。
训练的主程序是从train_agent.lua(具体的train_agent.lua的解析见这里)开始。训练时的参数表如下:
训练的开始会调用initenv.lua初始化game_env, game_actions, agent, opt。
initenv.lua
initenv文件是在训练的初始阶段,用来初始化gamEnv,gameActions,agent,以及opt参数。提供了torchSetup函数和Setup函数,在这里torchSetup函数用来初始化一些与torch相关的参数,包括gpu参数,计算线程,以及tensorType等。
而Setup参数用来调用torchSetup函数,并对gameEnv,gameActions,agent进行了初始化操作。
gameEnv表示游戏的环境,通过调用getState()方法可以得到screen, reward和terminal参数。screen表示屏幕状态,这是DQN中的输入,terminal是布尔型变量,表示是否游戏结束。
nnutils.lua
nnutils文件主要提供了一些辅助函数。
该文件首先提供了recursive_map的函数,该函数接受module, field, func作为输入,返回一个字符串,其中module表示训练的模型,field指模型中的某类参数名,比如field=’weight’时,module[field]表示模型中的权重。该函数会返回字符串,包含了模型的类型名,对module[field]的统计数据(统计的方法视func而定)。
由于模型中包含了子模型,因此recusive_map函数会递归调用子模型,因此会形成模型的树状表示。
在nnuils的文件中,定义了abs_mean()和abs_max()的函数,表示平均值和最大值。另外也定义了get_weight_norms()和get_grad_norms()的函数,这两个函数会调用recursive_map函数,分别对权重和梯度值求均值和最大值。
Scale.lua
scale.lua文件定义了训练时的scale层(此时的torch并没有内置scale的层),并定义了forward和updateOutput方法,实际上这两个方法都是相同的功能。
在scale:forward(x)函数中,x表示输入的图像,该函数会调用image.rgb2y(x)将输入的图像变成灰度图,然后将它按照初始化的宽高进行放缩。
Rectifier.lua
同样地,Rectifier.lua文件定义了训练时的ReLU函数层,这里对前向传播和反向传播都进行了定义。
这里self.output.resizeAs(input)的意思就是将output,resize成和input同样的size。cmul()表示矩阵对应元素相乘。
convnet.lua
convnet.lua文件的目的是建立CNN结构,该文件仅仅包含一个函数:create_network。输入层的定义由初始化时的input_dims给出。注意到,在函数里对GPU和CPU的卷积层的实现方式有所区分。
卷积层的数量由初始化时的arg.n_units的长度给出(arg.n_units的每个元素的数值表示每一层的输出的feature map个数),如下所示,这里arg.nl()表示非线性层的意思。
在卷积的最后一层通过人为构造0的输入的方式,进行前向传播,并对输出层进行nElement()的方法可以求得卷积最后一层的神经元数量。
然后加入多个线性层,同样的,线性层的数量由arg.n_hid的长度给出(arg.n_hid的每个元素的数值表示每个线性层输出的神经元数量)
最后加入一个线性层,其输出神经元的额数量等于actions的数量
convnet_atari3.lua
这个文件主要是调用convnet.lua文件,并设置了一些对应的参数。
net_downsample_2x_full_y.lua
这个文件会在构建网络时,在输入层增加一个Scale层,此时设置的长和宽均为84,Scale层会将输入的图像先变成灰度图,然后放缩成84x84的大小。
TransitionTable.lua
该文件主要创造了一个dqn.TransitionTable类,每个transition表示<s,a,r,s’>,其中s表示state,a表示actions,r表示rewards,s’表示在s状态下执行a,得到的下一个状态s’。这个类用来存储一定数量的transitions,充当replay memory的角色。在CNN训练时,从这个replay memory中进行sample,sample出来的样本作为了网络的输入。
对于dqn.TransitionTable类,该文件中设计了不少的方法,这里进行一一的解读。
trans:__init(args)
首先通过读args直接进行对象的初始化,这里包含的参数如下,在这里hist表示history的意思,每一个history中存储的帧图像合并才构成一个状态(这样做的原因是因为单独的某一帧的图像无法得到运动物体的速度信息等):
然后函数会针对不同的self.histType来设定不同的self.histIndices,同时,self.recentMemSize表示存储时的history的跨度,也就是histIndices[histLen]的值。
在self.histLen=5的情况下,如果self.histType=”linear”,且self.histSpacing=2时,那么self.histIndices={2,4,6,8,10},self.recentMemSize=10。如果self.histType=”exp2”,那么self.histIndices={1,2,4,8,16},self.recentMemSize=16。
接下来对self.s,self.a,self.r,self.t进行初始化设置。
然后初始化了recent存储区,用来存储最近recentMemSize个帧的图像,也就是说在采样时这里只能采样一个状态,这可以用来建立最新的状态。
另外初始化时也定义了buffer区,在训练时的transition即来自buffer区。
buffer区的state是由几个frame连接得到的,而self.s仅仅指一帧。
trans:reset()
重置transition memory
trans:size()
返回self.numEntries
trans.empty()
将self.numEntries置0
trans.concatFrames(index,use_recent)
该函数负责将histLen个Frames的图像连接在一起,组成一个状态。至于Frames的选取方法,由self.histIndices的值来决定。
use_recent是一个bool型的变量,这个变量决定是否使用recent table
函数新建了一个局部变量fullstate,用来存储histLen个Frames的数据。函数的输入变量index表示在s中采样的Frames的初始下标。
这个函数会在index与index+self.histIndice[histLen]-1之间的Frames,按照index+self.histIndice的方式进行采样,然而,如果在这些帧图像之间出现了terminal状态,也就是说游戏重新开始了一遍,这里会将出现terminal状态前的采样帧进行归零处理。也就是说最后得到的fullstate只包含最新的episode(每次从游戏开始到结束称为一个episode)。最终得到的一个fullstate称为一个状态。
trans:concatActions(index,use_recent)
该函数的作用类似于trans:concatFrames,唯一的区别是它作用的对象是actions。
trans:get(index)
调用self:concatFrames(index)得到s和s2,我们取s中的最后一帧的action和reward作为整个state的action和reward,terminal取整个state后的第一帧的t值。
trans:sample_one()
在(2,self.numEntries-self.recentMemSize)之间进行均匀采样得到一个index,从2开始的原因是保证有一个previous action,index的最大值是self.numEntries-self.rencentMemSize,这样设置是因为训练的状态的最后一帧的下标与第一帧的下标之间相差recentMemSize。
同时如果self.nonTermProb和self.nonEventProb不等于1的情况下,采样的状态会被随机抛弃。
trans:fill_buffer()
这个函数通过调用trans:sample_one()的函数来进行采样,然后将这些随机采样的样本加入到buffer区。执行这个函数会刷新buffer区的数据。
注意到这里必须保证原存储区的样本个数大于buffer区。
然后进行采样,注意到该函数调用后会初始化一个类成员变量self.buf_ind,这个变量表示在buffer中训练时的下标指示器。每次调用该函数就会使这个变量置为1,即表示现在的buffer区的数据还没有被训练。
trans:sample(batch_size)
在buffer区得到batch_size个tansition,注意到如果buffer区中所剩下的数据少于batch_size时会重新更新buffer区。
trans:add(s,a,r,term)
该文件会将一组新的s,a,r,term(terminal)写进存储区,每写进一个数据self.numEntries会加1,直到self.maxSize为止。
这里用self.inserIndex来控制写入的下标,当存储区写满后,又从头开始写入。
写入存储区
trans:add_recent_state(s,term),trans:add_recent_action(a)
这两个函数分别将s,term和a加入recent存储区,注意到由于recent存储区只存储一个状态,因此函数里面有维持recent存储区的大小等于self.recentMemSize的操作。
trans:get_recent()
从recent存储区取一个状态
trans:write(file)
将trans类的参数序列化写入文件
trans:read(file)
执行反序列化,从文件中读取参数