готово

This commit is contained in:
Максим Максимов 2025-05-17 09:30:25 +03:00
parent 1c9c0f69ea
commit acbbda37d1
2 changed files with 318 additions and 60 deletions

File diff suppressed because one or more lines are too long

View File

@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 6,
"id": "93b8c292-c33b-40e8-ac81-830e67dc8f9d",
"metadata": {},
"outputs": [
@ -12,13 +12,13 @@
"text": [
" precision recall f1-score support\n",
"\n",
" 0 1.00 1.00 1.00 10\n",
" 1 1.00 1.00 1.00 10\n",
" 2 1.00 1.00 1.00 10\n",
" 0 1.00 1.00 1.00 14\n",
" 1 1.00 0.88 0.93 8\n",
" 2 0.89 1.00 0.94 8\n",
"\n",
" accuracy 1.00 30\n",
" macro avg 1.00 1.00 1.00 30\n",
"weighted avg 1.00 1.00 1.00 30\n",
" accuracy 0.97 30\n",
" macro avg 0.96 0.96 0.96 30\n",
"weighted avg 0.97 0.97 0.97 30\n",
"\n"
]
}
@ -34,16 +34,35 @@
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)\n",
"\n",
"# Модель MLP — многослойный перцептрон\n",
"clf = MLPClassifier(hidden_layer_sizes=(10,), activation='relu', max_iter=2500)\n",
"clf = MLPClassifier(hidden_layer_sizes=(10,), activation='relu', max_iter=2000)\n",
"clf.fit(X_train, y_train)\n",
"\n",
"# Отчёт о точности\n",
"print(classification_report(y_test, clf.predict(X_test)))"
]
},
{
"cell_type": "markdown",
"id": "84090ff4-7812-4562-b490-b183f4bceebf",
"metadata": {},
"source": [
"Классификация текстовых документов с помощью SVM\n",
"\n",
"\n",
"Цель\n",
"Продемонстрировать применение метода опорных векторов (SVM) для классификации текстовых данных. \n",
"\n",
"План \n",
"1. Загрузка датасета `20newsgroups` (встроенный в scikit-learn) \n",
"2. Векторизация текста с помощью `TfidfVectorizer` \n",
"3. Обучение модели SVM \n",
"4. Оценка точности и визуализация результатов \n",
"5. Повторение эксперимента на внешнем датасете (спам-письма) "
]
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 1,
"id": "a3b4078a-2b82-460f-80d8-56472b379641",
"metadata": {},
"outputs": [
@ -110,20 +129,6 @@
}
],
"source": [
"# %% [markdown]\n",
"# # Классификация текстовых документов с помощью SVM\n",
"\n",
"# %% [markdown]\n",
"# **Цель**: \n",
"# Продемонстрировать применение метода опорных векторов (SVM) для классификации текстовых данных. \n",
"\n",
"# **План**: \n",
"# 1. Загрузка датасета `20newsgroups` (встроенный в scikit-learn) \n",
"# 2. Векторизация текста с помощью `TfidfVectorizer` \n",
"# 3. Обучение модели SVM \n",
"# 4. Оценка точности и визуализация результатов \n",
"# 5. Повторение эксперимента на внешнем датасете (спам-письма) \n",
"\n",
"# %%\n",
"# Импорт необходимых библиотек\n",
"import numpy as np\n",
@ -137,10 +142,9 @@
"from sklearn.pipeline import Pipeline\n",
"from sklearn.model_selection import GridSearchCV\n",
"\n",
"# %% [markdown]\n",
"# ## Часть 1: Работа со встроенным датасетом (20newsgroups)\n",
"\n",
"# %%\n",
"# Часть 1: Работа со встроенным датасетом (20newsgroups)\n",
"\n",
"# Загрузка данных (2 категории для упрощения)\n",
"categories = ['sci.space', 'rec.sport.baseball']\n",
"newsgroups = fetch_20newsgroups(subset='all', \n",
@ -148,7 +152,7 @@
" shuffle=True, \n",
" random_state=42)\n",
"\n",
"# %%\n",
"\n",
"# Векторизация текста\n",
"vectorizer = TfidfVectorizer(stop_words='english', max_features=5000)\n",
"X = vectorizer.fit_transform(newsgroups.data)\n",
@ -159,7 +163,7 @@
" test_size=0.2, \n",
" random_state=42)\n",
"\n",
"# %%\n",
"\n",
"# Создание и обучение модели SVM\n",
"svm = SVC(kernel='linear', C=1.0, random_state=42)\n",
"svm.fit(X_train, y_train)\n",
@ -168,7 +172,6 @@
"y_pred = svm.predict(X_test)\n",
"print(classification_report(y_test, y_pred))\n",
"\n",
"# %%\n",
"# Визуализация матрицы ошибок\n",
"ConfusionMatrixDisplay.from_estimator(svm, X_test, y_test, \n",
" display_labels=categories)\n",
@ -178,10 +181,6 @@
"# %% [markdown]\n",
"# ## Часть 2: Работа с внешним датасетом (спам-письма)\n",
"\n",
"# %%\n",
"# Загрузка данных из CSV\n",
"# Пример датасета: https://www.kaggle.com/datasets/uciml/sms-spam-collection-dataset\n",
"# Перед выполнением загрузите файл spam.csv в ту же директорию\n",
"try:\n",
" data = pd.read_csv('spam.csv', encoding='latin-1')[['v1', 'v2']]\n",
" data.columns = ['label', 'text']\n",
@ -231,35 +230,25 @@
" grid.fit(X_train, y_train)\n",
" \n",
" print(\"\\nЛучшие параметры:\", grid.best_params_)\n",
" print(\"Лучшая точность:\", grid.best_score_)\n",
"\n",
"# %% [markdown]\n",
"# ## Выводы\n",
"\n",
"# %% [markdown]\n",
"# **Результаты**:\n",
"# 1. На встроенном датасете (20newsgroups):\n",
"# - Точность: 98%\n",
"# - SVM успешно разделяет категории \"космос\" и \"бейсбол\"\n",
"#\n",
"# 2. На внешнем датасете (спам-письма):\n",
"# - Точность: 97-99%\n",
"# - Основные ошибки: некоторые спам-письма с коротким текстом классифицируются как \"ham\"\n",
"#\n",
"# **Рекомендации**:\n",
"# - Для улучшения результатов можно:\n",
"# - Добавить лемматизацию/стемминг\n",
"# - Использовать более сложные методы векторизации (Word2Vec, BERT)\n",
"# - Настроить гиперпараметры с помощью GridSearchCV"
" print(\"Лучшая точность:\", grid.best_score_)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "01945ae9-37f5-4758-96ab-7ab90a3cc4fd",
"cell_type": "markdown",
"id": "f5d6ed2d-292d-4b14-91dc-d78f8f096c37",
"metadata": {},
"outputs": [],
"source": []
"source": [
"Выводы\n",
"\n",
"Результаты\n",
"На встроенном датасете (20newsgroups):\n",
" - Точность: 98%\n",
"- SVM успешно разделяет категории \"космос\" и \"бейсбол\"\n",
"\n",
"2. На внешнем датасете (спам-письма):\n",
" - Точность: 97-99%\n",
" - Основные ошибки: некоторые спам-письма с коротким текстом классифицируются как \"ham\""
]
}
],
"metadata": {