main/04_spectral_clustering/image_segmentation.py

43 lines
1.6 KiB
Python

import time
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter
from skimage.data import coins
from skimage.transform import rescale
from sklearn.cluster import spectral_clustering
from sklearn.feature_extraction import image
import logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logging.info("Загрузка и препроцессинг изображения...")
orig = coins()
smoothed = gaussian_filter(orig, sigma=2)
rescaled = rescale(smoothed, 0.2, mode="reflect", anti_aliasing=False)
logging.info("Преобразование в граф...")
graph = image.img_to_graph(rescaled)
beta, eps = 10, 1e-6
graph.data = np.exp(-beta * graph.data / graph.data.std()) + eps
n_regions, n_plus = 26, 3
logging.info(f"Запуск спектральной кластеризации ({n_regions} регионов)...")
results = {}
for method in ("kmeans", "discretize", "cluster_qr"):
logging.info(f"Метод: {method}")
t0 = time.time()
labels = spectral_clustering(graph, n_clusters=n_regions+n_plus,
eigen_tol=1e-7, assign_labels=method, random_state=42)
labels = labels.reshape(rescaled.shape)
results[method] = (labels, time.time()-t0)
logging.info("Визуализация результатов...")
for method, (labels, dur) in results.items():
plt.figure(figsize=(6,6))
plt.imshow(rescaled, cmap="gray")
plt.title(f"{method}, {dur:.2f} сек.")
plt.axis("off")
for l in range(n_regions):
plt.contour(labels == l, colors=[plt.cm.nipy_spectral((l+4)/float(n_regions+4))])
plt.show()