ぱたへね

はてなダイアリーはrustの色分けができないのでこっちに来た

kerasでカスタムレイヤーのシリアライズを行う

Keras SSDjsonに変換し、読みこもうとするとエラーがでる。
https://github.com/rykov8/ssd_keras

Traceback (most recent call last):
  File "/home/natu/proj/myproj/cocytus/cocytus/cocytus.py", line 70, in <module>
    main(sys.argv)
  File "/home/natu/proj/myproj/cocytus/cocytus/cocytus.py", line 54, in main
    compiler = CocytusCompiler(config)
  File "/media/natu/data/proj/myproj/cocytus/cocytus/compiler/compiler.py", line 33, in __init__
    self.model = model_from_json(json_string, custom_objects={"Normalize": Normalize, "PriorBox": PriorBox})
  File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 325, in model_from_json
    return layer_module.deserialize(config, custom_objects=custom_objects)
  File "/usr/local/lib/python3.5/dist-packages/keras/layers/__init__.py", line 46, in deserialize
    printable_module_name='layer')
  File "/usr/local/lib/python3.5/dist-packages/keras/utils/generic_utils.py", line 140, in deserialize_keras_object
    list(custom_objects.items())))
  File "/usr/local/lib/python3.5/dist-packages/keras/engine/topology.py", line 2370, in from_config
    process_layer(layer_data)
  File "/usr/local/lib/python3.5/dist-packages/keras/engine/topology.py", line 2339, in process_layer
    custom_objects=custom_objects)
  File "/usr/local/lib/python3.5/dist-packages/keras/layers/__init__.py", line 46, in deserialize
    printable_module_name='layer')
  File "/usr/local/lib/python3.5/dist-packages/keras/utils/generic_utils.py", line 141, in deserialize_keras_object
    return cls.from_config(config['config'])
  File "/usr/local/lib/python3.5/dist-packages/keras/engine/topology.py", line 1202, in from_config
    return cls(**config)
TypeError: __init__() missing 1 required positional argument: 'scale'

原因

model.to_json()を使った時に、カスタムレイヤーの情報がjsonに含まれていないため、復元時にエラーになる。model.to_json()を使って出力されたjsonファイルを確認してみる。

{"class_name": "Normalize", "inbound_nodes": [[["conv4_3", 0, 0, {}]]], "name": "conv4_3_norm", "config": {"name": "conv4_3_norm", "trainable": true}}

ここの"config"の所にscaleの情報が必要。
ドキュメントが見つけられなかったので、
https://github.com/fchollet/keras/blob/master/keras/layers/normalization.py
を参考にした。

対策

カスタムレイヤーにget_configを追加し、シリアライズに必要な情報を教える。

Normalize

保存すべき情報は、__init__の引数を見れば分かる。

def __init__(self, scale, **kwargs):

Normalizeの場合はscaleがあれば、層を復元できる。

   def get_config(self):
        config = {
            'scale': self.scale,
        }
        base_config = super(Normalize, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))
PriorBox

PriorBoxは少し複雑。

    def __init__(self, img_size, min_size, max_size=None, aspect_ratios=None,
                 flip=True, variances=[0.1], clip=True, **kwargs):

ここの__init__の引数に出てくる情報が必要になる。ソースを見ると、self.waxis等も保存したくなるが、それらは引数から計算できるので不要である。

    def get_config(self):
        config = {
            'img_size': self.img_size,
            'min_size': self.min_size,
            'max_size': self.max_size,
            'aspect_ratios': self.aspect_ratios,
            'variances': list(self.variances),
            'clip': self.clip
        }
        base_config = super(PriorBox, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

これを追加すれば良い。self.variancesがarrayだったので、listに変換している。
修正済みのソースがこちら。

https://gist.github.com/natsutan/ba0abee68b23bb042602ddc5a35217f4

実行結果

上の修正を加えて、model.to_json()再度実行する。情報がjsonに含まれているのが分かる。

{"name": "conv4_3_norm", "class_name": "Normalize", "config": {"name": "conv4_3_norm", "trainable": true, "scale": 20}, "inbound_nodes": [[["conv4_3", 0, 0, {}]]]}

{"name": "conv4_3_norm_mbox_priorbox", "class_name": "PriorBox", "config": {"name": "conv4_3_norm_mbox_priorbox", "max_size": null, "img_size": [300, 300], "trainable": true, "clip": true, "min_size": 30.0, "variances": [0.1, 0.1, 0.2, 0.2], "aspect_ratios": [1.0, 2, 0.5]}, "inbound_nodes": [[["conv4_3_norm", 0, 0, {}]]]}

jsonからの読み込み

まずは、カスタムレイヤーのインポートから。

from ssd_layers import Normalize
from ssd_layers import PriorBox

読み込む時には、custom_objectsの指定を忘れずに。

model = model_from_json(json_string, custom_objects={"Normalize": Normalize, "PriorBox": PriorBox})


実行結果

____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
input_1 (InputLayer)             (None, 300, 300, 3)   0                                            
____________________________________________________________________________________________________
conv1_1 (Conv2D)                 (None, 300, 300, 64)  1792                                         
____________________________________________________________________________________________________
conv1_2 (Conv2D)                 (None, 300, 300, 64)  36928                                        
____________________________________________________________________________________________________
pool1 (MaxPooling2D)             (None, 150, 150, 64)  0                                            
____________________________________________________________________________________________________
conv2_1 (Conv2D)                 (None, 150, 150, 128) 73856                                        
____________________________________________________________________________________________________
conv2_2 (Conv2D)                 (None, 150, 150, 128) 147584                                       
____________________________________________________________________________________________________
pool2 (MaxPooling2D)             (None, 75, 75, 128)   0                                            
____________________________________________________________________________________________________
conv3_1 (Conv2D)                 (None, 75, 75, 256)   295168                                       
____________________________________________________________________________________________________
conv3_2 (Conv2D)                 (None, 75, 75, 256)   590080                                       
____________________________________________________________________________________________________
conv3_3 (Conv2D)                 (None, 75, 75, 256)   590080                                       
____________________________________________________________________________________________________
pool3 (MaxPooling2D)             (None, 38, 38, 256)   0                                            
____________________________________________________________________________________________________
conv4_1 (Conv2D)                 (None, 38, 38, 512)   1180160                                      
____________________________________________________________________________________________________
conv4_2 (Conv2D)                 (None, 38, 38, 512)   2359808                                      
____________________________________________________________________________________________________
conv4_3 (Conv2D)                 (None, 38, 38, 512)   2359808                                      
____________________________________________________________________________________________________
pool4 (MaxPooling2D)             (None, 19, 19, 512)   0                                            
____________________________________________________________________________________________________
conv5_1 (Conv2D)                 (None, 19, 19, 512)   2359808                                      
____________________________________________________________________________________________________
conv5_2 (Conv2D)                 (None, 19, 19, 512)   2359808                                      
____________________________________________________________________________________________________
conv5_3 (Conv2D)                 (None, 19, 19, 512)   2359808                                      
____________________________________________________________________________________________________
pool5 (MaxPooling2D)             (None, 19, 19, 512)   0                                            
____________________________________________________________________________________________________
fc6 (Conv2D)                     (None, 19, 19, 1024)  4719616                                      
____________________________________________________________________________________________________
fc7 (Conv2D)                     (None, 19, 19, 1024)  1049600                                      
____________________________________________________________________________________________________
conv6_1 (Conv2D)                 (None, 19, 19, 256)   262400                                       
____________________________________________________________________________________________________
conv6_2 (Conv2D)                 (None, 10, 10, 512)   1180160                                      
____________________________________________________________________________________________________
conv7_1 (Conv2D)                 (None, 10, 10, 128)   65664                                        
____________________________________________________________________________________________________
zero_padding2d_1 (ZeroPadding2D) (None, 12, 12, 128)   0                                            
____________________________________________________________________________________________________
conv7_2 (Conv2D)                 (None, 5, 5, 256)     295168                                       
____________________________________________________________________________________________________
conv8_1 (Conv2D)                 (None, 5, 5, 128)     32896                                        
____________________________________________________________________________________________________
conv4_3_norm (Normalize)         (None, 38, 38, 512)   512                                          
____________________________________________________________________________________________________
conv8_2 (Conv2D)                 (None, 3, 3, 256)     295168                                       
____________________________________________________________________________________________________
pool6 (GlobalAveragePooling2D)   (None, 256)           0                                            
____________________________________________________________________________________________________
conv4_3_norm_mbox_conf (Conv2D)  (None, 38, 38, 63)    290367                                       
____________________________________________________________________________________________________
fc7_mbox_conf (Conv2D)           (None, 19, 19, 126)   1161342                                      
____________________________________________________________________________________________________
conv6_2_mbox_conf (Conv2D)       (None, 10, 10, 126)   580734                                       
____________________________________________________________________________________________________
conv7_2_mbox_conf (Conv2D)       (None, 5, 5, 126)     290430                                       
____________________________________________________________________________________________________
conv8_2_mbox_conf (Conv2D)       (None, 3, 3, 126)     290430                                       
____________________________________________________________________________________________________
conv4_3_norm_mbox_loc (Conv2D)   (None, 38, 38, 12)    55308                                        
____________________________________________________________________________________________________
fc7_mbox_loc (Conv2D)            (None, 19, 19, 24)    221208                                       
____________________________________________________________________________________________________
conv6_2_mbox_loc (Conv2D)        (None, 10, 10, 24)    110616                                       
____________________________________________________________________________________________________
conv7_2_mbox_loc (Conv2D)        (None, 5, 5, 24)      55320                                        
____________________________________________________________________________________________________
conv8_2_mbox_loc (Conv2D)        (None, 3, 3, 24)      55320                                        
____________________________________________________________________________________________________
conv4_3_norm_mbox_conf_flat (Fla (None, 90972)         0                                            
____________________________________________________________________________________________________
fc7_mbox_conf_flat (Flatten)     (None, 45486)         0                                            
____________________________________________________________________________________________________
conv6_2_mbox_conf_flat (Flatten) (None, 12600)         0                                            
____________________________________________________________________________________________________
conv7_2_mbox_conf_flat (Flatten) (None, 3150)          0                                            
____________________________________________________________________________________________________
conv8_2_mbox_conf_flat (Flatten) (None, 1134)          0                                            
____________________________________________________________________________________________________
pool6_mbox_conf_flat (Dense)     (None, 126)           32382                                        
____________________________________________________________________________________________________
conv4_3_norm_mbox_loc_flat (Flat (None, 17328)         0                                            
____________________________________________________________________________________________________
fc7_mbox_loc_flat (Flatten)      (None, 8664)          0                                            
____________________________________________________________________________________________________
conv6_2_mbox_loc_flat (Flatten)  (None, 2400)          0                                            
____________________________________________________________________________________________________
conv7_2_mbox_loc_flat (Flatten)  (None, 600)           0                                            
____________________________________________________________________________________________________
conv8_2_mbox_loc_flat (Flatten)  (None, 216)           0                                            
____________________________________________________________________________________________________
pool6_mbox_loc_flat (Dense)      (None, 24)            6168                                         
____________________________________________________________________________________________________
mbox_conf (Concatenate)          (None, 153468)        0                                            
____________________________________________________________________________________________________
pool6_reshaped (Reshape)         (None, 1, 1, 256)     0                                            
____________________________________________________________________________________________________
mbox_loc (Concatenate)           (None, 29232)         0                                            
____________________________________________________________________________________________________
mbox_conf_logits (Reshape)       (None, 7308, 21)      0                                            
____________________________________________________________________________________________________
conv4_3_norm_mbox_priorbox (Prio (None, 4332, 8)       0                                            
____________________________________________________________________________________________________
fc7_mbox_priorbox (PriorBox)     (None, 2166, 8)       0                                            
____________________________________________________________________________________________________
conv6_2_mbox_priorbox (PriorBox) (None, 600, 8)        0                                            
____________________________________________________________________________________________________
conv7_2_mbox_priorbox (PriorBox) (None, 150, 8)        0                                            
____________________________________________________________________________________________________
conv8_2_mbox_priorbox (PriorBox) (None, 54, 8)         0                                            
____________________________________________________________________________________________________
pool6_mbox_priorbox (PriorBox)   (None, 6, 8)          0                                            
____________________________________________________________________________________________________
mbox_loc_final (Reshape)         (None, 7308, 4)       0                                            
____________________________________________________________________________________________________
mbox_conf_final (Activation)     (None, 7308, 21)      0                                            
____________________________________________________________________________________________________
mbox_priorbox (Concatenate)      (None, 7308, 8)       0                                            
____________________________________________________________________________________________________
predictions (Concatenate)        (None, 7308, 33)      0                                            
====================================================================================================
Total params: 25,765,497.0
Trainable params: 25,765,497.0
Non-trainable params: 0.0
____________________________________________________________________________________________________