{ "cells": [ { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "import os\n", "import numpy as np \n", "import pandas as pd \n", "import tensorflow as tf\n", "import keras\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": 115, "metadata": {}, "outputs": [], "source": [ "# Data info \n", "rootDir = '..your_data_path/'\n", "\n", "Dataset = 'NCLT'\n", "TrainOrTest = '/Train/'\n", "SequenceDate = '2012-01-15'\n", "\n", "SCImiddlePath = '/5. SCI_jet0to15_BackAug/'\n", "\n", "GridCellSize = '10'" ] }, { "cell_type": "code", "execution_count": 116, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/media/gskim/IRAP-ADV1/Data/ICRA2019/NCLT/Train/2012-01-15/5. SCI_jet0to15_BackAug/10/\n" ] } ], "source": [ "DataPath = rootDir + Dataset + TrainOrTest + SequenceDate + SCImiddlePath + GridCellSize + '/'\n", "print(DataPath)" ] }, { "cell_type": "code", "execution_count": 181, "metadata": {}, "outputs": [], "source": [ "def getTrainingDataNCLT(DataPath, SequenceDate): \n", "\n", " # info\n", " WholeData = os.listdir(DataPath)\n", " nWholeData = len(WholeData)\n", " print(str(nWholeData) + ' data exist in ' + SequenceDate)\n", " \n", " # read \n", " X = []\n", " y = []\n", " for ii in range(nWholeData):\n", " dataName = WholeData[ii]\n", " dataPath = DataPath + dataName\n", " \n", " dataTrajNodeOrder = int(dataName[0:5])\n", "\n", " SCI = plt.imread(dataPath)\n", " dataPlaceIndex = int(dataName[6:11])\n", " \n", " X.append(SCI)\n", " y.append(dataPlaceIndex)\n", " \n", " # progress message \n", " if ii%1000==0:\n", " print(str(format((ii/nWholeData)*100, '.1f')), '% loaded.')\n", " \n", " \n", " dataShape = SCI.shape\n", " \n", " # X\n", " X_nd = np.zeros(shape=(nWholeData, dataShape[0], dataShape[1], dataShape[2]))\n", " for jj in range(len(X)):\n", " X_nd[jj, :, :] = X[jj]\n", " X_nd = X_nd.astype('float32')\n", " \n", " # y (one-hot encoded)\n", " from sklearn.preprocessing import LabelEncoder\n", " lbl_enc = LabelEncoder()\n", " lbl_enc.fit(y)\n", " \n", " ClassesTheSequenceHave = lbl_enc.classes_\n", " nClassesTheSequenceHave = len(ClassesTheSequenceHave)\n", " \n", " y = lbl_enc.transform(y)\n", " y_nd = keras.utils.np_utils.to_categorical(y, num_classes=nClassesTheSequenceHave)\n", "\n", " # log message \n", " print('Data size: %s' % nWholeData)\n", " print(' ')\n", " print('Data shape:', X_nd.shape)\n", " print('Label shape:', y_nd.shape)\n", " \n", " return X_nd, y_nd, lbl_enc\n", " " ] }, { "cell_type": "code", "execution_count": 182, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "15078 data exist in 2012-01-15\n", "0.0 % loaded.\n", "6.6 % loaded.\n", "13.3 % loaded.\n", "19.9 % loaded.\n", "26.5 % loaded.\n", "33.2 % loaded.\n", "39.8 % loaded.\n", "46.4 % loaded.\n", "53.1 % loaded.\n", "59.7 % loaded.\n", "66.3 % loaded.\n", "73.0 % loaded.\n", "79.6 % loaded.\n", "86.2 % loaded.\n", "92.9 % loaded.\n", "99.5 % loaded.\n", "Data size: 15078\n", "Data shape: (15078, 40, 120, 3)\n", "Label shape: (15078, 579)\n" ] } ], "source": [ "[X, y, lbl_enc] = getTrainingDataNCLT(DataPath, SequenceDate)" ] }, { "cell_type": "code", "execution_count": 172, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(40, 120, 3)\n", "(579,)\n" ] } ], "source": [ "dataShape = X[0].shape\n", "labelShape = y[0].shape\n", "\n", "print(dataShape)\n", "print(labelShape)" ] }, { "cell_type": "code", "execution_count": 173, "metadata": {}, "outputs": [], "source": [ "# Model \n", "from keras import backend as K\n", "K.clear_session()\n", "\n", "ModelName = 'my_model'\n", "\n", "Drop1 = 0.7\n", "Drop2 = 0.7\n", "\n", "KernelSize = 5\n", "\n", "nConv1Filter = 64\n", "nConv2Filter = 128\n", "nConv3Filter = 256\n", "\n", "nFCN1 = 64\n", "\n", "inputs = keras.layers.Input(shape=(dataShape[0], dataShape[1], dataShape[2]))\n", "x = keras.layers.Conv2D(filters=nConv1Filter, kernel_size=KernelSize, activation='relu', padding='same')(inputs)\n", "x = keras.layers.MaxPooling2D(pool_size=(2, 2), strides=None, padding='valid')(x)\n", "x = keras.layers.BatchNormalization()(x)\n", "x = keras.layers.Conv2D(filters=nConv2Filter, kernel_size=KernelSize, activation='relu', padding='same')(x)\n", "x = keras.layers.MaxPool2D()(x)\n", "x = keras.layers.BatchNormalization()(x)\n", "x = keras.layers.Conv2D(filters=nConv3Filter, kernel_size=KernelSize, activation='relu', padding='same')(x)\n", "x = keras.layers.MaxPool2D()(x)\n", "x = keras.layers.Flatten()(x)\n", "x = keras.layers.Dropout(rate=Drop1)(x)\n", "x = keras.layers.Dense(units=nFCN1)(x)\n", "x = keras.layers.Dropout(rate=Drop2)(x)\n", "outputs = keras.layers.Dense(units=labelShape[0], activation='softmax')(x)\n", "\n", "model = keras.models.Model(inputs=inputs, outputs=outputs)" ] }, { "cell_type": "code", "execution_count": 174, "metadata": {}, "outputs": [], "source": [ "# Model Compile \n", "model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc'])\n", "model.build(None,)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "scrolled": true }, "outputs": [], "source": [ "#Train\n", "X_train = X\n", "y_train = y\n", "\n", "nEpoch = 200\n", "\n", "model.fit(X_train,\n", " y_train,\n", " epochs=nEpoch,\n", " batch_size=64,\n", " verbose=1\n", " )" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "#Train\n", "X_train = X\n", "y_train = y\n", "\n", "nEpoch = 200\n", "\n", "model.fit(X_train,\n", " y_train,\n", " epochs=nEpoch,\n", " batch_size=64,\n", " verbose=1\n", " )" ] }, { "cell_type": "code", "execution_count": 186, "metadata": {}, "outputs": [], "source": [ "# model save \n", "modelName = 'ModelWS_' + GridCellSize + 'm_' + ModelName + '.h5'\n", "model.save(modelName)\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.2" } }, "nbformat": 4, "nbformat_minor": 2 }