準備

Googleドライブのマウント

In [1]:
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive

sys.pathの設定

以下では,Googleドライブのマイドライブ直下にDNN_codeフォルダを置くことを仮定しています.必要に応じて,パスを変更してください.

In [2]:
import sys
sys.path.append('/content/drive/My Drive/DNN_code_colab_ver200425')

predict sin


[try]

  • iters_numを100にしよう
  • maxlenを5, iters_numを500, 3000(※時間がかかる)にしよう

In [7]:
import numpy as np
from common import functions
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

np.random.seed(0)

# sin曲線
round_num = 10
div_num = 500
ts = np.linspace(0, round_num * np.pi, div_num)
f = np.sin(ts)

def d_tanh(x):
    return 1/(np.cosh(x)**2 + 1e-4)

# ひとつの時系列データの長さ
maxlen = 5

# sin波予測の入力データ
test_head = [[f[k]] for k in range(0, maxlen)]

data = []
target = []

for i in range(div_num - maxlen):
    data.append(f[i: i + maxlen])
    target.append(f[i + maxlen])
    
X = np.array(data).reshape(len(data), maxlen, 1)
D = np.array(target).reshape(len(data), 1)

# データ設定
N_train = int(len(data) * 0.8)
N_validation = len(data) - N_train

x_train, x_test, d_train, d_test = train_test_split(X, D, test_size=N_validation)

input_layer_size = 1
hidden_layer_size = 5
output_layer_size = 1

weight_init_std = 0.01
learning_rate = 0.1

iters_num = 3000

# ウェイト初期化 (バイアスは簡単のため省略)
W_in = weight_init_std * np.random.randn(input_layer_size, hidden_layer_size)
W_out = weight_init_std * np.random.randn(hidden_layer_size, output_layer_size)
W = weight_init_std * np.random.randn(hidden_layer_size, hidden_layer_size)

# 勾配
W_in_grad = np.zeros_like(W_in)
W_out_grad = np.zeros_like(W_out)
W_grad = np.zeros_like(W)

us = []
zs = []

u = np.zeros(hidden_layer_size)
z = np.zeros(hidden_layer_size)
y = np.zeros(output_layer_size)

delta_out = np.zeros(output_layer_size)
delta = np.zeros(hidden_layer_size)

losses = []

# トレーニング
for i in range(iters_num):
    for s in range(x_train.shape[0]):
        us.clear()
        zs.clear()
        z *= 0
        
        # sにおける正解データ
        d = d_train[s]

        xs = x_train[s]        
        
        # 時系列ループ
        for t in range(maxlen):
            
            # 入力値
            x = xs[t]
            u = np.dot(x, W_in) + np.dot(z, W)
            us.append(u)
            z = np.tanh(u)
            zs.append(z)

        y = np.dot(z, W_out)
        
        #誤差
        loss = functions.mean_squared_error(d, y)
        
        delta_out = functions.d_mean_squared_error(d, y)
        
        delta *= 0
        for t in range(maxlen)[::-1]:
            
            delta = (np.dot(delta, W.T) + np.dot(delta_out, W_out.T)) * d_tanh(us[t])
            
            # 勾配更新
            W_grad += np.dot(zs[t].reshape(-1,1), delta.reshape(1,-1))
            W_in_grad += np.dot(xs[t], delta.reshape(1,-1))
        W_out_grad = np.dot(z.reshape(-1,1), delta_out)
        
        # 勾配適用
        W -= learning_rate * W_grad
        W_in -= learning_rate * W_in_grad
        W_out -= learning_rate * W_out_grad.reshape(-1,1)
            
        W_in_grad *= 0
        W_out_grad *= 0
        W_grad *= 0

# テスト        
for s in range(x_test.shape[0]):
    z *= 0

    # sにおける正解データ
    d = d_test[s]

    xs = x_test[s]

    # 時系列ループ
    for t in range(maxlen):

        # 入力値
        x = xs[t]
        u = np.dot(x, W_in) + np.dot(z, W)
        z = np.tanh(u)

    y = np.dot(z, W_out)

    #誤差
    loss = functions.mean_squared_error(d, y)
    print('loss:', loss, '   d:', d, '   y:', y)
        
        
        
original = np.full(maxlen, None)
pred_num = 200

xs = test_head

# sin波予測
for s in range(0, pred_num):
    z *= 0
    for t in range(maxlen):
        
        # 入力値
        x = xs[t]
        u = np.dot(x, W_in) + np.dot(z, W)
        z = np.tanh(u)

    y = np.dot(z, W_out)
    original = np.append(original, y)
    xs = np.delete(xs, 0)
    xs = np.append(xs, y)

plt.figure()
plt.ylim([-1.5, 1.5])
plt.plot(np.sin(np.linspace(0, round_num* pred_num / div_num * np.pi, pred_num)), linestyle='dotted', color='#aaaaaa')
plt.plot(original, linestyle='dashed', color='black')
plt.show()
loss: 1.0231756688240315e-07    d: [-0.29761864]    y: [-0.29716628]
loss: 1.2201090155920323e-08    d: [-0.56307233]    y: [-0.56322854]
loss: 4.038245319603949e-11    d: [-0.65766776]    y: [-0.65765877]
loss: 1.0126131714852026e-08    d: [0.13182648]    y: [0.13168417]
loss: 6.953992391227437e-08    d: [0.49909101]    y: [0.49871807]
loss: 5.5838333193818506e-08    d: [0.9518317]    y: [0.95149752]
loss: 1.006541644073702e-07    d: [0.97784112]    y: [0.97739245]
loss: 1.7033534573059554e-08    d: [-0.58880346]    y: [-0.58861889]
loss: 5.2622812685420113e-08    d: [-0.78351093]    y: [-0.78383534]
loss: 2.609182825988997e-10    d: [-0.49909101]    y: [-0.49906816]
loss: 5.7128497213479914e-08    d: [0.21857331]    y: [0.21823529]
loss: 9.126545094844397e-08    d: [-0.33938943]    y: [-0.3389622]
loss: 1.0304339135521172e-07    d: [-0.43793098]    y: [-0.43747701]
loss: 1.1445752972811905e-07    d: [-0.33346065]    y: [-0.3329822]
loss: 4.4870288964406624e-08    d: [-0.99639027]    y: [-0.9960907]
loss: 4.5188451220534447e-08    d: [0.88624247]    y: [0.88654309]
loss: 2.316626931563649e-08    d: [-0.92833248]    y: [-0.92811723]
loss: 7.877136644752068e-10    d: [-0.52075286]    y: [-0.52079255]
loss: 8.254988700900168e-09    d: [-0.55262221]    y: [-0.5527507]
loss: 8.291678208254147e-08    d: [0.47711265]    y: [0.47670543]
loss: 3.515702303100396e-08    d: [0.60896952]    y: [0.60923469]
loss: 4.6349673171986056e-08    d: [-0.94587102]    y: [-0.94556656]
loss: 9.366188781095447e-08    d: [0.27953518]    y: [0.27910237]
loss: 7.187670231816063e-08    d: [0.73863456]    y: [0.73901371]
loss: 2.517527202369438e-08    d: [-0.00629574]    y: [-0.00607135]
loss: 4.238282860557903e-08    d: [0.54208448]    y: [0.54179333]
loss: 5.364268767195752e-08    d: [0.99781582]    y: [0.99748827]
loss: 7.761284716240689e-08    d: [-0.96441607]    y: [-0.96402208]
loss: 7.299464765639473e-08    d: [0.07547747]    y: [0.07509539]
loss: 5.3747292183674635e-12    d: [0.66239735]    y: [0.66240063]
loss: 3.9176486316405196e-08    d: [-0.54736419]    y: [-0.54708428]
loss: 1.1396823999825883e-07    d: [0.99393675]    y: [0.99345932]
loss: 1.8916287871150376e-10    d: [0.96441607]    y: [0.96443552]
loss: 6.104853811881267e-08    d: [-0.22471249]    y: [-0.22436306]
loss: 3.3747121524395504e-08    d: [-0.99393675]    y: [-0.99367695]
loss: 7.459208770344294e-08    d: [-0.71705202]    y: [-0.71743827]
loss: 3.3346599615646084e-09    d: [0.8649742]    y: [0.86505587]
loss: 1.0502769561856412e-07    d: [-0.11933469]    y: [-0.11887638]
loss: 3.740289231639668e-08    d: [-0.40941891]    y: [-0.4091454]
loss: 1.1571745806193722e-07    d: [0.39789889]    y: [0.39741782]
loss: 1.3219146513176019e-08    d: [0.98611478]    y: [0.98595218]
loss: 6.824737983587242e-08    d: [-0.0691982]    y: [-0.06882875]
loss: 7.747335301932406e-08    d: [0.35709413]    y: [0.3567005]
loss: 4.201317088874832e-08    d: [0.99583607]    y: [0.9955462]
loss: 2.1526965553383208e-08    d: [-0.92597363]    y: [-0.92618113]
loss: 6.817128658026986e-08    d: [0.36882689]    y: [0.36845764]
loss: 1.1810524119005898e-08    d: [0.91617219]    y: [0.9160185]
loss: 1.5115146954905911e-09    d: [0.52611726]    y: [0.52617224]
loss: 8.054071771452032e-08    d: [0.96606148]    y: [0.96566013]
loss: 1.9560182092068755e-08    d: [0.43793098]    y: [0.43773319]
loss: 2.3651142877388884e-09    d: [-0.86811636]    y: [-0.86818514]
loss: 7.426592869201099e-08    d: [-0.99975723]    y: [-0.99937184]
loss: 3.369004516432079e-08    d: [0.77562491]    y: [0.77588449]
loss: 4.2428732557018664e-09    d: [0.04405617]    y: [0.04414829]
loss: 9.775107383103442e-09    d: [0.71705202]    y: [0.71719184]
loss: 4.3883218080675413e-10    d: [0.8773359]    y: [0.87736552]
loss: 3.008800038081094e-08    d: [0.91363079]    y: [0.9138761]
loss: 3.4353135858512383e-09    d: [-0.97784112]    y: [-0.97775823]
loss: 8.641444917436445e-10    d: [0.96101064]    y: [0.96105221]
loss: 8.711809080410903e-08    d: [0.26742375]    y: [0.26700633]
loss: 5.0207765974941775e-08    d: [-0.87122411]    y: [-0.871541]
loss: 2.838017375641239e-08    d: [-0.91617219]    y: [-0.91641043]
loss: 1.556722685502756e-09    d: [0.87122411]    y: [0.87127991]
loss: 9.803777837639783e-09    d: [0.98394564]    y: [0.98380562]
loss: 4.1480616132135435e-08    d: [0.4036669]    y: [0.40337887]
loss: 8.246547424839542e-08    d: [0.0880268]    y: [0.08762068]
loss: 3.519939115308866e-08    d: [0.81012572]    y: [0.81039105]
loss: 3.347592195082292e-08    d: [0.41515469]    y: [0.41489593]
loss: 3.9202132662385375e-08    d: [-0.99524241]    y: [-0.9949624]
loss: 6.676591329542326e-09    d: [-0.94789551]    y: [-0.94801107]
loss: 2.7259396135958284e-08    d: [0.16916853]    y: [0.16893503]
loss: 5.901954968723951e-08    d: [-0.95374324]    y: [-0.95339967]
loss: 7.404264207102875e-08    d: [-0.72577151]    y: [-0.72615633]
loss: 1.950089841511606e-08    d: [0.74286391]    y: [0.7430614]
loss: 9.1706563452403e-09    d: [-0.94380904]    y: [-0.94394447]
loss: 1.8342443304424838e-08    d: [-0.83516734]    y: [-0.83535887]
loss: 4.0168629527546835e-08    d: [-0.94170965]    y: [-0.94142621]
loss: 1.0438172705832991e-07    d: [0.32156366]    y: [0.32110676]
loss: 7.965196525172432e-08    d: [-0.48263615]    y: [-0.48223703]
loss: 4.522391104104177e-08    d: [0.03776568]    y: [0.03746493]
loss: 8.671646127140551e-08    d: [0.34530476]    y: [0.34488831]
loss: 3.0057469610566484e-08    d: [0.56307233]    y: [0.56282714]
loss: 3.3448244853359074e-08    d: [0.90843947]    y: [0.90869812]
loss: 9.740854716817453e-08    d: [0.99940055]    y: [0.99895917]
loss: 5.2790241458954756e-08    d: [0.85534252]    y: [0.85566745]
loss: 1.3438734027154209e-08    d: [0.93739898]    y: [0.93756292]
loss: 1.0825441262471587e-07    d: [0.99738016]    y: [0.99691486]
loss: 7.346796784236424e-08    d: [-0.6992734]    y: [-0.69965672]
loss: 9.629719356207507e-08    d: [-0.10682399]    y: [-0.10638513]
loss: 4.974432822871989e-09    d: [-0.54208448]    y: [-0.54218422]
loss: 7.137974552098255e-08    d: [0.9995987]    y: [0.99922087]
loss: 1.199131974055454e-07    d: [0.29761864]    y: [0.29712892]
loss: 3.111699813940709e-08    d: [0.99322482]    y: [0.99297535]
loss: 9.5739219591903e-08    d: [0.33346065]    y: [0.33302307]
loss: 5.181836972617623e-08    d: [-0.83168816]    y: [-0.83201008]
loss: 1.1452214926570487e-08    d: [-0.98504973]    y: [-0.98489839]
loss: 5.381547067735827e-08    d: [-0.64332332]    y: [-0.6436514]
loss: 1.0539082479537175e-07    d: [0.43226238]    y: [0.43180327]
loss: 3.266434009950708e-08    d: [-0.81380058]    y: [-0.81405617]
In [ ]: