pyGCO
pyGCO copied to clipboard
Can't segment multisegment image using `gco.cut_general_graph()`
Problem
I'm trying to use pyGCO to segment color images with 3 or more ground truth segments using gco.cut_general_graph()
rather than gco.cut_grid_graph_simple()
so that I can later pass custom weights for the grid edges. For now, however, I am trying to segment the image in the simplest case where the unaries are appropriate segmentations with little noise and the grid edges have a weight of 0. In this case, if I'm not mistaken, the optimal partition should be the argmax of the unaries. Instead, gco.cut_general_graph()
is simply returning the first term of the unary.
I'm not sure if this is a edge case/bug or if I'm misunderstanding the expected behavior. Any ideas would be appreciated.
I've included my complete code below as well as its output.
Code:
import numpy as np
import gco
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
def get_uniform_smoothness_pw_single_image(img_shape):
H, W = img_shape
E = (H - 1) * W + H * (W - 1)
edges = np.empty((E, 2), dtype=int)
edge_weights = np.ones(E, dtype=np.single)
idx = 0
# horizontal edges
for row in range(H):
edges[idx:idx+W-1,0] = np.arange(W-1) + row * W
edges[idx:idx+W-1,1] = np.arange(W-1) + row * W + 1
idx += W-1
# vertical edges
for col in range(W):
edges[idx:idx+H-1,0] = np.arange(0, (H-1)*W, W) + col
edges[idx:idx+H-1,1] = np.arange(W, H*W, W) + col
idx += H-1
return [edges, edge_weights]
def graph_segment(img, num_seg = 2, lamb_grid = 2, unary_noise = 0.1, unary_scale = 3):
pairwise_pot = (1 - np.eye(num_seg)) * lamb_grid
grid_edges, grid_edge_weights = get_uniform_smoothness_pw_single_image((img.shape[0], img.shape[1]))
# run kmeans on img to get unaries
kmeans = KMeans(n_clusters=num_seg, random_state=0).fit(img.reshape(-1, 3))
kmeans_labels = kmeans.labels_
# create one hot encoding
kmeans_labels_onehot = np.zeros((len(kmeans_labels), num_seg))
for seg_i in range(num_seg):
kmeans_labels_onehot[:, seg_i] = (kmeans_labels.astype(int) == seg_i).flatten()
# Use onehot as unary
unary_vec = kmeans_labels_onehot
# Add noise to unary
unary_vec += unary_noise * np.random.randn(*unary_vec.shape)
unary_vec *= unary_scale
# display unaries
fig, axarr = plt.subplots(1, num_seg, figsize=(2 * num_seg, 2))
for i in range(num_seg):
unary_i = unary_vec[:, i].reshape(*img.shape[:2])
axarr[i].imshow(unary_i)
axarr[i].set_title("Unary term #%i" % i)
plt.show()
labels = gco.cut_general_graph(grid_edges,
grid_edge_weights,
unary_vec,
pairwise_pot,
algorithm="swap")
return labels.reshape(*img.shape[:2])
image_path = "https://hips.hearstapps.com/hmg-prod/images/dog-puppy-on-garden-royalty-free-image-1586966191.jpg"
if "http" in image_path:
!wget -nc {image_path}
image_path = image_path.split("/")[-1]
img = plt.imread(image_path)
plt.imshow(img)
plt.title("Input Image")
plt.show()
seg = graph_segment(img, num_seg = 3, lamb_grid = 0, unary_noise = 0, unary_scale = 1)
plt.imshow(seg, interpolation="nearest")
plt.title("Partition")
plt.show()