博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
SEnet --se module
阅读量:4331 次
发布时间:2019-06-06

本文共 1582 字,大约阅读时间需要 5 分钟。

 

这是SEnet 的特征融合部分,

import tensorflow as tffrom tflearn.layers.conv import global_avg_poolclass SE_layer():    def __init__(self, x, training=True):        self.training = training    def Global_Average_Pooling(self, x):        return global_avg_pool(x, name='Global_avg_pooling')    def Fully_connected(self, x, units=3, layer_name='fully_connected') :        with tf.name_scope(layer_name) :            return tf.layers.dense(inputs=x, use_bias=False, units=units)    def Relu(self, x):        return tf.nn.relu(x)    def Sigmoid(self, x) :        return tf.nn.sigmoid(x)    def squeeze_excitation_layer(self, input_x, ratio, layer_name):        with tf.name_scope(layer_name) :            squeeze = self.Global_Average_Pooling(input_x)            excitation = self.Fully_connected(squeeze, units=int(input_x.shape[3])/ratio, layer_name=layer_name+'_fully_connected1')            excitation = self.Relu(excitation)            excitation = self.Fully_connected(excitation, units=int(input_x.shape[3]), layer_name=layer_name+'_fully_connected2')            excitation = self.Sigmoid(excitation)                    dim3 = int(input_x.shape[3])            excitation = tf.reshape(excitation, [-1,1,1, dim3])            scale = input_x * excitation            return scaleif __name__=="__main__":    input_data = tf.random_uniform([2, 70, 80, 3], 0, 255)    semodule = SE_layer(input_data)    output = semodule.squeeze_excitation_layer(input_data, 1, 'first')    print(output.shape)

注意:

暂时还不知道效果如何,可以测试一下。这个和cfe一样都是不改变shape的module,可以多关注一下。
在这里插入图片描述

 

转载于:https://www.cnblogs.com/o-v-o/p/9975350.html

你可能感兴趣的文章
实验三
查看>>
机器码和字节码
查看>>
环形菜单的实现
查看>>
【解决Chrome浏览器和IE浏览器上传附件兼容的问题 -- Chrome关闭flash后,uploadify插件不可用的解决办法】...
查看>>
34 帧动画
查看>>
二次剩余及欧拉准则
查看>>
Centos 7 Mysql 最大连接数超了问题解决
查看>>
thymeleaf 自定义标签
查看>>
关于WordCount的作业
查看>>
C6748和音频ADC连接时候的TDM以及I2S格式问题
查看>>
UIView的layoutSubviews,initWithFrame,initWithCoder方法
查看>>
STM32+IAP方案 实现网络升级应用固件
查看>>
用74HC165读8个按键状态
查看>>
jpg转bmp(使用libjpeg)
查看>>
linear-gradient常用实现效果
查看>>
sql语言的一大类 DML 数据的操纵语言
查看>>
VMware黑屏解决方法
查看>>
JS中各种跳转解析
查看>>
JAVA 基础 / 第八课:面向对象 / JAVA类的方法与实例方法
查看>>
Ecust OJ
查看>>