52matlab技术网站,matlab教程,matlab安装教程,matlab下载

 找回密码
 立即注册
搜索
热搜: 活动 交友 discuz
查看: 4979|回复: 0

CBAM通道注意力的是tensorflow2.xs

[复制链接]

125

主题

209

帖子

3083

积分

版主

Rank: 7Rank: 7Rank: 7

积分
3083
发表于 2024-9-30 04:23:03 | 显示全部楼层 |阅读模式
CBAM通道注意力的是tensorflow2.x,  参考https://blog.csdn.net/weixin_39122088/article/details/10719197
from keras import layers, regularizers
  1. #实现方式1
  2. class ChannelAttention(layers.Layer):
  3.         def __init__(self, in_planes, ratio=8):
  4.              super(ChannelAttention, self).__init__()

  5.             self.avg_out= layers.GlobalAveragePooling2D()
  6.             self.max_out= layers.GlobalMaxPooling2D()

  7.             self.fc1 = layers.Dense(in_planes//ratio, kernel_initializer='he_normal',
  8.             kernel_regularizer=regularizers.l2(5e-4),
  9.             activation=tf.nn.relu,
  10.            use_bias=True, bias_initializer='zeros')
  11.            self.fc2 = layers.Dense(in_planes, kernel_initializer='he_normal',
  12.            kernel_regularizer=regularizers.l2(5e-4),
  13.            use_bias=True, bias_initializer='zeros')

  14. def call(self, inputs):
  15.      avg_out = self.avg_out(inputs)
  16.      max_out = self.max_out(inputs)
  17.      out = tf.stack([avg_out, max_out], axis=1) # shape=(None, 2, fea_num)
  18.      out = self.fc2(self.fc1(out))
  19.      out = tf.reduce_sum(out, axis=1) # shape=(256, 512)
  20.      out = tf.nn.sigmoid(out)
  21.      out = layers.Reshape((1, 1, out.shape[1]))(out)

  22.    return out
复制代码

  1. #实现方式2
  2. class ChannelAttention(layers.Layer):
  3.      def __init__(self, in_planes):
  4.          super(ChannelAttention, self).__init__()

  5.          self.avg= layers.GlobalAveragePooling2D()
  6.          self.max= layers.GlobalMaxPooling2D()

  7.          self.fc1 = layers.Dense(in_planes//16, kernel_initializer='he_normal', activation='relu',
  8.          use_bias=True, bias_initializer='zeros')
  9.          self.fc2 = layers.Dense(in_planes, kernel_initializer='he_normal', use_bias=True,
  10.          bias_initializer='zeros')

  11.    def call(self, inputs):
  12.         avg_out = self.fc2(self.fc1(self.avg(inputs)))
  13.         max_out = self.fc2(self.fc1(self.max(inputs)))
  14.         out = avg_out + max_out
  15.         out = tf.nn.sigmoid(out)
  16.         out = tf.reshape(out, [out.shape[0], 1, 1, out.shape[1]])
  17.         out = tf.tile(out, [1, inputs.shape[1], inputs.shape[2], 1])

  18.       return out
复制代码

回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

Archiver|手机版|小黑屋|52matlab技术网站 ( 粤ICP备14005920号-5 )

GMT+8, 2024-12-8 01:26 , Processed in 0.058595 second(s), 20 queries .

Powered by Discuz! X3.2 Licensed

© 2001-2013 Comsenz Inc.

快速回复 返回顶部 返回列表