「くずし字」識別チャレンジ

β版ProbSpaceコンペ第2弾!

賞金: 100,000 参加チーム数: 138 4ヶ月前に終了

DenseNetの実装

今回はDenseNet(https://arxiv.org/pdf/1608.06993.pdf) というアーキテクチャを実装しました。

DenseNetは

1 Dense Block  

2 Transition Layer

の2つの部分で構成されています。また、Dense Blockは

D-1 フィルター数128の1x1の畳み込み層

D-2 フィルター数k(*kは任意の実数)の3x3の畳み込み層

*kは成長率と呼ばれる指標で、どれくらいの情報を新たにモデルに取り込むかを表します。今回はk=16で実装しました。

D-3フィルター数n+kのConcanate層

の3つ、Transition Layerは

T-1 1x1の畳み込み層

tー2 2x2のAverage Pooling

で構成されています。

DenseNetは畳み込み層とプーリング層の間にDense BlockとTransition Layerを交互に挟み込むモデルです。 これらを挟み込み事で、最適化が必要なパラメータ数を大幅に削減することができ、これにより学習効率をあげることが出来ます。

では、実装のコードを見ていきましょう。

def dense_block(x, k, n_block):
  for i in range(n_block):
    main = x
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    #1x1
    x = Conv2D(filters = 64, kernel_size = (1, 1), padding = 'valid')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    #3x3
    x = Conv2D(filters = k, kernel_size = (3, 3), padding = 'same')(x)
    #concatenate
    x = Concatenate()([main, x])

  return x

def transition_layer(inputs, compression = 0.5):
  n_channel = int(inputs.shape[3])
  filters = int(n_channel * compression)
  x = Conv2D(filters = filters, kernel_size = (1, 1))(inputs)
  outputs = AveragePooling2D(pool_size = (2, 2))(x)

  return outputs

def DenseNet():
  inputs = Input(shape = (28, 28, 1))

  x = dense_block(inputs, k = 16, n_block = 1)
  x = transition_layer(x, compression = 0.5)
  x = dense_block(x, k = 16, n_block = 2)
  x = transition_layer(x, compression = 0.5)
  x = dense_block(x, k = 16, n_block = 4)
  x = transition_layer(x, compression = 0.5)
  x = dense_block(x, k = 16, n_block = 3)

  x = GlobalAveragePooling2D()(x)

  x = Dense(512)(x)
  x = BatchNormalization()(x)
  x = Activation('relu')(x)

  x = Dense(n_class)(x)

  outputs = Activation('softmax')(x)

  return Model(inputs, outputs)

隠れ層を通常の畳み込み層x4にした時の正答率は97.0%でしたが、DenseNetに変えることで97.5%に向上しました。

参考リンク(https://qiita.com/koshian2/items/01bd9f08444799625607) 貼っておくので、皆さんも試してみて下さい。

Favicon
new user
コメントするには 新規登録 もしくは ログイン が必要です。