Cognitive_technologies/лр11/006-dogs-vs-cats.ipynb

327 lines
8.9 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Dogs vs. Cats"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from tensorflow import keras\n",
"from tensorflow.keras.models import Model\n",
"from tensorflow.keras.applications.vgg16 import VGG16\n",
"from tensorflow.keras.applications.vgg16 import preprocess_input\n",
"from tensorflow.keras.preprocessing.image import load_img, img_to_array"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"IMG_SIZE = (224, 224) # размер входного изображения сети"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Функции загрузки данных"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"<>:24: SyntaxWarning: invalid escape sequence '\\.'\n",
"<>:24: SyntaxWarning: invalid escape sequence '\\.'\n",
"C:\\Users\\dimae\\AppData\\Local\\Temp\\ipykernel_8920\\2072963477.py:24: SyntaxWarning: invalid escape sequence '\\.'\n",
" y = np.array([1. if re.match('.*/dog\\.\\d', path) else 0. for path in files[i:j]])\n"
]
}
],
"source": [
"import re\n",
"from random import shuffle\n",
"from glob import glob\n",
"\n",
"train_files = glob('../input/train/*.jpg')\n",
"test_files = glob('../input/test/*.jpg')\n",
"\n",
"# загружаем входное изображение и предобрабатываем\n",
"def load_image(path, target_size=IMG_SIZE):\n",
" img = load_img(path, target_size=target_size) # загрузка и масштабирование изображения\n",
" array = img_to_array(img)\n",
" return preprocess_input(array) # предобработка для VGG16\n",
"\n",
"# генератор для последовательного чтения обучающих данных с диска\n",
"def fit_generator(files, batch_size=32):\n",
" while True:\n",
" shuffle(files)\n",
" for k in range(len(files) // batch_size):\n",
" i = k * batch_size\n",
" j = i + batch_size\n",
" if j > len(files):\n",
" j = - j % len(files)\n",
" x = np.array([load_image(path) for path in files[i:j]])\n",
" y = np.array([1. if re.match('.*/dog\\.\\d', path) else 0. for path in files[i:j]])\n",
" yield (x, y)\n",
"\n",
"# генератор последовательного чтения тестовых данных с диска\n",
"def predict_generator(files):\n",
" while True:\n",
" for path in files:\n",
" yield np.array([load_image(path)])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Визуализируем примеры для обучения"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"from matplotlib import pyplot as plt\n",
"fig = plt.figure(figsize=(20, 20))\n",
"for i, path in enumerate(train_files[:10], 1):\n",
" subplot = fig.add_subplot(i // 5 + 1, 5, i)\n",
" plt.imshow(plt.imread(path));\n",
" subplot.set_title('%s' % path.split('/')[-1]);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Загружаем предобученную модель"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# base_model - объект класса keras.models.Model (Functional Model)\n",
"base_model = VGG16(include_top = False,\n",
" weights = 'imagenet',\n",
" input_shape = (IMG_SIZE[0], IMG_SIZE[1], 3))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# фиксируем все веса предобученной сети\n",
"for layer in base_model.layers:\n",
" layer.trainable = False"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"base_model.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Добавляем полносвязный слой"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x = base_model.layers[-5].output\n",
"x = keras.layers.Flatten()(x)\n",
"x = keras.layers.Dense(1, # один выход\n",
" activation='sigmoid', # функция активации \n",
" kernel_regularizer=keras.regularizers.l1(1e-4))(x)\n",
"model = Model(inputs=base_model.input, outputs=x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Выводим архитектуру модели"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Компилируем модель и запускаем обучение"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.compile(optimizer='adam', \n",
" loss='binary_crossentropy', # функция потерь binary_crossentropy (log loss\n",
" metrics=['accuracy'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"shuffle(train_files) # перемешиваем обучающую выборку\n",
"\n",
"train_val_split = 100 # число изображений в валидационной выборке\n",
"\n",
"validation_data = next(fit_generator(train_files[:train_val_split], train_val_split))\n",
"\n",
"# запускаем процесс обучения\n",
"model.fit_generator(fit_generator(train_files[train_val_split:]), # данные читаем функцией-генератором\n",
" steps_per_epoch=10, # число вызовов генератора за эпоху\n",
" epochs=100, # число эпох обучения\n",
" validation_data=validation_data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.save('cats-dogs-vgg16.hdf5')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Предсказания на проверочной выборке"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pred = model.predict_generator(predict_generator(test_files), len(test_files), max_queue_size=500)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"from matplotlib import pyplot as plt\n",
"fig = plt.figure(figsize=(20, 20))\n",
"for i, (path, score) in enumerate(zip(test_files[80:][:10], pred[80:][:10]), 1):\n",
" subplot = fig.add_subplot(i // 5 + 1, 5, i)\n",
" plt.imshow(plt.imread(path));\n",
" subplot.set_title('%.3f' % score);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Готовим данные для сабмита"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open('submit.txt', 'w') as dst:\n",
" dst.write('id,label\\n')\n",
" for path, score in zip(test_files, pred):\n",
" dst.write('%s,%f\\n' % (re.search('(\\d+)', path).group(0), score))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# LogLoss = 1.04979"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.4"
},
"widgets": {
"state": {},
"version": "1.1.2"
}
},
"nbformat": 4,
"nbformat_minor": 4
}