クイックノート

ちょっとした発見・アイデアから知識の発掘を

【R】GANを使って学習してみる【keras】

ディープラーニングの中で、大きく注目を浴びているのが、
敵対的生成ネットワーク、いわゆる GAN です。

これは、学習の仕方を工夫したもので、
二つの人工知能が互いに競い合って、
学習を進めていくようになっています。

通常、学習結果をよくしようと思うと、
より多くのデータを持ってくる、
パラメータを見直すなどの方法が取られてきましたが、
競争しながら学習をする仕組みを取り入れたことで、
時間をかけて長く競わせることで、
より良い学習結果を手に入れるということも可能になってきました。

データを取ったり、どのパラメータが効果を持つかを調べることは、
面倒なことも多いですが、
単に時間をかけて勝手に学習させておくだけならそれほど苦になりません。

このように、GANを活用することで、
人工知能の精度向上のための比較的楽で効果的な道が開かれるのです。

今回は、そんなGANをR上でkerasを使って、
お試し実行してみたいと思います。

GAN の概要

学習の仕組み

GAN では、重要な登場人物として、
2つの人工知能が登場します。

一つ目は、generatorと呼ばれるもので、
これは、本物のデータに近い偽物を作ることを目的として、
学習を進めていきます。

もう一つは、discriminatorと呼ばれるもので、
generator の生み出す偽物を本物のデータと区別することを目的として、
学習を進めていきます。

f:id:u874072e:20181004154745p:plain

つまり、
- generator は discriminator の目を欺くように学習し
- discriminator は generator の生み出す偽物を見破るために学習します。

丁度、それぞれが相手と戦う形で学習することになるので、
負けたら、次は負けない様にやり返すという反撃を繰り返しているうちに、
お互いに腕に磨きをかけていくことができます。

何ができるか

学習した結果、generatorとdiscriminatorの両方が得られますが、
通常、generator の方だけが使われます。

generator は本物に近い偽物を生み出すように学習されているので、
まるで本物のような画像を作ったりすることができます。

ここで重要なのは、generator には本物のデータを直接見せているわけではなく、
また、本物に似せることを目指しているわけではないということです。
あくまで、discriminator の目を騙せるような偽物を作ることを目的としているので、
中には、本物のデータとは全く違うのにまるで本物のように見える偽物を作ることができます。
つまり、ある種のオリジナル性を持った偽物が生み出される可能性もあるということです。

R で実装してみる

今回は、GAN をRを使って実装してみます。
keras パッケージを利用することで、比較的実装が楽になりますが、
keras 自身に GAN は含まれていないので、ある程度は自前で実装することになります。

keras の導入方法に関しては、過去の記事をご参照ください。

clean-copy-of-onenote.hatenablog.com

パッケージをロードする

まずは使用するパッケージをロードします。
今回は、kerasabindを利用します。

library(keras)
library(abind)

モデルの生成

generator のモデルを生成

主役の一人 generator のモデルを記述します。
モデルはデータに合わせて適宜調整すると良いですが、
後の例で、時系列データを使うのでLSTMをベースにモデルを構築しています。

gen_generator = function(z_dim,x_dim){
  model = keras_model_sequential()
  
  hist_gen <<- NULL

  model %>% 
    layer_lstm(units = 32,input_shape = z_dim, return_sequences = T) %>%
    layer_dropout(rate=0.6) %>%
    layer_lstm(units = 16) %>%
    layer_dropout(rate=0.4) %>%
    layer_dense(units=prod(x_dim),activation = "linear") %>%
    layer_dropout(rate=0.2) %>%
    layer_reshape(x_dim)
  
  return(model)
}

generator の入力はノイズを与えることが多いのですが、
そのノイズの次元をz_dimで指定するようにしています。

また、generator が作り出す偽物のデータの次元をx_dimで与えています。

discriminator のモデルを生成

もう一人の主役 discriminator のモデルを記述します。
こちらもデータに合わせてモデルを調整しましょう。

gen_discriminator = function(x_dim){
  model = keras_model_sequential()
  
  hist_disc <<- NULL
  
  model %>%
    layer_lstm(units = 32, input_shape = x_dim, return_sequences = T) %>%
    layer_dropout(rate=0.6) %>%
    layer_lstm(units = 16) %>%
    layer_dropout(rate=0.4)%>%
    layer_dense(units = 32) %>%
    layer_dropout(rate=0.2) %>%
    layer_dense(units = 1,activation = "sigmoid")
  
  model %>% compile(loss="binary_crossentropy",
                    optimizer = optimizer_adam(),
                    metrics="accuracy")
  
  return(model)
}

x_dimは入力される本物と偽物のデータの次元を表しています。
出力は偽物か本物かを0,1で表すので、1次元としています。

generator 学習用のモデルを生成

discriminator はそれ単体で学習するのですが、
generator は discriminator と繋げた状態で、
discriminator の真偽の判断プロセスを見ながら学習していきます。

そのため、この2つを繋いで、generatorの学習のみを行うモデルを生成します。

gen_combined = function(generator,discriminator){
  combined = keras_model_sequential()
  discriminator$trainable = FALSE
  combined %>%
    generator %>%
    discriminator
  
  combined %>% compile(loss="binary_crossentropy",
                       optimizer = optimizer_adam(),
                       metrics="accuracy")
  return(combined)
}

kerasの便利なところで、 generator と discriminator を自然にそのまま繋げたモデルを生成しています。
discriminator$trainable=FALSE とすることで、
discriminatorの学習を無効にしています。

学習

generator の学習

上で述べたように、generator は discriminator と繋げた状態で学習します。

学習の際には、generator にノイズを入力して、
生み出された偽物のデータを使って、
discriminator がそれを本物だと間違うように、
学習を進めていきます。

train_generator = function(generator,discriminator,combined,N,z_dim){
  x =  rnorm(2*N*prod(z_dim)) %>% array_reshape(c(2*N,z_dim))
  y = rep(1,2*N)
  
  discriminator$trainable = FALSE
  combined %>% fit(x,y,
                   epochs=1)->tmp
  hist_gen <<- rbind(hist_gen,tmp)
  
  return(generator)
}

discriminator の学習

discriminator の学習は、discriminator 単体で行います。

ただし、入力には、generator が生成した偽物データと、
元々用意していた本物データを混ぜて入力し、
それを正しく本物と偽物を区別できるように学習を進めます。

train_discriminator = function(generator,discriminator,z_dim,real_data){
  N = dim(real_data)[1]
  x_dim = dim(real_data)[-1]
  
  noises = rnorm(N*prod(z_dim)) %>% array_reshape(c(N,z_dim))
  fake_data = generator %>% predict(noises)
  
  x = abind(real_data,fake_data,along=1)
  x = array_reshape(x,c(2*N,x_dim))
  y = c(rep(1,N),rep(0,N))
  
  discriminator$trainable = TRUE
  discriminator %>% fit(x,y,
                        epochs=1,batch_size = N,shuffle = F) -> tmp
  hist_disc <<- rbind(hist_disc,tmp)
  return(discriminator)
}

全体の学習

上の手順に沿った学習を、discriminato と generator が順に行うことで、
1回の学習が終わります。
これを何度も繰り返すことで、
お互いに学習が進んでいきます。

train_gan = function(generator,discriminator,z_dim,real_data,epochs){
  N = dim(real_data)[1]
  x_dim = dim(real_data)[-1]

  combined = gen_combined(generator,discriminator)
  for(i in 1:epochs){
    print(paste("epoch",i))
    discriminator = train_discriminator(generator,discriminator,z_dim,real_data)
    generator = train_generator(generator,discriminator,combined,N,z_dim)
  }
  return(list(discriminator=discriminator,generator=generator))
}

テスト実行

GAN での学習が実装できたので、テスト実行してみましょう。

画像などを使ってもいいのですが、
結果が分かりやすいグラフでテストしてみましょう。

データ

GAN を使う際には、偽物を作る元となる本物のデータが必要です。
generator に直接入力はされませんが、
discriminator の反応をうかがいながら、
間接的に本物のデータに近い偽物が作られていきます。

今回は、結果を分かりやすくしたいので、
データとしてサイン波を与えることにします。

N=20
x = t(sapply(0:19,function(i){sin(1:N/pi+pi/10*i)}))
x = array_reshape(x,c(20,N,1))

位相をずらしながら、サイン波を20本用意しました。

サイン波が本物だと知らないgeneratorが、
サイン波に近いような波形を生み出せるかがポイントになります。

学習の実行

それでは、generator と discriminator を生成して、
学習させます。

  z_dim = c(1,1)
  x_dim = dim(x)[-1]
  
  generator = gen_generator(z_dim,x_dim)
  discriminator = gen_discriminator(x_dim)
  
  res = train_gan(generator,discriminator,z_dim,x,10000)

とりあえず1万回の学習を行いました。

学習の進み具合

学習のログをhist_gen,hist_discに格納するようにしているので、
学習の進み具合を確認してみましょう。

  ts.plot(cbind(apply(hist_disc,1,function(x){x$metrics$acc}),
                apply(hist_gen,1,function(x){x$metrics$acc})),col=1:2)
  legend("left",c("disc","gen"),col=1:2,lty=1)

f:id:u874072e:20181004163812p:plain:w400

こちらは精度になります。
徐々に generator が discriminator を騙せるようになってきているようすが分かります。

  ts.plot(cbind(apply(hist_disc,1,function(x){x$metrics$loss}),
                apply(hist_gen,1,function(x){x$metrics$loss})),col=1:2)
  legend("left",c("disc","gen"),col=1:2,lty=1)

f:id:u874072e:20181004164056p:plain:w400

こちらはロス関数です。
大きな時間でみると徐々にロス関数が減少していて、
学習が進んでいる様子が分かりますね。

学習の結果

最後に学習の結果をプロットしてみましょう。

  noises = rnorm(prod(z_dim)) %>% array_reshape(c(1,z_dim))
  fake_data = res$generator %>% predict(noises)
  fake_data = fake_data[1,,]
  
  plot(fake_data,type="l")
  lines(x[1,,1],col=2)

f:id:u874072e:20181004164214p:plain:w400

黒がgeneratorによって生成された波形ですが、
なんとなく赤のサイン波に近い?でしょうか。

まだまだ学習期間を長くすると、
もう少しサイン波に近づけるかもしれませんね。

まとめ

GAN の R での実装とテスト実行をしてみました。

二つの人工知能を競い合わせて、
互いに精度を高め合うというアイデアは非常に面白いですよね。

すでに面白い応用が生まれつつある技術ですが、
より違った応用が見れるんじゃないかと期待させられます。

プライバシーポリシー