chetyre/plot_digits_classification.ipynb

369 lines
59 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"# Recognizing hand-written digits\n",
"\n",
"This example shows how scikit-learn can be used to recognize images of\n",
"hand-written digits, from 0-9.\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"# Authors: The scikit-learn developers\n",
"# SPDX-License-Identifier: BSD-3-Clause\n",
"\n",
"# Standard scientific Python imports\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# Import datasets, classifiers and performance metrics\n",
"from sklearn import datasets, metrics, svm\n",
"from sklearn.model_selection import train_test_split"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Digits dataset\n",
"\n",
"The digits dataset consists of 8x8\n",
"pixel images of digits. The ``images`` attribute of the dataset stores\n",
"8x8 arrays of grayscale values for each image. We will use these arrays to\n",
"visualize the first 4 images. The ``target`` attribute of the dataset stores\n",
"the digit each image represents and this is included in the title of the 4\n",
"plots below.\n",
"\n",
"Note: if we were working from image files (e.g., 'png' files), we would load\n",
"them using :func:`matplotlib.pyplot.imread`.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAxsAAADSCAYAAAAi0d0oAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjMsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvZiW1igAAAAlwSFlzAAAPYQAAD2EBqD+naQAAEHdJREFUeJzt3X+QVWX9B/AH3dQsk9VSLBJYdSa1yRVwJioVpsXfymqB/eHoOjZQ0iBlM8v0C7QkSJtRSxH9x9+O4g9Ix1KZXB2nf2JzLfsxg7nalDqiLFSaQHW/85z5LgPs6l3gfLrce1+vmR245+797LmX+2HP+z7Pec6oSqVSSQAAACXbq+yCAAAAmbABAACEEDYAAIAQwgYAABBC2AAAAEIIGwAAQAhhAwAACCFsAAAAIYQNAAAghLCxC7q6utL48eN36bGLFi1Ko0aNKn2fYE+jT6A6fQLV6ZP61lBhI7+ZRvLV09NT613d4/zqV79Kn/vc59L++++fxowZk+bNm5f++c9/1nq3CKBPds3jjz+eLrnkkvTJT34y7b333rv8i4/6oE923ttvv51uuOGGdMopp6TDDjssHXDAAen4449Py5YtS//5z39qvXsE0Ce7ZvHixenTn/50+shHPpL222+/dNRRR6X58+endevWpUY0qlKpVFKDuPPOO7e7ffvtt6cnnngi3XHHHdttnz59ejr00EN3+eds2bIl/fe//0377rvvTj/23//+d/GV31x7ir6+vjRlypR09NFHp9mzZ6e//vWv6ZprrknTpk1LP//5z2u9e5RMn+z6J2v33ntvmjhxYvrLX/5SBI6XXnqp1rtFEH2y855//vn0qU99Kn3+858vAseHPvSh9Nhjj6WHHnooXXjhhem2226r9S5SMn2ya77whS8UQeMTn/hEEcr/+Mc/pltuuSUdcsghxTHZBz7wgdRQKg1s7ty5OUhV/b633nqr0sxOP/30ymGHHVbZuHHj1m233HJL8do99thjNd034umTkfnb3/5W2bx5c/H3M888szJu3Lha7xL/Q/qkunXr1lWef/75Idsvvvji4rVbu3ZtTfaL/x19suvuv//+4rW75557Ko2moaZRjcTUqVOLaRC9vb3ppJNOKqYNfetb3yruW7VqVTrzzDPTRz/60SI9H3HEEen73//+kOHfHecO5k838zBhHg24+eabi8flx59wwgnp17/+ddW5g/n21772tbRy5cpi3/Jjjz322PSLX/xiyP7nocjJkycXCT3/nOXLlw9b84033kh/+tOfimHt9/L3v/+9+BTiggsuKD6FGpQ/hfrgBz+Y7rvvvhG9rjQWfTJUfr7ve9/7RvgK0gz0yfY+/OEPFz9rR+eee27xZ/70luajT0Zm8Plt2LAhNZqW1ITefPPNdPrpp6cvfelLxUH24NDerbfeWhxgf+Mb3yj+/OUvf5m+973vFQfkV199ddW6d999d/rHP/6R5syZU7wJf/SjH6Xzzjsvvfjii1UPUp555pn04IMPpksvvbQYUrv++uuLYbY8XePggw8uvufZZ59Np512WjEX9oorriia8corryyG4nb005/+tPieJ598smj0d/O73/2uGF7MjbStffbZJ7W3txc/k+akT6A6fVLda6+9tjWM0Jz0yVD5LIb8uuRjsLVr16YFCxYU03Mb8ndRpcmG804++eRi20033TTk+99+++0h2+bMmVPZf//9K++8887WbRdddNF2Uyj6+/uLmgcffHBl/fr1W7evWrWq2P7www9v3bZw4cIh+5Rv77PPPpUXXnhh67bnnnuu2P6Tn/xk67azzz672Jc8nWNQHpZuaWkZUnPw5zz55JPv+RqtWLGi+L6nn356yH0zZ86sjBkz5j0fT/3TJ9X7ZEemUTUffbLzfZJt2rSpcswxx1QmTJhQ2bJly04/nvqiT0beJ6+++mrx/YNfY8eOrdx7772VRtR006iyPFx28cUXD9n+/ve/f+vfc1LOQ2InnnhiMSSWh8aqOf/881Nra+vW2/mxWU7Y1XR0dBTDc4PySXZ5WtPgY3OaXr16ders7CyGGwcdeeSRxacFO8pDfLmfqiXkf/3rX8Wfw510lYcMB++n+egTqE6fvLc8VeUPf/hD8alvS0tTTqZAnwzroIMOKqaxP/zww8VoSR75a9RVQJuy8z/2sY8V04R29Pvf/z595zvfKYbx8hDetjZu3Fi17uGHH77d7cEGGBgY2OnHDj5+8LGvv/56ceCf3+Q7Gm7bSA02+qZNm4bc984772z3HwHNRZ9Adfrk3eVpMHmFnTwH/4wzziitLvVHnwyVX48ceLKzzjqrWMXts5/9bLEiVb7dSJoybAx3AJ1PyDn55JOLVJsTZk67+ZP93/zmN6m7u7tYcq2aPNduOCNZXXh3Hrs78jzE7NVXXx1yX962bZqnuegTqE6fDC/Pxc/P9Stf+UpxMElz0yfVfeYznymOye666y5ho1Hl1QbyiTr5ZKG8WsKg/v7+tCfISTc34QsvvDDkvuG2jVRehSEPba9ZsybNmjVr6/bNmzcXaz1vuw2atU9gZzR7n+QVhr785S8XJ+rmi/zBcJq9T4aTZ5SMZESn3jTlORvvlXC3TbT5gPvGG29Me8r+5eG2vEzbK6+8st0bfrgL7410CbYDDzywqJsvzJPnSw7KF+TJcwdnzpxZ8jOhnjVrn8DOaOY+efrpp4sVh/LBY/6Edq+9HGYwvGbtk7feemvY73nggQeKKVw7rg7aCIxsbDN8lefqXXTRRWnevHnFEmr5gHtPmp6RTz56/PHHizl9X/3qV4uTl/JJd3l0Io9C7OoSbFdddVXx/PNw5uAVxH/84x8XV4DNS77BoGbuk9/+9rfpZz/72dZfNvnTpx/84AfF7eOOOy6dffbZgc+KetKsffLyyy+nc845p3i+X/ziF9OKFSu2uz+fgJu/oJn7ZO3atUWIySe35yuI50CeZ5fkD33ztTYuu+yy1GiEjf+X11R+5JFH0uWXX17ML80NkNeCzifsnHrqqWlPMGnSpCJNf/Ob30zf/e5308c//vFinmO+UNJIVm14NxMnTixWXMhzJL/+9a8X601fcskl6Yc//GGp+0/9a+Y+yfOIc71tDd7OvyyFDZq9T/L0l8EpIHPnzh1y/8KFC4UNUrP3ydixY4vreeST4m+77ba0ZcuWNG7cuGLltm9/+9tbr/HRSEbl9W9rvRPsnrwsW17RIadlYHj6BKrTJ1CdPtk5JlPWmR2ve5Hf6I8++qjrBMA29AlUp0+gOn2y+4xs1Jm8LFpXV1dqa2sr5scuW7asuEbGs88+m4466qha7x7sEfQJVKdPoDp9svucs1Fn8gnb99xzT3rttdeKK3JOmTIlLV682BsetqFPoDp9AtXpk91nZAMAAAjhnA0AACCEsAEAAIQQNgAAgBANd4L4jlcsLUO+2F3Zpk+fniIsWbKk9Jr5QjtQTcQygBs2bEhRV4WNWHcdqunp6amb9157e3tdPH9qb+nSpaXXXLBgQek1J0yYkCL09vaWXrO1gY69jGwAAAAhhA0AACCEsAEAAIQQNgAAgBDCBgAAEELYAAAAQggbAABACGEDAAAIIWwAAAAhhA0AACCEsAEAAIQQNgAAgBDCBgAAEELYAAAAQggbAABACGEDAAAIIWwAAAAhhA0AACCEsAEAAIRoSQ2mu7u79Jr9/f2l1xwYGEgRDjrooNJr3nfffaXXnDlzZuk1qa3Ro0eXXvOpp55KEXp6ekqv2dnZWXpNaquvr6/0mtOmTSu95oEHHpgivPTSSyF1qa0FCxbUxXHC8uXLS685Z86cFKG3t7f0mh0dHalRGNkAAABCCBsAAEAIYQMAAAghbAAAACGEDQAAIISwAQAAhBA2AACAEMIGAAAQQtgAAABCCBsAAEAIYQMAAAghbAAAACGEDQAAIISwAQAAhBA2AACAEMIGAAAQQtgAAABCCBsAAEAIYQMAAAghbAAAACFaUg319vaWXrO/v7/0mn/+859Lr9nW1pYiTJ8+vS7+nWbOnFl6TUaur6+v9Jo9PT2pXrS3t9d6F6gDK1euLL3mcccdV3rNzs7OFOGKK64IqUttzZ49u/Sa3d3dpdecNGlS6TUnTJiQInR0dITUbRRGNgAAgBDCBgAAEELYAAAAQggbAABACGEDAAAIIWwAAAAhhA0AACCEsAEAAIQQNgAAgBDCBgAAEELYAAAAQggbAABACGEDAAAIIWwAAAAhhA0AACCEsAEAAIQQNgAAgBDCBgAAEELYAAAAQggbAABAiJZUQwMDA6XXnDhxYuk129raUr2YNGlSrXeBkl177bWl11y0aFHpNTdu3JjqxdSpU2u9C9SB+fPnl15z/PjxdbGf2YwZM0LqUlsRxzQvvvhi6TX7+/tLr9nR0ZHq5Xi2tbU1NQojGwAAQAhhAwAACCFsAAAAIYQNAAAghLABAACEEDYAAIAQwgYAABBC2AAAAEIIGwAAQAhhAwAACCFsAAAAIYQNAAAghLABAACEEDYAAIAQwgYAABBC2AAAAEIIGwAAQAhhAwAACCFsAAAAIYQNAAAgREuqoYGBgdJrTp8+PTWziNe0tbW19JqM3Pz580uv2dXV1dTvkw0bNtR6F6iDf9Nrr7229JorV65M9eLWW2+t9S5QJ9ra2kqvuX79+tJrdnR0lF4zqu7q1asb5ve0kQ0AACCEsAEAAIQQNgAAgBDCBgAAEELYAAAAQggbAABACGEDAAAIIWwAAAAhhA0AACCEsAEAAIQQNgAAgBDCBgAAEELYAAAAQggbAABACGEDAAAIIWwAAAAhhA0AACCEsAEAAIQQNgAAgBDCBgAAEELYAAAAQrSkGmptbS29Zm9vb6oHAwMDIXXXrFlTes1Zs2aVXhNqqa+vr/Sa7e3tpddk5BYtWlR6zeuuuy7Vg4ceeiik7ujRo0PqQq2OEVevXp0izJkzp/SaS5cuLb3mkiVLUi0Y2QAAAEIIGwAAQAhhAwAACCFsAAAAIYQNAAAghLABAACEEDYAAIAQwgYAABBC2AAAAEIIGwAAQAhhAwAACCFsAAAAIYQNAAAghLABAACEEDYAAIAQwgYAABBC2AAAAEIIGwAAQAhhAwAACCFsAAAAIVpSDbW1tZVec82aNaXXXLFiRV3UjNLd3V3rXQB4T11dXaXX7OnpKb3mc889V3rNc889N0WYMWNGXfw7dXZ2ll6TnbNgwYLSa3Z0dJRec2BgIEV44oknSq85a9as1CiMbAAAACGEDQAAIISwAQAAhBA2AACAEMIGAAAQQtgAAABCCBsAAEAIYQMAAAghbAAAACGEDQAAIISwAQAAhBA2AACAEMIGAAAQQtgAAABCCBsAAEAIYQMAAAghbAAAACGEDQAAIISwAQAAhBA2AACAEC2phtra2kqvuXTp0tJrdnd3l15z8uTJKUJvb29IXRrL6NGjS685Y8aM0muuWrUqRejp6Sm9ZldXV+k1Gbn29vbSa/b19dVFzUWLFqUIEf03fvz40mt2dnaWXpOd09raWnrN2bNnp3oxa9as0msuX748NQojGwAAQAhhAwAACCFsAAAAIYQNAAAghLABAACEEDYAAIAQwgYAABBC2AAAAEIIGwAAQAhhAwAACCFsAAAAIYQNAAAghLABAACEEDYAAIAQwgYAABBC2AAAAEIIGwAAQAhhAwAACCFsAAAAIYQNAAAgxKhKpVKJKQ0AADQzIxsAAEAIYQMAAAghbAAAACGEDQAAIISwAQAAhBA2AACAEMIGAAAQQtgAAABCCBsAAECK8H9LBMQg22J/wgAAAABJRU5ErkJggg==",
"text/plain": [
"<Figure size 1000x300 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"digits = datasets.load_digits()\n",
"\n",
"_, axes = plt.subplots(nrows=1, ncols=4, figsize=(10, 3))\n",
"for ax, image, label in zip(axes, digits.images, digits.target):\n",
" ax.set_axis_off()\n",
" ax.imshow(image, cmap=plt.cm.gray_r, interpolation=\"nearest\")\n",
" ax.set_title(\"Training: %i\" % label)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Classification\n",
"\n",
"To apply a classifier on this data, we need to flatten the images, turning\n",
"each 2-D array of grayscale values from shape ``(8, 8)`` into shape\n",
"``(64,)``. Subsequently, the entire dataset will be of shape\n",
"``(n_samples, n_features)``, where ``n_samples`` is the number of images and\n",
"``n_features`` is the total number of pixels in each image.\n",
"\n",
"We can then split the data into train and test subsets and fit a support\n",
"vector classifier on the train samples. The fitted classifier can\n",
"subsequently be used to predict the value of the digit for the samples\n",
"in the test subset.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"# flatten the images\n",
"n_samples = len(digits.images)\n",
"data = digits.images.reshape((n_samples, -1))\n",
"\n",
"# Create a classifier: a support vector classifier\n",
"clf = svm.SVC(gamma=0.001)\n",
"\n",
"# Split data into 50% train and 50% test subsets\n",
"X_train, X_test, y_train, y_test = train_test_split(\n",
" data, digits.target, test_size=0.5, shuffle=False\n",
")\n",
"\n",
"# Learn the digits on the train subset\n",
"clf.fit(X_train, y_train)\n",
"\n",
"# Predict the value of the digit on the test subset\n",
"predicted = clf.predict(X_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Below we visualize the first 4 test samples and show their predicted\n",
"digit value in the title.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAxsAAADSCAYAAAAi0d0oAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjMsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvZiW1igAAAAlwSFlzAAAPYQAAD2EBqD+naQAAEepJREFUeJzt3QlsZHUdB/D/wopyuVVAxWuLZ7xgFEUTlS0CEkXjaLzwyLbGeKPVKKImbhXjHbMGj0iMWzUYPAjdeKESdxvwDtBGvKK4XSEoYEJX48nxzO+Zqe3S3bbr+3V2Zj+fpJRO3/zmzez7dd73/d/7z5qqqqoCAADQsIOaLggAABCEDQAAIIWwAQAApBA2AACAFMIGAACQQtgAAABSCBsAAEAKYQMAAEghbAAAACmEjRUaHBwsw8PDcz9v3769rFmzpv7elKg3NjbWWD1YbfoElkevwNL0SW/rqbAxPj5ebwydr7vd7W7lYQ97WHnDG95QbrzxxtJLvvWtb/XURv2Vr3ylPOlJTyoDAwPlqKOOKhs2bCjf/OY3u71aLEKfdI8+6S16pftuvfXW8shHPrJ+/T/60Y92e3VYhD7pnk984hPlEY94RLnrXe9a7ne/+5W3vOUt5W9/+1vpNWtLD3rve99bjjvuuPLPf/6zXHHFFeXTn/50vQFdc8015bDDDlvVdTn55JPLP/7xj3LIIYes6H6xvp/85CcX3eij3tq1+88/zfnnn1/e+MY3ljPPPLN88IMfrF/3+OPzrGc9q1x88cXlec97XrdXkUXok9WlT3qXXulu3/zhD3/o9mqwDPpkdb397W8vH/7wh8vzn//88qY3van88pe/rPvlF7/4RfnOd75Tesn+86quwDOe8Yzy+Mc/vv7/V77ylfURxI997GNl69at5ayzzlr0PpEEDz/88MbX5aCDDqpTfpOarvf/io37CU94Qvn6179eH9UIr3jFK+qU/fnPf95O1H5Kn6wufdK79Ep33HTTTfUObOxUvfvd7+726rAEfbJ6/vjHP9av7ctf/vLyhS98Ye72GFE6++yz6/eZZz/72aVX9NRpVHvytKc9rf6+Y8eO+nuc13fEEUeUa6+9tjzzmc8sRx55ZHnpS19a/+6OO+4omzdvLo961KPqDeve9753efWrX11uueWWBTWrqirve9/7yv3vf/86sZ9yyil1mtzdns4b/MlPflI/9j3ucY+60Y4//vjy8Y9/fG79IlmH+UOTeztv8Oqrr64b/e53v3v93E499dTy4x//eNGhzh/84Af1UNsxxxxTP/Zzn/vccvPNNy9YdteuXeXXv/51/X0pf/nLX8q97nWvBevYWY9DDz10yfuzf9An/6VPWIpeye2VjnPPPbc8/OEPLy972cuWfR/2H/okr09+9KMfldtuu628+MUvXnB75+eLLrqo9JKeHNnYXWzYIVJ2R/wjnXHGGeUpT3lKfR5oZ4gvNu7YMEZGRupTHqJJ4py42KBiQ7nLXe5SLxdHWWKDj402vq666qry9Kc/vfz73/9ecn2+973v1adOHHvssfXQ133uc5/yq1/9qnzjG9+of451uOGGG+rlvvjFLy5ZLxrtqU99ar2xn3POOfU6fuYznylDQ0NlcnKyPPGJT1ywfKTeaLRNmzaVmZmZusHj3Movf/nLc8tccskl9WuwZcuWBRddLSYe52tf+1p95DaSdAyhxv9Hs8TzoTfoE33C8uiV3F4JP/3pT+sRvzgdZ/4OH71Dn+T1yb/+9a/6++4Hqjqv55VXXll6StVDtmzZUsUqX3bZZdXNN99cXXfdddVFF11UHXXUUdWhhx5aXX/99fVyGzdurJc799xzF9z/8ssvr2+/8MILF9x+6aWXLrj9pptuqg455JDqzDPPrO6444655d75znfWy0X9jm3bttW3xfdw2223Vccdd1y1fv366pZbblnwOPNrvf71r6/vt5i4fdOmTXM/t9vten2uvfbaudtuuOGG6sgjj6xOPvnkO70+p5122oLHevOb31wdfPDB1ezs7J2Wje9LufHGG6tTTz21Xr7zdfTRR1c//OEPl7wvq0+f6BOWR690p1ei1kknnVSdddZZ9c87duyo7/uRj3xkyfuy+vTJ6vfJlVdeWS933nnnLfqaHXHEEVUv6cnTqE477bR6mOoBD3hAPaQUQ1uRFuPc6Ple+9rXLvj5q1/9alm3bl05/fTTy5///Oe5rxNPPLGusW3btnq5yy67rE7RkVLnH3EZHR1dct0ipUdij2VjRpr59uXoze23316++93vlna7XR70oAfN3R7J/SUveUl9VChO35jvVa961YLHimQedXbu3Dl3WyTq6K3lHIGKJB1D3Rs3bqxfw8997nP148c56L/73e9W/JxYHfpEn7A8emV1eyWOcP/85z8vH/rQh1a8/nSPPlm9Pnnc4x5Xj5xEj8QoSIyUfPvb365HZ2KEJS5m7yU9eRpVnHMXF8nErAFx3l+8wcfFQvPF7+Kcv/l++9vf1qc0xHnVe7pYLXQ2jIc+9KELfh9NFkNkyxlWfPSjH12aEOf7/f3vf6+f4+5iOrQ4D/K6666rz4PseOADH7hguc46735u5HK94AUvqF/PuCCp4znPeU79+rzrXe9aMETI/kOf/Jc+YSl6ZfV6JXbQ3vGOd5S3ve1t9U4rvUOfrO57ysUXX1xe9KIX1RONhIMPPri+JiRO4frNb35TeklPho2TTjppbkaEPYk5iXdvgtg4YmO/8MILF71PbND9IDbIxfx3lHBlfv/735dLL720XHDBBQtuv+c971mfkxnnWrJ/0id7p0/o0Cur1ytxHn8cvY6dqDhaG66//vq5nbK47b73ve+KpzQlnz5ZvT4JMWIUIygR1v70pz/VISyuQ4n+iNDXS3oybOyrBz/4wfUw3ZOf/OS9zg6zfv36+nv8A88fPouku1RCjccIMe90DDnuyXKH9aIJ4/SMxVJszGgQTZ15dKjzgT0xFLjYhzHFxWD0F32ycvrkwKRXVi4+UyOe8/wjwh3vf//76684JabVaqWtA6tLn/x/ImR0RnviszZiWtzlnK64P+nJazb21Qtf+MJ6Z+C888670+9iZ2B2drb+/9hQ45y4mElmfiKNmQWWEufZxYfexLKdeh3za3Xmnd59mcWScszEEPNYd44CdXZuvvSlL9VHTWOmhJVa7vRrD3nIQ+qmilNA5q9/HIm6/PLLy2Mf+9gVPzb7N33yP/qEvdErK++VmIkozvOf/xUz/ITYgYqf4/nSP/TJ/zdF9PwRopgVK0LQa17zmtJLDqiRjQ0bNtQX13zgAx8oU1NT9YYUG3ak6LiAKeZijk9qjET71re+tV4uplGL6dfiSEtcnHP00Ufv9TFihyM+VTOmvowjMzHFWVxQFBvX/E99jAujOn94Y5q42LB3n0+5I6aBi6naYuN+3eteV58TGX+cY2q0+HTJfbHc6dfitYjzBT/72c/W80vHxa5//etfy6c+9an6AqU495b+ok/+R5+wN3pl5b0SO4XxNV9nZy5GO+KCXPqLPin7NEV0TNcbU6jH84kR8gg5nSmjd78+ZL9X9ZDOlGE/+9nP9rpcTI92+OGH7/H3F1xwQXXiiSfWU7bFFGaPecxjqnPOOaee0qzj9ttvr97znvdUxx57bL3c0NBQdc0119TTqu1t+rWOK664ojr99NPr+rEuxx9/fHX++efP/T6maTv77LOrY445plqzZs2Cqdh2n34tXHXVVdUZZ5xRT3d22GGHVaeccsqdptTc0+uz2DquZJrCW2+9tV73VqtVP358xeN///vfX/K+rD59ok9YHr3SnV7Znalv92/6pDt9smXLluqEE06on0c8n5havVffT9bEf7odeAAAgP5zQF2zAQAArB5hAwAASCFsAAAAKYQNAAAghbABAACkEDYAAIAUwgYAAJCi7z5BfKmPoN8XS33K476IT9Hslee/ffv2xmvGJ2LSPePj443XHBsba7zmzp07S4b4FNem+eRjuvX3NGvb27x5c0+8n9J9GfseGe8pGe99YWhoqCeef6tL+15GNgAAgBTCBgAAkELYAAAAUggbAABACmEDAABIIWwAAAAphA0AACCFsAEAAKQQNgAAgBTCBgAAkELYAAAAUggbAABACmEDAABIIWwAAAAphA0AACCFsAEAAKQQNgAAgBTCBgAAkELYAAAAUqwtXTQ7O9t4zaGhocZrTk9PN15zw4YNJcPk5GTjNScmJhqv2Wq1Gq/Zr2ZmZhqvOTIyUg5kGa8pLMfo6GjjNQcHB0uGdrudUpf+k7GtZOwnZP3tHx4ebrzm1NRU3+x7GdkAAABSCBsAAEAKYQMAAEghbAAAACmEDQAAIIWwAQAApBA2AACAFMIGAACQQtgAAABSCBsAAEAKYQMAAEghbAAAACmEDQAAIIWwAQAApBA2AACAFMIGAACQQtgAAABSCBsAAEAKYQMAAEghbAAAACnWli7avHlz4zWnp6cbr7lt27bGa87MzJQMk5OTjddstVqN16S71q1b13jNXbt29cR6hna7nVKX/tIr71E7duwoGQYGBlLq0n9mZ2cbrzk4ONh4zYmJiZJh69atjdds9dG+l5ENAAAghbABAACkEDYAAIAUwgYAAJBC2AAAAFIIGwAAQAphAwAASCFsAAAAKYQNAAAghbABAACkEDYAAIAUwgYAAJBC2AAAAFIIGwAAQAphAwAASCFsAAAAKYQNAAAghbABAACkEDYAAIAUwgYAAJBibemiVqvVeM1169Y1XnPz5s2N15yZmSkZ1q9f33jNdrvdeE2Wb3BwsCe26ZGRkdIrJiYmGq85OjraeE2Wb/v27Y3XHBsba7zmpk2beuJvRFafeD/pTxnvKePj4z2z75Wx7zk0NFT6hZENAAAghbABAACkEDYAAIAUwgYAAJBC2AAAAFIIGwAAQAphAwAASCFsAAAAKYQNAAAghbABAACkEDYAAIAUwgYAAJBC2AAAAFIIGwAAQAphAwAASCFsAAAAKYQNAAAghbABAACkEDYAAIAUwgYAAJBiTVVVVekjMzMzjdccHh5uvObk5GTJcMIJJzRec2pqqvGadNfg4GDjNYeGhnqiZhgZGWm85tVXX914zVar1XjNftVut3vib19GzYmJidIrfXLJJZf0xL89dPu9ajhh3zOj5nIY2QAAAFIIGwAAQAphAwAASCFsAAAAKYQNAAAghbABAACkEDYAAIAUwgYAAJBC2AAAAFIIGwAAQAphAwAASCFsAAAAKYQNAAAghbABAACkEDYAAIAUwgYAAJBC2AAAAFIIGwAAQAphAwAASCFsAAAAKYQNAAAgxdrSZwYHBxuvOTs7W3rF9PR04zXHx8cbrzk8PNx4zX6Vsf3t3Lmz8Zqjo6ON12y1WiXDyMhI4zW3b9/eM8+/H7fprVu3Nl5z/fr1jddst9uN15ycnCwH8ns0KzM2NtZ4zYGBgZ54T8kyNTXVE69ptxjZAAAAUggbAABACmEDAABIIWwAAAAphA0AACCFsAEAAKQQNgAAgBTCBgAAkELYAAAAUggbAABACmEDAABIIWwAAAAphA0AACCFsAEAAKQQNgAAgBTCBgAAkELYAAAAUggbAABACmEDAABIIWwAAAAp1uaU7S/T09PlQDY7O9vtVTigDQwMNF5z48aNjdccGxsrvWLdunWN1xwaGmq8Zr/qlW16Zmam8ZqDg4ON15ycnCwZMl7TVqvVeE1WZnR0tPGa7Xa78ZpTU1ON1xweHi4Zdu3a1RN/K7rFyAYAAJBC2AAAAFIIGwAAQAphAwAASCFsAAAAKYQNAAAghbABAACkEDYAAIAUwgYAAJBC2AAAAFIIGwAAQAphAwAASCFsAAAAKYQNAAAghbABAACkEDYAAIAUwgYAAJBC2AAAAFIIGwAAQAphAwAASLGmqqoqp3T/aLfbjdecmZkpGQYGBhqvOTEx0RPryfJNTU31RJ/s3LmzZNiyZUvjNYeHhxuvSf8ZHx9vvObIyEjJsGPHjsZrDg4ONl6T/tRqtRqvOT09XTJs2rSp8ZpjY2OlXxjZAAAAUggbAABACmEDAABIIWwAAAAphA0AACCFsAEAAKQQNgAAgBTCBgAAkELYAAAAUggbAABACmEDAABIIWwAAAAphA0AACCFsAEAAKQQNgAAgBTCBgAAkELYAAAAUggbAABACmEDAABIIWwAAAAp1lRVVeWUBgAADmRGNgAAgBTCBgAAkELYAAAAUggbAABACmEDAABIIWwAAAAphA0AACCFsAEAAKQQNgAAgJLhP/LI/xgQyRqSAAAAAElFTkSuQmCC",
"text/plain": [
"<Figure size 1000x300 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"_, axes = plt.subplots(nrows=1, ncols=4, figsize=(10, 3))\n",
"for ax, image, prediction in zip(axes, X_test, predicted):\n",
" ax.set_axis_off()\n",
" image = image.reshape(8, 8)\n",
" ax.imshow(image, cmap=plt.cm.gray_r, interpolation=\"nearest\")\n",
" ax.set_title(f\"Prediction: {prediction}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
":func:`~sklearn.metrics.classification_report` builds a text report showing\n",
"the main classification metrics.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Classification report for classifier SVC(gamma=0.001):\n",
" precision recall f1-score support\n",
"\n",
" 0 1.00 0.99 0.99 88\n",
" 1 0.99 0.97 0.98 91\n",
" 2 0.99 0.99 0.99 86\n",
" 3 0.98 0.87 0.92 91\n",
" 4 0.99 0.96 0.97 92\n",
" 5 0.95 0.97 0.96 91\n",
" 6 0.99 0.99 0.99 91\n",
" 7 0.96 0.99 0.97 89\n",
" 8 0.94 1.00 0.97 88\n",
" 9 0.93 0.98 0.95 92\n",
"\n",
" accuracy 0.97 899\n",
" macro avg 0.97 0.97 0.97 899\n",
"weighted avg 0.97 0.97 0.97 899\n",
"\n",
"\n"
]
}
],
"source": [
"print(\n",
" f\"Classification report for classifier {clf}:\\n\"\n",
" f\"{metrics.classification_report(y_test, predicted)}\\n\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also plot a `confusion matrix <confusion_matrix>` of the\n",
"true digit values and the predicted digit values.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Confusion matrix:\n",
"[[87 0 0 0 1 0 0 0 0 0]\n",
" [ 0 88 1 0 0 0 0 0 1 1]\n",
" [ 0 0 85 1 0 0 0 0 0 0]\n",
" [ 0 0 0 79 0 3 0 4 5 0]\n",
" [ 0 0 0 0 88 0 0 0 0 4]\n",
" [ 0 0 0 0 0 88 1 0 0 2]\n",
" [ 0 1 0 0 0 0 90 0 0 0]\n",
" [ 0 0 0 0 0 1 0 88 0 0]\n",
" [ 0 0 0 0 0 0 0 0 88 0]\n",
" [ 0 0 0 1 0 1 0 0 0 90]]\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"disp = metrics.ConfusionMatrixDisplay.from_predictions(y_test, predicted)\n",
"disp.figure_.suptitle(\"Confusion Matrix\")\n",
"print(f\"Confusion matrix:\\n{disp.confusion_matrix}\")\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If the results from evaluating a classifier are stored in the form of a\n",
"`confusion matrix <confusion_matrix>` and not in terms of `y_true` and\n",
"`y_pred`, one can still build a :func:`~sklearn.metrics.classification_report`\n",
"as follows:\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Classification report rebuilt from confusion matrix:\n",
" precision recall f1-score support\n",
"\n",
" 0 1.00 0.99 0.99 88\n",
" 1 0.99 0.97 0.98 91\n",
" 2 0.99 0.99 0.99 86\n",
" 3 0.98 0.87 0.92 91\n",
" 4 0.99 0.96 0.97 92\n",
" 5 0.95 0.97 0.96 91\n",
" 6 0.99 0.99 0.99 91\n",
" 7 0.96 0.99 0.97 89\n",
" 8 0.94 1.00 0.97 88\n",
" 9 0.93 0.98 0.95 92\n",
"\n",
" accuracy 0.97 899\n",
" macro avg 0.97 0.97 0.97 899\n",
"weighted avg 0.97 0.97 0.97 899\n",
"\n",
"\n"
]
}
],
"source": [
"# The ground truth and predicted lists\n",
"y_true = []\n",
"y_pred = []\n",
"cm = disp.confusion_matrix\n",
"\n",
"# For each cell in the confusion matrix, add the corresponding ground truths\n",
"# and predictions to the lists\n",
"for gt in range(len(cm)):\n",
" for pred in range(len(cm)):\n",
" y_true += [gt] * cm[gt][pred]\n",
" y_pred += [pred] * cm[gt][pred]\n",
"\n",
"print(\n",
" \"Classification report rebuilt from confusion matrix:\\n\"\n",
" f\"{metrics.classification_report(y_true, y_pred)}\\n\"\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}