kerasのネットワーク情報と学習データを保存する方法はこちら
http://d.hatena.ne.jp/natsutan/20170212
この続きで保存した学習データをCで読み込んでみます。
hd5形式をネットワーク名に対応させてnumpy形式で書き出す。
ファイル構成を調べても、どのデータがどの重みに対応するのかがよくわかりませんでした。結局、kerasのコードから必要なところを引き抜いてきました。
参考:https://github.com/fchollet/keras/blob/master/keras/engine/topology.py
# https://github.com/fchollet/keras/blob/master/keras/engine/topology.py # https://github.com/fchollet/keras/blob/master/LICENSE # Kerasの学習データ抜き出しスクリプト import h5py import numpy as np input_hd5 = 'cnn.h5' def save_weights(filepath, by_name=False): f = h5py.File(filepath, mode='r') if 'layer_names' not in f.attrs and 'model_weights' in f: f = f['model_weights'] save_weights_from_hdf5_group(f) if hasattr(f, 'close'): f.close() def save_weights_from_hdf5_group(f): layer_names = [n.decode('utf8') for n in f.attrs['layer_names']] print(layer_names) filtered_layer_names = [] for name in layer_names: g = f[name] weight_names = [n.decode('utf8') for n in g.attrs['weight_names']] if len(weight_names): filtered_layer_names.append(name) layer_names = filtered_layer_names print(layer_names) for k, name in enumerate(layer_names): g = f[name] weight_names = [n.decode('utf8') for n in g.attrs['weight_names']] weight_values = [g[weight_name] for weight_name in weight_names] print('') print(name) print(weight_names) for weight_name in weight_names: data = g[weight_name].value print(data.shape) filename = weight_name.replace(':0', '_z') + '.npy' np.save(filename, data, allow_pickle=False) print("save %s to %s" % (weight_name, filename)) save_weights(input_hd5)
これをhd5のあるディレクトリで実行すると、kerasのネットワークに対応した名前のnpy形式で保存します。ポイントは、np.save()を使ってnumpyを保存するときに、allow_pickleにFalseを指定することです。こうすることで、Pickle形式ではなく、連続したCの配列になります。
以下、実行結果(抜粋)。kerasのネットワークに対応したファイル名で保存しています。
convolution2d_1 ['convolution2d_1_W:0', 'convolution2d_1_b:0'] (3, 3, 1, 32) save convolution2d_1_W:0 to convolution2d_1_W_z.npy (32,) save convolution2d_1_b:0 to convolution2d_1_b_z.npy
numpy形式をCで読み出す。
numpyのフォーマットはこちらを参照。
https://docs.scipy.org/doc/numpy/neps/npy-format.html
先頭から、8byte, 9byteにヘッダー長があるので、その長さ+8以降が実際のデータです。先にpythonで値を確認します。
import numpy as np data = np.load('dense_1_b_z.npy') print(data)
これを実行すると、このように出力されます。
[ 0.01447727 -0.00560358 -0.03917709 -0.00773129 0.01669481 0.02331016 -0.02222279 -0.00773809 0.0162074 -0.00154851 -0.04478939 -0.00506137 0.01831469 -0.02160945 0.01513562 -0.00983882 -0.02593831 -0.02709689 0.00020504 -0.00142339 -0.00577859 -0.00883115 0.00912447 0.03082606 0.01143611 -0.00278293 -0.00554265 -0.02509919 -0.00772283 -0.02985987 -0.02390517 -0.00474763 -0.01730801 0.00120485 -0.01142981 -0.00061552 -0.02196323 0.00577068 -0.00911395 0.01392506 0.02327628 0.00369775 -0.00224362 0.0179715 -0.04719299 0.00875815 -0.00452969 -0.03214104 -0.0069066 -0.01714456 0.01036448 0.0048428 -0.01533855 -0.00589305 0.00368132 0.00494579 0.03383888 -0.01036347 -0.01587362 0.00544504 0.00351622 0.00747294 0.04023237 -0.00225322 0.03317111 -0.04557553 -0.05552437 -0.00583258 0.01372788 0.00907361 0.02187419 0.0295448 -0.00099583 0.00802352 0.00391109 0.00583726 -0.03709769 0.00090473 -0.01646602 -0.01568777 -0.01240169 -0.00157236 -0.01959496 -0.00049798 -0.03411916 0.02637957 0.02369151 -0.01003799 -0.01855751 0.00707838 0.04559748 -0.01478013 0.00100191 0.01219569 -0.00181059 -0.01096813 0.01167649 -0.00252321 -0.02898367 0.01881598 0.01765522 -0.0138432 -0.00060348 -0.00303171 0.00660704 0.00888676 0.0260018 -0.01976014 -0.01559355 -0.00344168 0.02282669 -0.0136997 0.00279327 -0.02522321 -0.00378602 -0.02168743 -0.01284252 -0.00368761 0.0025697 0.00357005 0.01638989 0.00504998 -0.00064822 0.00188594 -0.02876503 -0.02664481 0.02615403 -0.00238846]
読みだすCのソースです。
#include <stdio.h> #include <stdlib.h> #define DATA_NUM (128) float data[DATA_NUM]; int main(void) { FILE *fp; const char *fname = "dense_1_b_z.npy"; int hsize; int i; int cnt; fp = fopen(fname, "rb"); if(fp == NULL) { printf("error can't open %s\n", fname); exit(1); } //get header size fseek(fp, 8, 0); fread(&hsize, 2, 1, fp); //skip header printf("hsize = 0x%x(%d)\n", hsize, hsize); fseek(fp, 8+2+hsize, 0); //read data cnt = fread(&data, 4, DATA_NUM, fp); printf("read cnt = %d\n", cnt); //print for(i=0;i<DATA_NUM;i++) { printf("%f\n", data[i]); } fclose(fp); return 0; }
実行するとこのように表示されて、先頭からcの配列に浮動小数点数として読み込めているのがわかります。
hsize = 0x46(70) read cnt = 128 0.014477 -0.005604 -0.039177 -0.007731 0.016695 0.023310 -0.022223 -0.007738 0.016207 -0.001549 -0.044789 -0.005061 0.018315 -0.021609 0.015136 -0.009839 -0.025938 -0.027097 0.000205 -0.001423 -0.005779 -0.008831 0.009124 0.030826 0.011436 -0.002783 -0.005543 -0.025099 -0.007723 -0.029860 -0.023905 -0.004748 -0.017308 0.001205 -0.011430 -0.000616 -0.021963 0.005771 -0.009114 0.013925 0.023276 0.003698 -0.002244 0.017971 -0.047193 0.008758 -0.004530 -0.032141 -0.006907 -0.017145 0.010364 0.004843 -0.015339 -0.005893 0.003681 0.004946 0.033839 -0.010363 -0.015874 0.005445 0.003516 0.007473 0.040232 -0.002253 0.033171 -0.045576 -0.055524 -0.005833 0.013728 0.009074 0.021874 0.029545 -0.000996 0.008024 0.003911 0.005837 -0.037098 0.000905 -0.016466 -0.015688 -0.012402 -0.001572 -0.019595 -0.000498 -0.034119 0.026380 0.023692 -0.010038 -0.018558 0.007078 0.045597 -0.014780 0.001002 0.012196 -0.001811 -0.010968 0.011676 -0.002523 -0.028984 0.018816 0.017655 -0.013843 -0.000603 -0.003032 0.006607 0.008887 0.026002 -0.019760 -0.015594 -0.003442 0.022827 -0.013700 0.002793 -0.025223 -0.003786 -0.021687 -0.012843 -0.003688 0.002570 0.003570 0.016390 0.005050 -0.000648 0.001886 -0.028765 -0.026645 0.026154 -0.002388