Add spectral clustering image segmentation
This commit is contained in:
parent
dca62325b0
commit
6af81cbf92
67
04_spectral_clustering/image_segmentation.py
Normal file
67
04_spectral_clustering/image_segmentation.py
Normal file
@ -0,0 +1,67 @@
|
||||
import time
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import logging
|
||||
from scipy.ndimage import gaussian_filter
|
||||
from skimage.data import coins
|
||||
from skimage.transform import rescale
|
||||
from sklearn.cluster import spectral_clustering
|
||||
from sklearn.feature_extraction import image
|
||||
|
||||
# Настройка логирования
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format='%(asctime)s [%(levelname)s] %(message)s')
|
||||
|
||||
def main():
|
||||
logging.info("Загрузка и обработка изображения...")
|
||||
|
||||
# Загрузка изображения
|
||||
orig_coins = coins()
|
||||
smoothed = gaussian_filter(orig_coins, sigma=2)
|
||||
rescaled = rescale(smoothed, 0.2, mode="reflect", anti_aliasing=False)
|
||||
|
||||
logging.info(f"Изображение: {rescaled.shape}")
|
||||
|
||||
# Построение графа
|
||||
logging.info("Построение графа...")
|
||||
graph = image.img_to_graph(rescaled)
|
||||
beta = 10
|
||||
eps = 1e-6
|
||||
graph.data = np.exp(-beta * graph.data / graph.data.std()) + eps
|
||||
|
||||
# Кластеризация
|
||||
n_regions = 26
|
||||
methods = ('kmeans', 'discretize', 'cluster_qr')
|
||||
|
||||
for method in methods:
|
||||
logging.info(f"Метод: {method}")
|
||||
t0 = time.time()
|
||||
|
||||
labels = spectral_clustering(
|
||||
graph,
|
||||
n_clusters=n_regions + 3,
|
||||
eigen_tol=1e-7,
|
||||
assign_labels=method,
|
||||
random_state=42
|
||||
)
|
||||
|
||||
duration = time.time() - t0
|
||||
labels = labels.reshape(rescaled.shape)
|
||||
|
||||
logging.info(f"Завершено за {duration:.2f} сек.")
|
||||
|
||||
# Визуализация
|
||||
plt.figure(figsize=(6, 6))
|
||||
plt.imshow(rescaled, cmap=plt.cm.gray)
|
||||
plt.title(f"{method}, {duration:.2f} сек.")
|
||||
plt.axis('off')
|
||||
|
||||
for l in range(n_regions):
|
||||
color = plt.cm.nipy_spectral((l + 4) / float(n_regions + 4))
|
||||
plt.contour(labels == l, colors=[color], linewidths=0.5)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Loading…
Reference in New Issue
Block a user