Keras SSDをjsonに変換し、読みこもうとするとエラーがでる。
https://github.com/rykov8/ssd_keras
Traceback (most recent call last): File "/home/natu/proj/myproj/cocytus/cocytus/cocytus.py", line 70, in <module> main(sys.argv) File "/home/natu/proj/myproj/cocytus/cocytus/cocytus.py", line 54, in main compiler = CocytusCompiler(config) File "/media/natu/data/proj/myproj/cocytus/cocytus/compiler/compiler.py", line 33, in __init__ self.model = model_from_json(json_string, custom_objects={"Normalize": Normalize, "PriorBox": PriorBox}) File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 325, in model_from_json return layer_module.deserialize(config, custom_objects=custom_objects) File "/usr/local/lib/python3.5/dist-packages/keras/layers/__init__.py", line 46, in deserialize printable_module_name='layer') File "/usr/local/lib/python3.5/dist-packages/keras/utils/generic_utils.py", line 140, in deserialize_keras_object list(custom_objects.items()))) File "/usr/local/lib/python3.5/dist-packages/keras/engine/topology.py", line 2370, in from_config process_layer(layer_data) File "/usr/local/lib/python3.5/dist-packages/keras/engine/topology.py", line 2339, in process_layer custom_objects=custom_objects) File "/usr/local/lib/python3.5/dist-packages/keras/layers/__init__.py", line 46, in deserialize printable_module_name='layer') File "/usr/local/lib/python3.5/dist-packages/keras/utils/generic_utils.py", line 141, in deserialize_keras_object return cls.from_config(config['config']) File "/usr/local/lib/python3.5/dist-packages/keras/engine/topology.py", line 1202, in from_config return cls(**config) TypeError: __init__() missing 1 required positional argument: 'scale'
原因
model.to_json()を使った時に、カスタムレイヤーの情報がjsonに含まれていないため、復元時にエラーになる。model.to_json()を使って出力されたjsonファイルを確認してみる。
{"class_name": "Normalize", "inbound_nodes": [[["conv4_3", 0, 0, {}]]], "name": "conv4_3_norm", "config": {"name": "conv4_3_norm", "trainable": true}}
ここの"config"の所にscaleの情報が必要。
ドキュメントが見つけられなかったので、
https://github.com/fchollet/keras/blob/master/keras/layers/normalization.py
を参考にした。
対策
カスタムレイヤーにget_configを追加し、シリアライズに必要な情報を教える。
Normalize
保存すべき情報は、__init__の引数を見れば分かる。
def __init__(self, scale, **kwargs):
Normalizeの場合はscaleがあれば、層を復元できる。
def get_config(self): config = { 'scale': self.scale, } base_config = super(Normalize, self).get_config() return dict(list(base_config.items()) + list(config.items()))
PriorBox
PriorBoxは少し複雑。
def __init__(self, img_size, min_size, max_size=None, aspect_ratios=None, flip=True, variances=[0.1], clip=True, **kwargs):
ここの__init__の引数に出てくる情報が必要になる。ソースを見ると、self.waxis等も保存したくなるが、それらは引数から計算できるので不要である。
def get_config(self): config = { 'img_size': self.img_size, 'min_size': self.min_size, 'max_size': self.max_size, 'aspect_ratios': self.aspect_ratios, 'variances': list(self.variances), 'clip': self.clip } base_config = super(PriorBox, self).get_config() return dict(list(base_config.items()) + list(config.items()))
これを追加すれば良い。self.variancesがarrayだったので、listに変換している。
修正済みのソースがこちら。
https://gist.github.com/natsutan/ba0abee68b23bb042602ddc5a35217f4
実行結果
上の修正を加えて、model.to_json()再度実行する。情報がjsonに含まれているのが分かる。
{"name": "conv4_3_norm", "class_name": "Normalize", "config": {"name": "conv4_3_norm", "trainable": true, "scale": 20}, "inbound_nodes": [[["conv4_3", 0, 0, {}]]]} {"name": "conv4_3_norm_mbox_priorbox", "class_name": "PriorBox", "config": {"name": "conv4_3_norm_mbox_priorbox", "max_size": null, "img_size": [300, 300], "trainable": true, "clip": true, "min_size": 30.0, "variances": [0.1, 0.1, 0.2, 0.2], "aspect_ratios": [1.0, 2, 0.5]}, "inbound_nodes": [[["conv4_3_norm", 0, 0, {}]]]}
jsonからの読み込み
まずは、カスタムレイヤーのインポートから。
from ssd_layers import Normalize from ssd_layers import PriorBox
読み込む時には、custom_objectsの指定を忘れずに。
model = model_from_json(json_string, custom_objects={"Normalize": Normalize, "PriorBox": PriorBox})
実行結果
____________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ==================================================================================================== input_1 (InputLayer) (None, 300, 300, 3) 0 ____________________________________________________________________________________________________ conv1_1 (Conv2D) (None, 300, 300, 64) 1792 ____________________________________________________________________________________________________ conv1_2 (Conv2D) (None, 300, 300, 64) 36928 ____________________________________________________________________________________________________ pool1 (MaxPooling2D) (None, 150, 150, 64) 0 ____________________________________________________________________________________________________ conv2_1 (Conv2D) (None, 150, 150, 128) 73856 ____________________________________________________________________________________________________ conv2_2 (Conv2D) (None, 150, 150, 128) 147584 ____________________________________________________________________________________________________ pool2 (MaxPooling2D) (None, 75, 75, 128) 0 ____________________________________________________________________________________________________ conv3_1 (Conv2D) (None, 75, 75, 256) 295168 ____________________________________________________________________________________________________ conv3_2 (Conv2D) (None, 75, 75, 256) 590080 ____________________________________________________________________________________________________ conv3_3 (Conv2D) (None, 75, 75, 256) 590080 ____________________________________________________________________________________________________ pool3 (MaxPooling2D) (None, 38, 38, 256) 0 ____________________________________________________________________________________________________ conv4_1 (Conv2D) (None, 38, 38, 512) 1180160 ____________________________________________________________________________________________________ conv4_2 (Conv2D) (None, 38, 38, 512) 2359808 ____________________________________________________________________________________________________ conv4_3 (Conv2D) (None, 38, 38, 512) 2359808 ____________________________________________________________________________________________________ pool4 (MaxPooling2D) (None, 19, 19, 512) 0 ____________________________________________________________________________________________________ conv5_1 (Conv2D) (None, 19, 19, 512) 2359808 ____________________________________________________________________________________________________ conv5_2 (Conv2D) (None, 19, 19, 512) 2359808 ____________________________________________________________________________________________________ conv5_3 (Conv2D) (None, 19, 19, 512) 2359808 ____________________________________________________________________________________________________ pool5 (MaxPooling2D) (None, 19, 19, 512) 0 ____________________________________________________________________________________________________ fc6 (Conv2D) (None, 19, 19, 1024) 4719616 ____________________________________________________________________________________________________ fc7 (Conv2D) (None, 19, 19, 1024) 1049600 ____________________________________________________________________________________________________ conv6_1 (Conv2D) (None, 19, 19, 256) 262400 ____________________________________________________________________________________________________ conv6_2 (Conv2D) (None, 10, 10, 512) 1180160 ____________________________________________________________________________________________________ conv7_1 (Conv2D) (None, 10, 10, 128) 65664 ____________________________________________________________________________________________________ zero_padding2d_1 (ZeroPadding2D) (None, 12, 12, 128) 0 ____________________________________________________________________________________________________ conv7_2 (Conv2D) (None, 5, 5, 256) 295168 ____________________________________________________________________________________________________ conv8_1 (Conv2D) (None, 5, 5, 128) 32896 ____________________________________________________________________________________________________ conv4_3_norm (Normalize) (None, 38, 38, 512) 512 ____________________________________________________________________________________________________ conv8_2 (Conv2D) (None, 3, 3, 256) 295168 ____________________________________________________________________________________________________ pool6 (GlobalAveragePooling2D) (None, 256) 0 ____________________________________________________________________________________________________ conv4_3_norm_mbox_conf (Conv2D) (None, 38, 38, 63) 290367 ____________________________________________________________________________________________________ fc7_mbox_conf (Conv2D) (None, 19, 19, 126) 1161342 ____________________________________________________________________________________________________ conv6_2_mbox_conf (Conv2D) (None, 10, 10, 126) 580734 ____________________________________________________________________________________________________ conv7_2_mbox_conf (Conv2D) (None, 5, 5, 126) 290430 ____________________________________________________________________________________________________ conv8_2_mbox_conf (Conv2D) (None, 3, 3, 126) 290430 ____________________________________________________________________________________________________ conv4_3_norm_mbox_loc (Conv2D) (None, 38, 38, 12) 55308 ____________________________________________________________________________________________________ fc7_mbox_loc (Conv2D) (None, 19, 19, 24) 221208 ____________________________________________________________________________________________________ conv6_2_mbox_loc (Conv2D) (None, 10, 10, 24) 110616 ____________________________________________________________________________________________________ conv7_2_mbox_loc (Conv2D) (None, 5, 5, 24) 55320 ____________________________________________________________________________________________________ conv8_2_mbox_loc (Conv2D) (None, 3, 3, 24) 55320 ____________________________________________________________________________________________________ conv4_3_norm_mbox_conf_flat (Fla (None, 90972) 0 ____________________________________________________________________________________________________ fc7_mbox_conf_flat (Flatten) (None, 45486) 0 ____________________________________________________________________________________________________ conv6_2_mbox_conf_flat (Flatten) (None, 12600) 0 ____________________________________________________________________________________________________ conv7_2_mbox_conf_flat (Flatten) (None, 3150) 0 ____________________________________________________________________________________________________ conv8_2_mbox_conf_flat (Flatten) (None, 1134) 0 ____________________________________________________________________________________________________ pool6_mbox_conf_flat (Dense) (None, 126) 32382 ____________________________________________________________________________________________________ conv4_3_norm_mbox_loc_flat (Flat (None, 17328) 0 ____________________________________________________________________________________________________ fc7_mbox_loc_flat (Flatten) (None, 8664) 0 ____________________________________________________________________________________________________ conv6_2_mbox_loc_flat (Flatten) (None, 2400) 0 ____________________________________________________________________________________________________ conv7_2_mbox_loc_flat (Flatten) (None, 600) 0 ____________________________________________________________________________________________________ conv8_2_mbox_loc_flat (Flatten) (None, 216) 0 ____________________________________________________________________________________________________ pool6_mbox_loc_flat (Dense) (None, 24) 6168 ____________________________________________________________________________________________________ mbox_conf (Concatenate) (None, 153468) 0 ____________________________________________________________________________________________________ pool6_reshaped (Reshape) (None, 1, 1, 256) 0 ____________________________________________________________________________________________________ mbox_loc (Concatenate) (None, 29232) 0 ____________________________________________________________________________________________________ mbox_conf_logits (Reshape) (None, 7308, 21) 0 ____________________________________________________________________________________________________ conv4_3_norm_mbox_priorbox (Prio (None, 4332, 8) 0 ____________________________________________________________________________________________________ fc7_mbox_priorbox (PriorBox) (None, 2166, 8) 0 ____________________________________________________________________________________________________ conv6_2_mbox_priorbox (PriorBox) (None, 600, 8) 0 ____________________________________________________________________________________________________ conv7_2_mbox_priorbox (PriorBox) (None, 150, 8) 0 ____________________________________________________________________________________________________ conv8_2_mbox_priorbox (PriorBox) (None, 54, 8) 0 ____________________________________________________________________________________________________ pool6_mbox_priorbox (PriorBox) (None, 6, 8) 0 ____________________________________________________________________________________________________ mbox_loc_final (Reshape) (None, 7308, 4) 0 ____________________________________________________________________________________________________ mbox_conf_final (Activation) (None, 7308, 21) 0 ____________________________________________________________________________________________________ mbox_priorbox (Concatenate) (None, 7308, 8) 0 ____________________________________________________________________________________________________ predictions (Concatenate) (None, 7308, 33) 0 ==================================================================================================== Total params: 25,765,497.0 Trainable params: 25,765,497.0 Non-trainable params: 0.0 ____________________________________________________________________________________________________