{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Наша первая свёрточная нейросеть \n", "\n", "Пришло время построить нашу первую свёрточную нейросеть. Будем использовать для этого датасет [MNIST.](https://www.cs.toronto.edu/~kriz/cifar.html) Набор данных включает в себя изображения рукописных цифр. \n", "\n", "" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# подгружаем пакеты\n", "import numpy as np\n", "import random\n", "from tqdm import tqdm\n", "\n", "import keras\n", "from keras import backend as K\n", "\n", "%matplotlib inline\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Смотрим на данные \n", "\n", "Скачаеми приготовим данные. Буквально через минуту в наших руках окажутся $60 000$ картинок размера $28 \\times 28$." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from keras.datasets import mnist\n", "(x_tr, y_tr), (x_ts, y_ts) = mnist.load_data()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train samples: (60000, 28, 28) (60000,)\n", "Test samples: (10000, 28, 28) (10000,)\n" ] } ], "source": [ "print(\"Train samples:\", x_tr.shape, y_tr.shape)\n", "print(\"Test samples:\", x_ts.shape, y_ts.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Нарисуем несколько рандомных картинок из тренировочной выборки. " ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "cols = 8\n", "rows = 2\n", "fig = plt.figure(figsize=(2 * cols - 1, 2.5 * rows - 1))\n", "for i in range(cols):\n", " for j in range(rows):\n", " random_index = np.random.randint(0, len(y_tr))\n", " ax = fig.add_subplot(rows, cols, i * rows + j + 1)\n", " ax.grid('off')\n", " ax.axis('off')\n", " ax.imshow(x_tr[random_index, :], cmap = 'gray')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Отлично! Как вы помните, если пронормаровать данные, то сетка будет сходиться на порядок быстрее. \n", "\n", "Также, как вы помните из предыдущих скриптов, картинка - это тензор из циферок. Каждая циферка сообщает нам о яркости конкретного пикселя. Яркость измеряется по шкале от 0 до 255. В связи с этим фактом, нормализация будет немного странной: \n", "\n", "$$\n", "x_{norm} = \\frac{x}{255}\n", "$$" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# делай раз\n", "x_train = x_tr/255\n", "x_test = x_ts/255\n", "\n", "# Оставляем одну размерность на число каналов\n", "x_train = np.reshape(x_train, (len(x_train), 28, 28, 1)) \n", "x_test = np.reshape(x_test, (len(x_test), 28, 28, 1)) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Также мы помним, что классы нужно конвертировать одним горячи кодированием (one-hot encoding) в набор из дамми-переменных. \n", "\n", "```\n", "0 → [1, 0, 0, 0, 0, 0, 0, 0, 0]\n", "1 → [0, 1, 0, 0, 0, 0, 0, 0, 0]\n", "2 → [0, 0, 1, 0, 0, 0, 0, 0, 0]\n", "3 → [0, 0, 0, 1, 0, 0, 0, 0, 0]\n", "etc...\n", "```" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[0., 0., 0., ..., 1., 0., 0.],\n", " [0., 0., 1., ..., 0., 0., 0.],\n", " [0., 1., 0., ..., 0., 0., 0.],\n", " ...,\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.]])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# делай два! \n", "y_train = keras.utils.to_categorical(y_tr,10)\n", "y_test = keras.utils.to_categorical(y_ts,10)\n", "\n", "# Размерность после конвертации будет вот такой: (?, NUM_CLASSES)\n", "y_test" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(60000, 28, 28, 1)\n", "(10000, 28, 28, 1)\n" ] } ], "source": [ "print(x_train.shape)\n", "print(x_test.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Выбираем для нашей нейросети архитектуру\n", "\n", "Свёрточная нейронная сеть строится из нескольких разных типов слоёв: \n", "\n", "* [Conv2D](https://keras.io/layers/convolutional/#conv2d) - Конволюция:\n", " - **filters**: число выходных каналов; \n", " - **kernel_size**: размер окна для свёртки;\n", " - **padding**: padding=\"same\" добавляет нулевую каёмку по краям картинки, чтбы после свёртки размеры картинки не изменялись; padding='valid' ничего не добавляет;\n", " - **activation**: \"relu\", \"tanh\", etc.\n", " - **input_shape**: размер входа\n", "* [MaxPooling2D](https://keras.io/layers/pooling/#maxpooling2d) - макспулинг\n", "* [Flatten](https://keras.io/layers/core/#flatten) - разворачивает картинку в вектор \n", "* [Dense](https://keras.io/layers/core/#dense) - полносвязный слой (fully-connected layer)\n", "* [Activation](https://keras.io/layers/core/#activation) - функция активации\n", "* [LeakyReLU](https://keras.io/layers/advanced-activations/#leakyrelu) - leaky relu активация\n", "* [Dropout](https://keras.io/layers/core/#dropout) - дропаут.\n", "\n", "\n", "В модели, которую мы определим ниже, на вход будет идти тензоры размера __(None, 28, 28, 1)__ и __(None, 10)__. На выходе мы будем получать вероятноть того, что объект относится к конкретному классу. Разменость __None__ заготовлена для размерности батча. " ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# подгружаем важные строительные блоки\n", "from keras.models import Sequential\n", "from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Activation, Dropout, InputLayer, LeakyReLU, Input" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.1 Полносвязная сетка \n", "\n", "Соберём двухслойную полносвязную сетку: \n", "\n", "* Вход\n", "* Развернём картинку в вектор \n", "* Полносвязный слой с 64 нейронами\n", "* RELU \n", "* Полносвязный слой с 32 нейронами \n", "* Dropout с вероятностью 0.5 \n", "* RELU \n", "* Полносвязный слой с 16 нейронами \n", "* Dropout с вероятностью 0.5 \n", "* RELU \n", "* Слой с 10 нейронами для клссов, в качестве активации используйте Softmax" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "model_1 = Sequential( )\n", "\n", "# 1. Добавляем входной слой с указанной формой входных данных\n", "model_1.add(Input(shape=(28, 28))) # Задаём форму входных данных через Input\n", "\n", "# 2. Первый полносвязный слой с 64 нейронами и функцией активации ReLU\n", "model_1.add(Flatten()) # Разворачиваем входное изображение в вектор\n", "model_1.add(Dense(64))\n", "model_1.add(Activation('relu'))\n", "\n", "# 3. Второй полносвязный слой с 32 нейронами и функцией активации ReLU\n", "model_1.add(Dense(32))\n", "model_1.add(Dropout(0.5)) # Dropout с вероятностью 0.5 для регуляризации\n", "model_1.add(Activation('relu'))\n", "\n", "# 4. Третий полносвязный слой с 16 нейронами и функцией активации ReLU\n", "model_1.add(Dense(16))\n", "model_1.add(Dropout(0.5)) # Dropout с вероятностью 0.5 для регуляризации\n", "model_1.add(Activation('relu'))\n", "\n", "# 5. Выходной слой с 10 нейронами (для классификации на 10 классов) и функцией активации Softmax\n", "model_1.add(Dense(10))\n", "model_1.add(Activation('softmax'))\n", "\n", "model_1.compile(\"adam\", \"categorical_crossentropy\", metrics=[\"accuracy\"])" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Model: \"sequential\"\n",
       "
\n" ], "text/plain": [ "\u001b[1mModel: \"sequential\"\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓\n",
       "┃ Layer (type)                          Output Shape                         Param # ┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩\n",
       "│ flatten (Flatten)                    │ (None, 784)                 │               0 │\n",
       "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
       "│ dense (Dense)                        │ (None, 64)                  │          50,240 │\n",
       "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
       "│ activation (Activation)              │ (None, 64)                  │               0 │\n",
       "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
       "│ dense_1 (Dense)                      │ (None, 32)                  │           2,080 │\n",
       "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
       "│ dropout (Dropout)                    │ (None, 32)                  │               0 │\n",
       "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
       "│ activation_1 (Activation)            │ (None, 32)                  │               0 │\n",
       "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
       "│ dense_2 (Dense)                      │ (None, 16)                  │             528 │\n",
       "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
       "│ dropout_1 (Dropout)                  │ (None, 16)                  │               0 │\n",
       "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
       "│ activation_2 (Activation)            │ (None, 16)                  │               0 │\n",
       "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
       "│ dense_3 (Dense)                      │ (None, 10)                  │             170 │\n",
       "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
       "│ activation_3 (Activation)            │ (None, 10)                  │               0 │\n",
       "└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘\n",
       "
\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩\n", "│ flatten (\u001b[38;5;33mFlatten\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m784\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", "│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m50,240\u001b[0m │\n", "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", "│ activation (\u001b[38;5;33mActivation\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", "│ dense_1 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m2,080\u001b[0m │\n", "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", "│ dropout (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", "│ activation_1 (\u001b[38;5;33mActivation\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", "│ dense_2 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m16\u001b[0m) │ \u001b[38;5;34m528\u001b[0m │\n", "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", "│ dropout_1 (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m16\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", "│ activation_2 (\u001b[38;5;33mActivation\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m16\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", "│ dense_3 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m10\u001b[0m) │ \u001b[38;5;34m170\u001b[0m │\n", "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", "│ activation_3 (\u001b[38;5;33mActivation\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m10\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", "└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Total params: 53,018 (207.10 KB)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m53,018\u001b[0m (207.10 KB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Trainable params: 53,018 (207.10 KB)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m53,018\u001b[0m (207.10 KB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Non-trainable params: 0 (0.00 B)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model_1.summary()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Можно визуализировать внутренними средствами keras сетку, которую мы собираем. " ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "G\n", "\n", "\n", "\n", "2429137156816\n", "\n", "\n", "\n", "\n", "flatten\n", " (Flatten)\n", "\n", "\n", "Input shape: \n", "(None, 28, 28)\n", "\n", "\n", "Output shape: \n", "(None, 784)\n", "\n", "\n", "\n", "2429137149568\n", "\n", "\n", "\n", "\n", "dense\n", " (Dense)\n", "\n", "\n", "Input shape: \n", "(None, 784)\n", "\n", "\n", "Output shape: \n", "(None, 64)\n", "\n", "\n", "\n", "2429137156816->2429137149568\n", "\n", "\n", "\n", "\n", "\n", "2429137149328\n", "\n", "\n", "\n", "\n", "activation\n", " (Activation)\n", "\n", "\n", "Input shape: \n", "(None, 64)\n", "\n", "\n", "Output shape: \n", "(None, 64)\n", "\n", "\n", "\n", "2429137149568->2429137149328\n", "\n", "\n", "\n", "\n", "\n", "2429137157632\n", "\n", "\n", "\n", "\n", "dense_1\n", " (Dense)\n", "\n", "\n", "Input shape: \n", "(None, 64)\n", "\n", "\n", "Output shape: \n", "(None, 32)\n", "\n", "\n", "\n", "2429137149328->2429137157632\n", "\n", "\n", "\n", "\n", "\n", "2429137123472\n", "\n", "\n", "\n", "\n", "dropout\n", " (Dropout)\n", "\n", "\n", "Input shape: \n", "(None, 32)\n", "\n", "\n", "Output shape: \n", "(None, 32)\n", "\n", "\n", "\n", "2429137157632->2429137123472\n", "\n", "\n", "\n", "\n", "\n", "2429137166288\n", "\n", "\n", "\n", "\n", "activation_1\n", " (Activation)\n", "\n", "\n", "Input shape: \n", "(None, 32)\n", "\n", "\n", "Output shape: \n", "(None, 32)\n", "\n", "\n", "\n", "2429137123472->2429137166288\n", "\n", "\n", "\n", "\n", "\n", "2429137172000\n", "\n", "\n", "\n", "\n", "dense_2\n", " (Dense)\n", "\n", "\n", "Input shape: \n", "(None, 32)\n", "\n", "\n", "Output shape: \n", "(None, 16)\n", "\n", "\n", "\n", "2429137166288->2429137172000\n", "\n", "\n", "\n", "\n", "\n", "2429137114592\n", "\n", "\n", "\n", "\n", "dropout_1\n", " (Dropout)\n", "\n", "\n", "Input shape: \n", "(None, 16)\n", "\n", "\n", "Output shape: \n", "(None, 16)\n", "\n", "\n", "\n", "2429137172000->2429137114592\n", "\n", "\n", "\n", "\n", "\n", "2429137118768\n", "\n", "\n", "\n", "\n", "activation_2\n", " (Activation)\n", "\n", "\n", "Input shape: \n", "(None, 16)\n", "\n", "\n", "Output shape: \n", "(None, 16)\n", "\n", "\n", "\n", "2429137114592->2429137118768\n", "\n", "\n", "\n", "\n", "\n", "2429137179888\n", "\n", "\n", "\n", "\n", "dense_3\n", " (Dense)\n", "\n", "\n", "Input shape: \n", "(None, 16)\n", "\n", "\n", "Output shape: \n", "(None, 10)\n", "\n", "\n", "\n", "2429137118768->2429137179888\n", "\n", "\n", "\n", "\n", "\n", "2429137187472\n", "\n", "\n", "\n", "\n", "activation_3\n", " (Activation)\n", "\n", "\n", "Input shape: \n", "(None, 10)\n", "\n", "\n", "Output shape: \n", "(None, 10)\n", "\n", "\n", "\n", "2429137179888->2429137187472\n", "\n", "\n", "\n", "\n", "" ], "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from IPython.display import SVG\n", "from keras.utils import model_to_dot\n", "\n", "SVG(model_to_dot(model_1, show_shapes=True, dpi=60).create(prog='dot', format='svg'))" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/5\n", "\u001b[1m1500/1500\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m13s\u001b[0m 6ms/step - accuracy: 0.4364 - loss: 1.5832 - val_accuracy: 0.9202 - val_loss: 0.4038\n", "Epoch 2/5\n", "\u001b[1m1500/1500\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m8s\u001b[0m 5ms/step - accuracy: 0.7256 - loss: 0.8545 - val_accuracy: 0.9391 - val_loss: 0.2817\n", "Epoch 3/5\n", "\u001b[1m1500/1500\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m8s\u001b[0m 5ms/step - accuracy: 0.7768 - loss: 0.7073 - val_accuracy: 0.9490 - val_loss: 0.2270\n", "Epoch 4/5\n", "\u001b[1m1500/1500\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m8s\u001b[0m 5ms/step - accuracy: 0.8016 - loss: 0.6395 - val_accuracy: 0.9555 - val_loss: 0.2025\n", "Epoch 5/5\n", "\u001b[1m1500/1500\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m8s\u001b[0m 5ms/step - accuracy: 0.8178 - loss: 0.5896 - val_accuracy: 0.9560 - val_loss: 0.2100\n" ] } ], "source": [ "# обучаем 5 эпох\n", "hist = model_1.fit(x_train, y_train, validation_split=0.2, epochs=5, verbose=1)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(hist.history['loss'])\n", "plt.plot(hist.history['val_loss'])\n", "plt.legend(['Train loss', 'Validation loss'])" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m313/313\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 4ms/step - accuracy: 0.9520 - loss: 0.2394\n", "\n", "Loss, Accuracy = [0.2147352546453476, 0.9567999839782715]\n" ] } ], "source": [ "print(\"\\nLoss, Accuracy = \", model_1.evaluate(x_test, y_test))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Посмотрите на качество получившейся модели. Вернитесь по коду вверх и раскоментируйте строки, где картинки нормируются к отрезку $[0;1]$. Переобучите сетку. Что произошло с качеством? \n", "* Теперь попробуйте использовать в качестве функции активации линейную функцию. Что произошло с качеством модели? " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Можно посмотреть, где именно сетка ошибается. " ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m313/313\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 2ms/step\n" ] } ], "source": [ "y_pred = model_1.predict(x_test)\n", "y_pred_classes = y_pred.argmax(axis=1)\n", "\n", "errors = y_pred_classes != y_ts\n", "\n", "x_err = x_ts[errors]\n", "y_err = y_ts[errors]\n", "y_pred = y_pred_classes[errors]" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "cols = 8\n", "rows = 2\n", "fig = plt.figure(figsize=(2 * cols - 1, 2.5 * rows - 1))\n", "for i in range(cols):\n", " for j in range(rows):\n", " random_index = np.random.randint(0, len(y_err))\n", " ax = fig.add_subplot(rows, cols, i * rows + j + 1)\n", " ax.grid('off')\n", " ax.axis('off')\n", " ax.imshow(x_err[random_index, : ], cmap='gray')\n", " ax.set_title('real_class: {} \\n predict class: {}'.format(y_err[random_index], y_pred[random_index]))\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.2 Свёрточная сетка \n", "\n", "Теперь давайте соберём свёртоную сеть: \n", "\n", "* Свёртка с ядром $5 \\times 5$, same padding и $32$ каналами\n", "* ReLU\n", "* Макспулинг размера $2 \\times 2$\n", "* Свёртка с ядром $5 \\times 5$ и $16$ каналами и same padding\n", "* ReLU\n", "* Макспулинг размера $2 \\times 2$ с шагом (strides) $2$ по обеим осям \n", "* Дальше используйте старую архитектуру " ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "model_2 = Sequential( )\n", "\n", "# Входной слой с указанием формы входных данных\n", "model_2.add(InputLayer(shape=(28, 28, 1)))\n", "\n", "# Первый свёрточный слой с ядром 5x5, 32 каналами, активацией ReLU и same padding\n", "model_2.add(Conv2D(32, (5, 5), padding='same'))\n", "model_2.add(Activation('relu'))\n", "\n", "# Макспулинг 2x2\n", "model_2.add(MaxPooling2D(pool_size=(2, 2)))\n", "\n", "# Второй свёрточный слой с ядром 5x5, 16 каналами, активацией ReLU и same padding\n", "model_2.add(Conv2D(16, (5, 5), padding='same'))\n", "model_2.add(Activation('relu'))\n", "\n", "# Макспулинг 2x2 с шагом 2 по обеим осям\n", "model_2.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))\n", "\n", "# Разворачиваем вектор для полносвязных слоев\n", "model_2.add(Flatten())\n", "\n", "# Полносвязный слой с 64 нейронами и активацией ReLU\n", "model_2.add(Dense(64))\n", "model_2.add(Activation('relu'))\n", "\n", "# Полносвязный слой с 32 нейронами и активацией ReLU\n", "model_2.add(Dense(32))\n", "model_2.add(Activation('relu'))\n", "\n", "# Полносвязный слой с 16 нейронами и активацией ReLU\n", "model_2.add(Dense(16))\n", "model_2.add(Activation('relu'))\n", "\n", "# Выходной слой для 10 классов с активацией softmax\n", "model_2.add(Dense(10))\n", "model_2.add(Activation('softmax'))\n", "\n", "model_2.compile(\"adam\", \"categorical_crossentropy\", metrics=[\"accuracy\"])" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [], "source": [ "model_2 = Sequential()\n", "\n", "# Входной слой с уменьшением числа фильтров\n", "model_2.add(InputLayer(shape=(28, 28, 1)))\n", "model_2.add(Conv2D(16, (3, 3), padding='same'))\n", "model_2.add(Activation('relu'))\n", "model_2.add(MaxPooling2D(pool_size=(2, 2)))\n", "\n", "# Уменьшаем количество фильтров во втором слое\n", "model_2.add(Conv2D(8, (3, 3), padding='same'))\n", "model_2.add(Activation('relu'))\n", "model_2.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))\n", "\n", "# Уменьшаем число нейронов в полносвязных слоях\n", "model_2.add(Flatten())\n", "model_2.add(Dense(32))\n", "model_2.add(Activation('relu'))\n", "model_2.add(Dense(16))\n", "model_2.add(Activation('relu'))\n", "\n", "# Выходной слой с 10 классами\n", "model_2.add(Dense(10))\n", "model_2.add(Activation('softmax'))\n", "\n", "# Компиляция модели\n", "model_2.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Model: \"sequential_4\"\n",
       "
\n" ], "text/plain": [ "\u001b[1mModel: \"sequential_4\"\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓\n",
       "┃ Layer (type)                          Output Shape                         Param # ┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩\n",
       "│ conv2d_6 (Conv2D)                    │ (None, 28, 28, 16)          │             160 │\n",
       "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
       "│ activation_22 (Activation)           │ (None, 28, 28, 16)          │               0 │\n",
       "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
       "│ max_pooling2d_6 (MaxPooling2D)       │ (None, 14, 14, 16)          │               0 │\n",
       "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
       "│ conv2d_7 (Conv2D)                    │ (None, 14, 14, 8)           │           1,160 │\n",
       "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
       "│ activation_23 (Activation)           │ (None, 14, 14, 8)           │               0 │\n",
       "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
       "│ max_pooling2d_7 (MaxPooling2D)       │ (None, 7, 7, 8)             │               0 │\n",
       "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
       "│ flatten_4 (Flatten)                  │ (None, 392)                 │               0 │\n",
       "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
       "│ dense_16 (Dense)                     │ (None, 32)                  │          12,576 │\n",
       "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
       "│ activation_24 (Activation)           │ (None, 32)                  │               0 │\n",
       "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
       "│ dense_17 (Dense)                     │ (None, 16)                  │             528 │\n",
       "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
       "│ activation_25 (Activation)           │ (None, 16)                  │               0 │\n",
       "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
       "│ dense_18 (Dense)                     │ (None, 10)                  │             170 │\n",
       "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
       "│ activation_26 (Activation)           │ (None, 10)                  │               0 │\n",
       "└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘\n",
       "
\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩\n", "│ conv2d_6 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m16\u001b[0m) │ \u001b[38;5;34m160\u001b[0m │\n", "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", "│ activation_22 (\u001b[38;5;33mActivation\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m28\u001b[0m, \u001b[38;5;34m16\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", "│ max_pooling2d_6 (\u001b[38;5;33mMaxPooling2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m16\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", "│ conv2d_7 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m8\u001b[0m) │ \u001b[38;5;34m1,160\u001b[0m │\n", "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", "│ activation_23 (\u001b[38;5;33mActivation\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m14\u001b[0m, \u001b[38;5;34m8\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", "│ max_pooling2d_7 (\u001b[38;5;33mMaxPooling2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m7\u001b[0m, \u001b[38;5;34m7\u001b[0m, \u001b[38;5;34m8\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", "│ flatten_4 (\u001b[38;5;33mFlatten\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m392\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", "│ dense_16 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m12,576\u001b[0m │\n", "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", "│ activation_24 (\u001b[38;5;33mActivation\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", "│ dense_17 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m16\u001b[0m) │ \u001b[38;5;34m528\u001b[0m │\n", "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", "│ activation_25 (\u001b[38;5;33mActivation\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m16\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", "│ dense_18 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m10\u001b[0m) │ \u001b[38;5;34m170\u001b[0m │\n", "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", "│ activation_26 (\u001b[38;5;33mActivation\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m10\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", "└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Total params: 14,594 (57.01 KB)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m14,594\u001b[0m (57.01 KB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Trainable params: 14,594 (57.01 KB)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m14,594\u001b[0m (57.01 KB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Non-trainable params: 0 (0.00 B)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model_2.summary()" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/30\n", "\u001b[1m1500/1500\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m14s\u001b[0m 7ms/step - accuracy: 0.7611 - loss: 0.7332 - val_accuracy: 0.9593 - val_loss: 0.1340\n", "Epoch 2/30\n", "\u001b[1m1500/1500\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m10s\u001b[0m 6ms/step - accuracy: 0.9602 - loss: 0.1302 - val_accuracy: 0.9665 - val_loss: 0.1053\n", "Epoch 3/30\n", "\u001b[1m1500/1500\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m10s\u001b[0m 6ms/step - accuracy: 0.9716 - loss: 0.0955 - val_accuracy: 0.9769 - val_loss: 0.0775\n", "Epoch 4/30\n", "\u001b[1m1500/1500\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m11s\u001b[0m 7ms/step - accuracy: 0.9766 - loss: 0.0741 - val_accuracy: 0.9772 - val_loss: 0.0788\n", "Epoch 5/30\n", "\u001b[1m1500/1500\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m10s\u001b[0m 7ms/step - accuracy: 0.9799 - loss: 0.0632 - val_accuracy: 0.9804 - val_loss: 0.0644\n", "Epoch 6/30\n", "\u001b[1m1500/1500\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m10s\u001b[0m 6ms/step - accuracy: 0.9823 - loss: 0.0578 - val_accuracy: 0.9790 - val_loss: 0.0741\n", "Epoch 7/30\n", "\u001b[1m1500/1500\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m10s\u001b[0m 6ms/step - accuracy: 0.9837 - loss: 0.0514 - val_accuracy: 0.9790 - val_loss: 0.0670\n", "Epoch 8/30\n", "\u001b[1m1500/1500\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m10s\u001b[0m 7ms/step - accuracy: 0.9860 - loss: 0.0435 - val_accuracy: 0.9791 - val_loss: 0.0725\n", "Epoch 9/30\n", "\u001b[1m1500/1500\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m10s\u001b[0m 7ms/step - accuracy: 0.9880 - loss: 0.0386 - val_accuracy: 0.9844 - val_loss: 0.0597\n", "Epoch 10/30\n", "\u001b[1m1500/1500\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m10s\u001b[0m 7ms/step - accuracy: 0.9886 - loss: 0.0359 - val_accuracy: 0.9855 - val_loss: 0.0504\n", "Epoch 11/30\n", "\u001b[1m1500/1500\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m10s\u001b[0m 7ms/step - accuracy: 0.9905 - loss: 0.0312 - val_accuracy: 0.9823 - val_loss: 0.0624\n", "Epoch 12/30\n", "\u001b[1m1500/1500\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m11s\u001b[0m 7ms/step - accuracy: 0.9910 - loss: 0.0284 - val_accuracy: 0.9837 - val_loss: 0.0567\n", "Epoch 13/30\n", "\u001b[1m1500/1500\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m10s\u001b[0m 7ms/step - accuracy: 0.9906 - loss: 0.0282 - val_accuracy: 0.9836 - val_loss: 0.0594\n", "Epoch 14/30\n", "\u001b[1m1500/1500\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m10s\u001b[0m 7ms/step - accuracy: 0.9913 - loss: 0.0272 - val_accuracy: 0.9847 - val_loss: 0.0535\n", "Epoch 15/30\n", "\u001b[1m1500/1500\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m10s\u001b[0m 7ms/step - accuracy: 0.9909 - loss: 0.0261 - val_accuracy: 0.9834 - val_loss: 0.0632\n" ] } ], "source": [ "# обучаем 5 эпох\n", "# hist = model_2.fit(x_train, y_train, validation_split=0.2, epochs=5, verbose=1)\n", "\n", "from keras.callbacks import EarlyStopping\n", "\n", "# Обучение на большем числе эпох, с ранней остановкой\n", "early_stopping_monitor = EarlyStopping(patience=5, restore_best_weights=True)\n", "hist = model_2.fit(x_train, y_train, epochs=30, batch_size=32, validation_split=0.2, callbacks=[early_stopping_monitor])" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 52, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(hist.history['loss'])\n", "plt.plot(hist.history['val_loss'])\n", "plt.legend(['Train loss', 'Validation loss'])" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m313/313\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step - accuracy: 0.9858 - loss: 0.0517\n", "\n", "Loss, Accuracy = [0.040652498602867126, 0.9890000224113464]\n" ] } ], "source": [ "print(\"\\nLoss, Accuracy = \", model_2.evaluate(x_test, y_test))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Как видите, точность довольно сильно подскочила. Попробуйте поиграться числом параметров и слоёв так, чтобы их стало меньше, а качество сетки стало лучше. Попробуйте обучать нейросетку большее количество эпох. \n", "\n", "Снова посмотрим на ошибки. " ] }, { "cell_type": "code", "execution_count": 64, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m313/313\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 3ms/step\n" ] } ], "source": [ "y_pred = model_2.predict(x_test)\n", "y_pred_classes = y_pred.argmax(axis=1)\n", "\n", "errors = y_pred_classes != y_ts\n", "\n", "x_err = x_ts[errors]\n", "y_err = y_ts[errors]\n", "y_pred = y_pred_classes[errors]" ] }, { "cell_type": "code", "execution_count": 65, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "cols = 8\n", "rows = 2\n", "fig = plt.figure(figsize=(2 * cols - 1, 2.5 * rows - 1))\n", "for i in range(cols):\n", " for j in range(rows):\n", " random_index = np.random.randint(0, len(y_err))\n", " ax = fig.add_subplot(rows, cols, i * rows + j + 1)\n", " ax.grid('off')\n", " ax.axis('off')\n", " ax.imshow(x_err[random_index, : ], cmap='gray')\n", " ax.set_title('real_class: {} \\n predict class: {}'.format(y_err[random_index], y_pred[random_index]))\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.4" } }, "nbformat": 4, "nbformat_minor": 4 }