博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
mxnet系列 全连接层代码阅读
阅读量:4320 次
发布时间:2019-06-06

本文共 3204 字,大约阅读时间需要 10 分钟。

全连接操作(全连接层)也具有前向和反向。代码 解析如下

virtual void Forward(const OpContext &ctx,                       const std::vector
&in_data, const std::vector
&req, const std::vector
&out_data, const std::vector
&aux_args) { using namespace mshadow; using namespace mshadow::expr; if (req[fullc::kOut] == kNullOp) return; CHECK_EQ(req[fullc::kOut], kWriteTo); size_t expected = param_.no_bias ? 2 : 3; CHECK_EQ(in_data.size(), expected); CHECK_EQ(out_data.size(), 1); // TODO(bing): check the BLAS Handle, be careful // maybe need blas handle from context // TODO(bing): judge shape to remove flatten op Stream
*s = ctx.get_stream
();#if defined(__CUDACC__) CHECK_EQ(s->blas_handle_ownership_, Stream
::OwnHandle) << "Must init CuBLAS handle in stream";#endif // __CUDACC__ const TShape& ishape = in_data[fullc::kData].shape_; const TShape& oshape = out_data[fullc::kOut].shape_; Tensor
data = in_data[fullc::kData].get_with_shape
( //输入 Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim())), s); Tensor
wmat = in_data[fullc::kWeight].get
(s); //权重 Tensor
out = out_data[fullc::kOut].get_with_shape
( //输出 Shape2(oshape[0], oshape.ProdShape(1, oshape.ndim())), s); out = dot(data, wmat.T()); //点乘 if (!param_.no_bias) { Tensor
bias = in_data[fullc::kBias].get
(s); out += repmat(bias, data.size(0)); } } virtual void Backward(const OpContext &ctx, const std::vector
&out_grad, const std::vector
&in_data, const std::vector
&out_data, const std::vector
&req, const std::vector
&in_grad, const std::vector
&aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(out_grad.size(), 1); size_t expected = param_.no_bias ? 2 : 3; CHECK(in_data.size() == expected && in_grad.size() == expected); CHECK_EQ(req.size(), expected); // TODO(bing): check the BLAS Handle, be careful // maybe need blas handle from context Stream
*s = ctx.get_stream
(); const TShape& ishape = in_data[fullc::kData].shape_; const TShape& oshape = out_grad[fullc::kOut].shape_; Tensor
data = in_data[fullc::kData].get_with_shape
( //输入 Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim())), s); Tensor
wmat = in_data[fullc::kWeight].get
(s); //权重 Tensor
grad = out_grad[fullc::kOut].get_with_shape
( //梯度 Shape2(oshape[0], oshape.ProdShape(1, oshape.ndim())), s);#if defined(__CUDACC__) CHECK_EQ(s->blas_handle_ownership_, Stream
::OwnHandle) << "Must init CuBLAS handle in stream";#endif // backprop CHECK_NE(req[fullc::kWeight], kWriteInplace) << "cannot write weight inplace"; // gradient of weight Tensor
gwmat = in_grad[fullc::kWeight].get
(s); //权重梯度 Assign(gwmat, req[fullc::kWeight], dot(grad.T(), data)); //求权重梯度 // gradient of bias if (!param_.no_bias) { Tensor
gbias = in_grad[fullc::kBias].get
(s);//偏置梯度 Assign(gbias, req[fullc::kBias], sum_rows(grad)); } // gradient of data Tensor
gdata = in_grad[fullc::kData].get_with_shape
( //输入梯度 Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim())), s); Assign(gdata, req[fullc::kData], dot(grad, wmat)); //求权重梯度 }

转载于:https://www.cnblogs.com/hellokittyblog/p/8186128.html

你可能感兴趣的文章
Android面试题集合
查看>>
Android NDK开发
查看>>
Centos中安装和配置vsftp简明教程
查看>>
spring源码学习之AOP(一)
查看>>
AES加密算法动画演示
查看>>
三种方法实现调用Restful接口
查看>>
php第五节(字符串函数和时间、日期函数)
查看>>
magento主页限制某个目录的产品显示数量
查看>>
SpringBoot整合Netty
查看>>
MongoDB数据库的基本操作
查看>>
PAT乙级1014
查看>>
ORACLE wm_concat自定义
查看>>
[Zend PHP5 Cerification] Lectures -- 6. Database and SQL
查看>>
[Drupal] Using the Administrator theme whenever you want.
查看>>
【Hibernate框架】关联映射(一对一关联映射)
查看>>
【算法】大数乘法
查看>>
WPF解析PPT为图片
查看>>
JavaScrict中的断言调试
查看>>
密码服务
查看>>
结构体在内存中的存储
查看>>