Writing your own Keras layers

对于简单的无状态自定义操作,最好使用layers.core.Lambda图层. 但是对于具有可调整权重的任何自定义操作,您应该实现自己的层.

这是Keras 2.0的Keras层的骨架(如果您使用的是旧版本,请升级). 您只需实现三种方法:

  • build(input_shape) :这是您定义权重的地方. 此方法必须在最后设置self.built = True ,这可以通过调用super([Layer], self).build() .
  • call(x) :这是层的逻辑所在. 除非您希望您的图层支持遮罩,否则您只需关心传递给call的第一个参数:输入张量.
  • compute_output_shape(input_shape) :如果您的图层修改了其输入的形状,则应在此处指定形状转换逻辑. 这使Keras可以进行自动形状推断.
from keras import backend as K
from keras.layers import Layer

class MyLayer(Layer):

    def __init__(self, output_dim, **kwargs):
        self.output_dim = output_dim
        super(MyLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        # Create a trainable weight variable for this layer.
        self.kernel = self.add_weight(name='kernel', 
                                      shape=(input_shape[1], self.output_dim),
                                      initializer='uniform',
                                      trainable=True)
        super(MyLayer, self).build(input_shape)  # Be sure to call this at the end

    def call(self, x):
        return K.dot(x, self.kernel)

    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.output_dim)

也可以定义具有多个输入张量和多个输出张量的Keras层. 为此,您应该假定build(input_shape)call(x)compute_output_shape(input_shape)方法的输入和输出是列表. 这是一个示例,类似于上面的示例:

from keras import backend as K
from keras.layers import Layer

class MyLayer(Layer):

    def __init__(self, output_dim, **kwargs):
        self.output_dim = output_dim
        super(MyLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        assert isinstance(input_shape, list)
        # Create a trainable weight variable for this layer.
        self.kernel = self.add_weight(name='kernel',
                                      shape=(input_shape[0][1], self.output_dim),
                                      initializer='uniform',
                                      trainable=True)
        super(MyLayer, self).build(input_shape)  # Be sure to call this at the end

    def call(self, x):
        assert isinstance(x, list)
        a, b = x
        return [K.dot(a, self.kernel) + b, K.mean(b, axis=-1)]

    def compute_output_shape(self, input_shape):
        assert isinstance(input_shape, list)
        shape_a, shape_b = input_shape
        return [(shape_a[0], self.output_dim), shape_b[:-1]]

现有的Keras层提供了如何实现几乎所有内容的示例. 毫不犹豫地阅读源代码!