You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

302 lines
7.5 KiB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np \n",
    "import pandas as pd \n",
    "import tensorflow as tf\n",
    "import keras\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Data info \n",
    "rootDir = '..your_data_path/'\n",
    "\n",
    "Dataset = 'NCLT'\n",
    "TrainOrTest = '/Train/'\n",
    "SequenceDate = '2012-01-15'\n",
    "\n",
    "SCImiddlePath = '/5. SCI_jet0to15_BackAug/'\n",
    "\n",
    "GridCellSize = '10'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 116,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/media/gskim/IRAP-ADV1/Data/ICRA2019/NCLT/Train/2012-01-15/5. SCI_jet0to15_BackAug/10/\n"
     ]
    }
   ],
   "source": [
    "DataPath = rootDir + Dataset + TrainOrTest + SequenceDate + SCImiddlePath + GridCellSize + '/'\n",
    "print(DataPath)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 181,
   "metadata": {},
   "outputs": [],
   "source": [
    "def getTrainingDataNCLT(DataPath, SequenceDate):    \n",
    "\n",
    "    # info\n",
    "    WholeData = os.listdir(DataPath)\n",
    "    nWholeData = len(WholeData)\n",
    "    print(str(nWholeData) + ' data exist in ' + SequenceDate)\n",
    "    \n",
    "    # read \n",
    "    X = []\n",
    "    y = []\n",
    "    for ii in range(nWholeData):\n",
    "        dataName = WholeData[ii]\n",
    "        dataPath = DataPath + dataName\n",
    "        \n",
    "        dataTrajNodeOrder = int(dataName[0:5])\n",
    "\n",
    "        SCI = plt.imread(dataPath)\n",
    "        dataPlaceIndex = int(dataName[6:11])\n",
    "        \n",
    "        X.append(SCI)\n",
    "        y.append(dataPlaceIndex)\n",
    "        \n",
    "        # progress message \n",
    "        if ii%1000==0:\n",
    "            print(str(format((ii/nWholeData)*100, '.1f')), '% loaded.')\n",
    "        \n",
    "        \n",
    "    dataShape = SCI.shape\n",
    "    \n",
    "    # X\n",
    "    X_nd = np.zeros(shape=(nWholeData, dataShape[0], dataShape[1], dataShape[2]))\n",
    "    for jj in range(len(X)):\n",
    "        X_nd[jj, :, :] = X[jj]\n",
    "    X_nd = X_nd.astype('float32')\n",
    "    \n",
    "    # y (one-hot encoded)\n",
    "    from sklearn.preprocessing import LabelEncoder\n",
    "    lbl_enc = LabelEncoder()\n",
    "    lbl_enc.fit(y)\n",
    "    \n",
    "    ClassesTheSequenceHave = lbl_enc.classes_\n",
    "    nClassesTheSequenceHave = len(ClassesTheSequenceHave)\n",
    "    \n",
    "    y = lbl_enc.transform(y)\n",
    "    y_nd = keras.utils.np_utils.to_categorical(y, num_classes=nClassesTheSequenceHave)\n",
    "\n",
    "    # log message \n",
    "    print('Data size: %s' % nWholeData)\n",
    "    print(' ')\n",
    "    print('Data shape:', X_nd.shape)\n",
    "    print('Label shape:', y_nd.shape)\n",
    "    \n",
    "    return X_nd, y_nd, lbl_enc\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 182,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "15078 data exist in 2012-01-15\n",
      "0.0 % loaded.\n",
      "6.6 % loaded.\n",
      "13.3 % loaded.\n",
      "19.9 % loaded.\n",
      "26.5 % loaded.\n",
      "33.2 % loaded.\n",
      "39.8 % loaded.\n",
      "46.4 % loaded.\n",
      "53.1 % loaded.\n",
      "59.7 % loaded.\n",
      "66.3 % loaded.\n",
      "73.0 % loaded.\n",
      "79.6 % loaded.\n",
      "86.2 % loaded.\n",
      "92.9 % loaded.\n",
      "99.5 % loaded.\n",
      "Data size: 15078\n",
      "Data shape: (15078, 40, 120, 3)\n",
      "Label shape: (15078, 579)\n"
     ]
    }
   ],
   "source": [
    "[X, y, lbl_enc] = getTrainingDataNCLT(DataPath, SequenceDate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 172,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(40, 120, 3)\n",
      "(579,)\n"
     ]
    }
   ],
   "source": [
    "dataShape = X[0].shape\n",
    "labelShape = y[0].shape\n",
    "\n",
    "print(dataShape)\n",
    "print(labelShape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 173,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model \n",
    "from keras import backend as K\n",
    "K.clear_session()\n",
    "\n",
    "ModelName = 'my_model'\n",
    "\n",
    "Drop1 = 0.7\n",
    "Drop2 = 0.7\n",
    "\n",
    "KernelSize = 5\n",
    "\n",
    "nConv1Filter = 64\n",
    "nConv2Filter = 128\n",
    "nConv3Filter = 256\n",
    "\n",
    "nFCN1 = 64\n",
    "\n",
    "inputs = keras.layers.Input(shape=(dataShape[0], dataShape[1], dataShape[2]))\n",
    "x = keras.layers.Conv2D(filters=nConv1Filter, kernel_size=KernelSize, activation='relu', padding='same')(inputs)\n",
    "x = keras.layers.MaxPooling2D(pool_size=(2, 2), strides=None, padding='valid')(x)\n",
    "x = keras.layers.BatchNormalization()(x)\n",
    "x = keras.layers.Conv2D(filters=nConv2Filter, kernel_size=KernelSize, activation='relu', padding='same')(x)\n",
    "x = keras.layers.MaxPool2D()(x)\n",
    "x = keras.layers.BatchNormalization()(x)\n",
    "x = keras.layers.Conv2D(filters=nConv3Filter, kernel_size=KernelSize, activation='relu', padding='same')(x)\n",
    "x = keras.layers.MaxPool2D()(x)\n",
    "x = keras.layers.Flatten()(x)\n",
    "x = keras.layers.Dropout(rate=Drop1)(x)\n",
    "x = keras.layers.Dense(units=nFCN1)(x)\n",
    "x = keras.layers.Dropout(rate=Drop2)(x)\n",
    "outputs = keras.layers.Dense(units=labelShape[0], activation='softmax')(x)\n",
    "\n",
    "model = keras.models.Model(inputs=inputs, outputs=outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 174,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model Compile \n",
    "model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc'])\n",
    "model.build(None,)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "#Train\n",
    "X_train = X\n",
    "y_train = y\n",
    "\n",
    "nEpoch = 200\n",
    "\n",
    "model.fit(X_train,\n",
    "          y_train,\n",
    "          epochs=nEpoch,\n",
    "          batch_size=64,\n",
    "          verbose=1\n",
    "         )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Train\n",
    "X_train = X\n",
    "y_train = y\n",
    "\n",
    "nEpoch = 200\n",
    "\n",
    "model.fit(X_train,\n",
    "          y_train,\n",
    "          epochs=nEpoch,\n",
    "          batch_size=64,\n",
    "          verbose=1\n",
    "         )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 186,
   "metadata": {},
   "outputs": [],
   "source": [
    "# model save \n",
    "modelName = 'ModelWS_' + GridCellSize + 'm_' + ModelName + '.h5'\n",
    "model.save(modelName)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}

新建文件(夹)
上传文件(夹)
scan context