以index访问Tensor元素+ unpool TensorFlow代码

释放双眼,带上耳机,听听看~!

摘要: 本文主要解决了两个难题:1)原本 tensor cannot be accessed by index,本文给出了一个可行的方案:访问/读取/修改指定index的tensor元素的方案。2)Tensorflow官方给出了反向卷积的API,但是没有给出反向max-pool的API,网上没有相关代码,本文给出解决方案及代码。

首先,Unpool操作非常重要,因为在Autoencoder 等需要用反向神经网络来实现invert反转来构建原始数据,例如我们可以用倒过来的卷积神经网络从特征生成一些椅子的图片,而且生成的图片跟之前的图片一定是不同的。就像是一个脑洞大开的你,输入照片描述:宋仲基聚光的小眼睛,雷佳音丧萌丧萌的眉毛,潘粤明呆萌乖巧的腿….,然后就能给你生成一批符合这些描述的照片。倒过来的神经网络架构及应用例子如下图的例子:
以index访问Tensor元素+ unpool TensorFlow代码

unpool操作,也就是max-pool的反向操作并没有现成的函数,本文将已有的方法做一下总结,并将代码放出来,供大家参考:

1. 文献中关于Unpool的两种方法

方法1: we perform unappealing by simply replacing each entry map by an ss block with the entry value in the top corner and zeros elsewhere. This increases the width and the height of the feature map s times. We used s=2 in our networks. 如果max-pool是将一个22的方格里最大值拿出,那么unpooling可以将该值赋给2*2的左上角元素,其它置为0.
以index访问Tensor元素+ unpool TensorFlow代码

方法 2: http://www.matthewzeiler.com/wp-content/uploads/2017/07/arxive2013.pdf it records the locations of maximum activations selected during pooling operation in switch variables, which are em- ployed to place each activation back to its original pooled location  https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/Noh_Learning_Deconvolution_Network_ICCV_2015_paper.pdf 这个方法比方法1应该更准确了一些,因为在unpooling的时候将max pool 选择的最大值的 index拿到,那么将值再放回原位置,其它置为0。
以index访问Tensor元素+ unpool TensorFlow代码

2. 代码实现

在实现unpooling的时候,因为Tensor cannot be accessed by index,只有 numpy array才可以用index访问,而我们知道tensor 转 numpy array的唯一方式就是执行它,比如sess.run(Tensor),或者Tessor.eval(),因为Tensor实际上是个容器,没有内容,只有执行一下才有数据,而run Tensor的这件事在我们训练模型定义层结构的时候是不符合逻辑的,当你有好多层,需要定义一个反向卷积层服务于整个神经网络的训练,你不可能单独为某一层拿到sess和feed_dic。

2.1 如何实现 Access tensor by index 修改/读取 Tensor的某个元素

第一个摆在我们面前的困难就是:tensor cannot be accessed by index。

如何Access tensor by index 修改/读取 Tensor中的某个元素?本节先单独讨论这一问题,这个问题也会是其它需要通过index访问tensor的代码都会遇到的基础问题。

解决方案:既然不能用 tensor1来指定index读取元素,例如(1,2)的元素(这里tensor是一个创建的tensor的名字)那么我们可以借助其它方式巧妙的完成它,也就是分为两个步骤:1)指定index来读取tensor的值 (tf.expand_dims),2)为tensor 指定index的位置赋指定的值(使用tf.SparseTensor(indices, values, shape))。

最后,放上实现上述方法1 unpooling的代码

def unpool2(pool, ksize, stride, padding = 'VALID'):
    """
    simple unpool method
    :param pool : the tensor to run unpool operation
    :param ksize : integer
    :param stride : integer
    :return : the tensor after the unpool operation
    """
    pool = tf.transpose(pool, perm=[0,3,1,2])
    pool_shape = pool.shape.as_list()
    if padding == 'VALID':
        size = (pool_shape[2] - 1) * stride + ksize
    else:
        size = pool_shape[2] * stride
    unpool_shape = [pool_shape[0], pool_shape[1], size, size]
    unpool = tf.Variable(np.zeros(unpool_shape), dtype=tf.float32)
    for batch in range(pool_shape[0]):
        for channel in range(pool_shape[1]):
            for w in range(pool_shape[2]):
                for h in range(pool_shape[3]):
                    diff_matrix = tf.sparse_tensor_to_dense(tf.SparseTensor(indices=[[batch,channel,w*stride,h*stride]],values=tf.expand_dims(pool[batch][channel][w][h],axis=0),dense_shape = [pool_shape[0],pool_shape[1],size,size]))
                    unpool = unpool + diff_matrix

相信有了这个代码,方法2大家实现起来就很容易了。

但是这种方法非常慢,所以实现unpooling的过程,可以用下面的方法代替:

*# PI is the 4-dimension Tensor from above layer
unpool1 = tf.image.resize_images(PI, size = [resize_width, resize_width], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)*  

另外还有两种unpool的方法:

def max_unpool_2x2(x, output_shape):
    out = tf.concat_v2([x, tf.zeros_like(x)], 3)
    out = tf.concat_v2([out, tf.zeros_like(out)], 2)
    out_size = output_shape
    return tf.reshape(out, out_size)
# max unpool layer 改变输出的 shape 为两倍
# output_shape_d_pool2 = tf.pack([tf.shape(x)[0], 28, 28, 1])
# h_d_pool2 = max_unpool_2x2(h_d_conv2, output_shape_d_pool2)
def max_unpool_2x2(x, shape):
    inference = tf.image.resize_nearest_neighbor(x, tf.pack([shape[1]*2, shape[2]*2]))
    return inference

其他参考:

https://ithelp.ithome.com.tw/articles/10188326

【转自慕课】https://www.imooc.com

Python

Python性能分析指南

2022-3-3 14:44:07

Python

「自动化实战」python对mongodb的CURD介绍

2022-3-3 14:49:23

搜索