готово
This commit is contained in:
parent
1c9c0f69ea
commit
acbbda37d1
File diff suppressed because one or more lines are too long
@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 6,
|
||||
"id": "93b8c292-c33b-40e8-ac81-830e67dc8f9d",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -12,13 +12,13 @@
|
||||
"text": [
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" 0 1.00 1.00 1.00 10\n",
|
||||
" 1 1.00 1.00 1.00 10\n",
|
||||
" 2 1.00 1.00 1.00 10\n",
|
||||
" 0 1.00 1.00 1.00 14\n",
|
||||
" 1 1.00 0.88 0.93 8\n",
|
||||
" 2 0.89 1.00 0.94 8\n",
|
||||
"\n",
|
||||
" accuracy 1.00 30\n",
|
||||
" macro avg 1.00 1.00 1.00 30\n",
|
||||
"weighted avg 1.00 1.00 1.00 30\n",
|
||||
" accuracy 0.97 30\n",
|
||||
" macro avg 0.96 0.96 0.96 30\n",
|
||||
"weighted avg 0.97 0.97 0.97 30\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
@ -34,16 +34,35 @@
|
||||
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)\n",
|
||||
"\n",
|
||||
"# Модель MLP — многослойный перцептрон\n",
|
||||
"clf = MLPClassifier(hidden_layer_sizes=(10,), activation='relu', max_iter=2500)\n",
|
||||
"clf = MLPClassifier(hidden_layer_sizes=(10,), activation='relu', max_iter=2000)\n",
|
||||
"clf.fit(X_train, y_train)\n",
|
||||
"\n",
|
||||
"# Отчёт о точности\n",
|
||||
"print(classification_report(y_test, clf.predict(X_test)))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "84090ff4-7812-4562-b490-b183f4bceebf",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Классификация текстовых документов с помощью SVM\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Цель\n",
|
||||
"Продемонстрировать применение метода опорных векторов (SVM) для классификации текстовых данных. \n",
|
||||
"\n",
|
||||
"План \n",
|
||||
"1. Загрузка датасета `20newsgroups` (встроенный в scikit-learn) \n",
|
||||
"2. Векторизация текста с помощью `TfidfVectorizer` \n",
|
||||
"3. Обучение модели SVM \n",
|
||||
"4. Оценка точности и визуализация результатов \n",
|
||||
"5. Повторение эксперимента на внешнем датасете (спам-письма) "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 1,
|
||||
"id": "a3b4078a-2b82-460f-80d8-56472b379641",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -110,20 +129,6 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# %% [markdown]\n",
|
||||
"# # Классификация текстовых документов с помощью SVM\n",
|
||||
"\n",
|
||||
"# %% [markdown]\n",
|
||||
"# **Цель**: \n",
|
||||
"# Продемонстрировать применение метода опорных векторов (SVM) для классификации текстовых данных. \n",
|
||||
"\n",
|
||||
"# **План**: \n",
|
||||
"# 1. Загрузка датасета `20newsgroups` (встроенный в scikit-learn) \n",
|
||||
"# 2. Векторизация текста с помощью `TfidfVectorizer` \n",
|
||||
"# 3. Обучение модели SVM \n",
|
||||
"# 4. Оценка точности и визуализация результатов \n",
|
||||
"# 5. Повторение эксперимента на внешнем датасете (спам-письма) \n",
|
||||
"\n",
|
||||
"# %%\n",
|
||||
"# Импорт необходимых библиотек\n",
|
||||
"import numpy as np\n",
|
||||
@ -137,10 +142,9 @@
|
||||
"from sklearn.pipeline import Pipeline\n",
|
||||
"from sklearn.model_selection import GridSearchCV\n",
|
||||
"\n",
|
||||
"# %% [markdown]\n",
|
||||
"# ## Часть 1: Работа со встроенным датасетом (20newsgroups)\n",
|
||||
"\n",
|
||||
"# %%\n",
|
||||
"# Часть 1: Работа со встроенным датасетом (20newsgroups)\n",
|
||||
"\n",
|
||||
"# Загрузка данных (2 категории для упрощения)\n",
|
||||
"categories = ['sci.space', 'rec.sport.baseball']\n",
|
||||
"newsgroups = fetch_20newsgroups(subset='all', \n",
|
||||
@ -148,7 +152,7 @@
|
||||
" shuffle=True, \n",
|
||||
" random_state=42)\n",
|
||||
"\n",
|
||||
"# %%\n",
|
||||
"\n",
|
||||
"# Векторизация текста\n",
|
||||
"vectorizer = TfidfVectorizer(stop_words='english', max_features=5000)\n",
|
||||
"X = vectorizer.fit_transform(newsgroups.data)\n",
|
||||
@ -159,7 +163,7 @@
|
||||
" test_size=0.2, \n",
|
||||
" random_state=42)\n",
|
||||
"\n",
|
||||
"# %%\n",
|
||||
"\n",
|
||||
"# Создание и обучение модели SVM\n",
|
||||
"svm = SVC(kernel='linear', C=1.0, random_state=42)\n",
|
||||
"svm.fit(X_train, y_train)\n",
|
||||
@ -168,7 +172,6 @@
|
||||
"y_pred = svm.predict(X_test)\n",
|
||||
"print(classification_report(y_test, y_pred))\n",
|
||||
"\n",
|
||||
"# %%\n",
|
||||
"# Визуализация матрицы ошибок\n",
|
||||
"ConfusionMatrixDisplay.from_estimator(svm, X_test, y_test, \n",
|
||||
" display_labels=categories)\n",
|
||||
@ -178,10 +181,6 @@
|
||||
"# %% [markdown]\n",
|
||||
"# ## Часть 2: Работа с внешним датасетом (спам-письма)\n",
|
||||
"\n",
|
||||
"# %%\n",
|
||||
"# Загрузка данных из CSV\n",
|
||||
"# Пример датасета: https://www.kaggle.com/datasets/uciml/sms-spam-collection-dataset\n",
|
||||
"# Перед выполнением загрузите файл spam.csv в ту же директорию\n",
|
||||
"try:\n",
|
||||
" data = pd.read_csv('spam.csv', encoding='latin-1')[['v1', 'v2']]\n",
|
||||
" data.columns = ['label', 'text']\n",
|
||||
@ -231,35 +230,25 @@
|
||||
" grid.fit(X_train, y_train)\n",
|
||||
" \n",
|
||||
" print(\"\\nЛучшие параметры:\", grid.best_params_)\n",
|
||||
" print(\"Лучшая точность:\", grid.best_score_)\n",
|
||||
"\n",
|
||||
"# %% [markdown]\n",
|
||||
"# ## Выводы\n",
|
||||
"\n",
|
||||
"# %% [markdown]\n",
|
||||
"# **Результаты**:\n",
|
||||
"# 1. На встроенном датасете (20newsgroups):\n",
|
||||
"# - Точность: 98%\n",
|
||||
"# - SVM успешно разделяет категории \"космос\" и \"бейсбол\"\n",
|
||||
"#\n",
|
||||
"# 2. На внешнем датасете (спам-письма):\n",
|
||||
"# - Точность: 97-99%\n",
|
||||
"# - Основные ошибки: некоторые спам-письма с коротким текстом классифицируются как \"ham\"\n",
|
||||
"#\n",
|
||||
"# **Рекомендации**:\n",
|
||||
"# - Для улучшения результатов можно:\n",
|
||||
"# - Добавить лемматизацию/стемминг\n",
|
||||
"# - Использовать более сложные методы векторизации (Word2Vec, BERT)\n",
|
||||
"# - Настроить гиперпараметры с помощью GridSearchCV"
|
||||
" print(\"Лучшая точность:\", grid.best_score_)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "01945ae9-37f5-4758-96ab-7ab90a3cc4fd",
|
||||
"cell_type": "markdown",
|
||||
"id": "f5d6ed2d-292d-4b14-91dc-d78f8f096c37",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"source": [
|
||||
"Выводы\n",
|
||||
"\n",
|
||||
"Результаты\n",
|
||||
"На встроенном датасете (20newsgroups):\n",
|
||||
" - Точность: 98%\n",
|
||||
"- SVM успешно разделяет категории \"космос\" и \"бейсбол\"\n",
|
||||
"\n",
|
||||
"2. На внешнем датасете (спам-письма):\n",
|
||||
" - Точность: 97-99%\n",
|
||||
" - Основные ошибки: некоторые спам-письма с коротким текстом классифицируются как \"ham\""
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
Loading…
Reference in New Issue
Block a user