scan-context/src/example/longterm_localization/NCLT/2012-01-15/2_Train/.ipynb_checkpoints/train-checkpoint.ipynb

303 lines
7.5 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import numpy as np \n",
"import pandas as pd \n",
"import tensorflow as tf\n",
"import keras\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 115,
"metadata": {},
"outputs": [],
"source": [
"# Data info \n",
"rootDir = '..your_data_path/'\n",
"\n",
"Dataset = 'NCLT'\n",
"TrainOrTest = '/Train/'\n",
"SequenceDate = '2012-01-15'\n",
"\n",
"SCImiddlePath = '/5. SCI_jet0to15_BackAug/'\n",
"\n",
"GridCellSize = '10'"
]
},
{
"cell_type": "code",
"execution_count": 116,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/media/gskim/IRAP-ADV1/Data/ICRA2019/NCLT/Train/2012-01-15/5. SCI_jet0to15_BackAug/10/\n"
]
}
],
"source": [
"DataPath = rootDir + Dataset + TrainOrTest + SequenceDate + SCImiddlePath + GridCellSize + '/'\n",
"print(DataPath)"
]
},
{
"cell_type": "code",
"execution_count": 181,
"metadata": {},
"outputs": [],
"source": [
"def getTrainingDataNCLT(DataPath, SequenceDate): \n",
"\n",
" # info\n",
" WholeData = os.listdir(DataPath)\n",
" nWholeData = len(WholeData)\n",
" print(str(nWholeData) + ' data exist in ' + SequenceDate)\n",
" \n",
" # read \n",
" X = []\n",
" y = []\n",
" for ii in range(nWholeData):\n",
" dataName = WholeData[ii]\n",
" dataPath = DataPath + dataName\n",
" \n",
" dataTrajNodeOrder = int(dataName[0:5])\n",
"\n",
" SCI = plt.imread(dataPath)\n",
" dataPlaceIndex = int(dataName[6:11])\n",
" \n",
" X.append(SCI)\n",
" y.append(dataPlaceIndex)\n",
" \n",
" # progress message \n",
" if ii%1000==0:\n",
" print(str(format((ii/nWholeData)*100, '.1f')), '% loaded.')\n",
" \n",
" \n",
" dataShape = SCI.shape\n",
" \n",
" # X\n",
" X_nd = np.zeros(shape=(nWholeData, dataShape[0], dataShape[1], dataShape[2]))\n",
" for jj in range(len(X)):\n",
" X_nd[jj, :, :] = X[jj]\n",
" X_nd = X_nd.astype('float32')\n",
" \n",
" # y (one-hot encoded)\n",
" from sklearn.preprocessing import LabelEncoder\n",
" lbl_enc = LabelEncoder()\n",
" lbl_enc.fit(y)\n",
" \n",
" ClassesTheSequenceHave = lbl_enc.classes_\n",
" nClassesTheSequenceHave = len(ClassesTheSequenceHave)\n",
" \n",
" y = lbl_enc.transform(y)\n",
" y_nd = keras.utils.np_utils.to_categorical(y, num_classes=nClassesTheSequenceHave)\n",
"\n",
" # log message \n",
" print('Data size: %s' % nWholeData)\n",
" print(' ')\n",
" print('Data shape:', X_nd.shape)\n",
" print('Label shape:', y_nd.shape)\n",
" \n",
" return X_nd, y_nd, lbl_enc\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 182,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"15078 data exist in 2012-01-15\n",
"0.0 % loaded.\n",
"6.6 % loaded.\n",
"13.3 % loaded.\n",
"19.9 % loaded.\n",
"26.5 % loaded.\n",
"33.2 % loaded.\n",
"39.8 % loaded.\n",
"46.4 % loaded.\n",
"53.1 % loaded.\n",
"59.7 % loaded.\n",
"66.3 % loaded.\n",
"73.0 % loaded.\n",
"79.6 % loaded.\n",
"86.2 % loaded.\n",
"92.9 % loaded.\n",
"99.5 % loaded.\n",
"Data size: 15078\n",
"Data shape: (15078, 40, 120, 3)\n",
"Label shape: (15078, 579)\n"
]
}
],
"source": [
"[X, y, lbl_enc] = getTrainingDataNCLT(DataPath, SequenceDate)"
]
},
{
"cell_type": "code",
"execution_count": 172,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(40, 120, 3)\n",
"(579,)\n"
]
}
],
"source": [
"dataShape = X[0].shape\n",
"labelShape = y[0].shape\n",
"\n",
"print(dataShape)\n",
"print(labelShape)"
]
},
{
"cell_type": "code",
"execution_count": 173,
"metadata": {},
"outputs": [],
"source": [
"# Model \n",
"from keras import backend as K\n",
"K.clear_session()\n",
"\n",
"ModelName = 'my_model'\n",
"\n",
"Drop1 = 0.7\n",
"Drop2 = 0.7\n",
"\n",
"KernelSize = 5\n",
"\n",
"nConv1Filter = 64\n",
"nConv2Filter = 128\n",
"nConv3Filter = 256\n",
"\n",
"nFCN1 = 64\n",
"\n",
"inputs = keras.layers.Input(shape=(dataShape[0], dataShape[1], dataShape[2]))\n",
"x = keras.layers.Conv2D(filters=nConv1Filter, kernel_size=KernelSize, activation='relu', padding='same')(inputs)\n",
"x = keras.layers.MaxPooling2D(pool_size=(2, 2), strides=None, padding='valid')(x)\n",
"x = keras.layers.BatchNormalization()(x)\n",
"x = keras.layers.Conv2D(filters=nConv2Filter, kernel_size=KernelSize, activation='relu', padding='same')(x)\n",
"x = keras.layers.MaxPool2D()(x)\n",
"x = keras.layers.BatchNormalization()(x)\n",
"x = keras.layers.Conv2D(filters=nConv3Filter, kernel_size=KernelSize, activation='relu', padding='same')(x)\n",
"x = keras.layers.MaxPool2D()(x)\n",
"x = keras.layers.Flatten()(x)\n",
"x = keras.layers.Dropout(rate=Drop1)(x)\n",
"x = keras.layers.Dense(units=nFCN1)(x)\n",
"x = keras.layers.Dropout(rate=Drop2)(x)\n",
"outputs = keras.layers.Dense(units=labelShape[0], activation='softmax')(x)\n",
"\n",
"model = keras.models.Model(inputs=inputs, outputs=outputs)"
]
},
{
"cell_type": "code",
"execution_count": 174,
"metadata": {},
"outputs": [],
"source": [
"# Model Compile \n",
"model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc'])\n",
"model.build(None,)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"#Train\n",
"X_train = X\n",
"y_train = y\n",
"\n",
"nEpoch = 200\n",
"\n",
"model.fit(X_train,\n",
" y_train,\n",
" epochs=nEpoch,\n",
" batch_size=64,\n",
" verbose=1\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"#Train\n",
"X_train = X\n",
"y_train = y\n",
"\n",
"nEpoch = 200\n",
"\n",
"model.fit(X_train,\n",
" y_train,\n",
" epochs=nEpoch,\n",
" batch_size=64,\n",
" verbose=1\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 186,
"metadata": {},
"outputs": [],
"source": [
"# model save \n",
"modelName = 'ModelWS_' + GridCellSize + 'm_' + ModelName + '.h5'\n",
"model.save(modelName)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}