torch学习笔记3.1:实现自定义模块(lua)

在使用torch时,如果想自己实现一个层,则可以按照《torch学习笔记1:实现自定义层》 中的方法来实现。但是如果想要实现一个比较复杂的网络,往往需要自己实现多个层(或类),并且有时可能需要重写其他模块中已有的函数来达到自己的目的,如果还是在nn模块中添加,会比较混乱,并且不利于本地git仓库统一管理,这个时候,我们可以自己实现一个像nn一样的模块,在代码中使用时 require即可。

我们来实现一个名为nxn的自定义模块,以及它的cuda版本cunxn模块,其中包含一个自定义的Hello类(lua实现),ReLU类(分别用CPU和GPU实现)。

由于篇幅原因,这里把torch自定义模块的lua实现,cpu实现,gpu实现分别写一篇文章,本文先介绍lua实现的Hello类。

1 总目录结构

模板源代码可在我的资源中下载。

.../myproj/
      |----scripts/
           |---- demo.lua
      |----nxn/
           |---- CMakeLists.txt
           |---- nxn-scm-1.rockspec
           |---- init.lua
           |---- init.c
           |---- ReLU.lua
           |---- Hello.lua
           |---- generic/
                 |---- ReLU.c
           |---- test/
                 |---- test.lua
      |----cunxn/
           |---- CMakeLists.txt
           |---- cunxn-scm-1.rockspec
           |---- init.lua
           |---- init.cu   
           |---- ReLU.cu
           |---- test/
                 |---- test.lua                

2 使用

  1. 成功安装了torch。
  2. 在nxn目录下运行
luarocks make nxn-scm-1.rockspec
  1. 在cunxn目录下运行
luarocks make cunxn-scm-1.rockspec
  1. 在scripts目录下运行
th demo.lua
  1. 输出
    result

3 文件说明

demo.lua

是使用自定义类的示例代码。

require 'cunxn'

local module = nxn.Hello()
module:updateOutput()

input = torch.rand(3,3)
print(input)

local module = nxn.ReLU(false)
output = module:updateOutput(input)
print(output)

cutorch.setDevice(2)
input = input:cuda()
print(input)

local module = nxn.ReLU(true)
output = module:updateOutput(input)
print(output)

CMakeLists.txt

一般和nn之类的模块没有太大区别,仿照着写即可,需要注意的是以下几句:

......
# 编译时从init.c找cpu实现的代码文件
SET(src init.c) 
# 指定要编译的lua文件
FILE(GLOB luasrc *.lua)
SET(luasrc ${luasrc} test/test.lua)
# 把cpp和lua文件加入模块nxn
ADD_TORCH_PACKAGE(nxn "${src}" "${luasrc}")
# 链接lua库
TARGET_LINK_LIBRARIES(nxn luaT TH)
......

nxn-scm-1.rockspec

注意dependencies里面还可以添加已有模块,比如nn,cunn,格式如下:

......
dependencies = {
   "torch >= 7.0",
   "cunn",
   "nn"
}
......

init.lua

内容如下,要include自定义类的lua文件,以及这里把cpp实现编译成了一个lib,也要添加进来。

require('torch')
require('libnxn')

include('ReLU.lua')
include('Hello.lua')

Hello.lua

自定义类的文件,该类由lua实现,这里提供一个简单的模板。

local Hello = torch.class('nxn.Hello')

function Hello:__init()
end

function Hello:updateOutput()
   print("hello in updateOutput")
end

function Hello:updateGradInput(input, gradOutput)
   print("hello in updateGradInput")
end

未完,后续说明见 CPU实现,GPU实现。


版权声明:本文为happyer88原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。