pyGCO copied to clipboard
Can't segment multisegment image using `gco.cut_general_graph()`
I'm trying to use pyGCO to segment color images with 3 or more ground truth segments using gco.cut_general_graph()
rather than gco.cut_grid_graph_simple()
so that I can later pass custom weights for the grid edges. For now, however, I am trying to segment the image in the simplest case where the unaries are appropriate segmentations with little noise and the grid edges have a weight of 0. In this case, if I'm not mistaken, the optimal partition should be the argmax of the unaries. Instead, gco.cut_general_graph()
is simply returning the first term of the unary.
I'm not sure if this is a edge case/bug or if I'm misunderstanding the expected behavior. Any ideas would be appreciated.
I've included my complete code below as well as its output.
import numpy as np
import gco
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
def get_uniform_smoothness_pw_single_image(img_shape):
H, W = img_shape
E = (H - 1) * W + H * (W - 1)
edges = np.empty((E, 2), dtype=int)
edge_weights = np.ones(E, dtype=np.single)
idx = 0
# horizontal edges
for row in range(H):
edges[idx:idx+W-1,0] = np.arange(W-1) + row * W
edges[idx:idx+W-1,1] = np.arange(W-1) + row * W + 1
idx += W-1
# vertical edges
for col in range(W):
edges[idx:idx+H-1,0] = np.arange(0, (H-1)*W, W) + col
edges[idx:idx+H-1,1] = np.arange(W, H*W, W) + col
idx += H-1
return [edges, edge_weights]
def graph_segment(img, num_seg = 2, lamb_grid = 2, unary_noise = 0.1, unary_scale = 3):
pairwise_pot = (1 - np.eye(num_seg)) * lamb_grid
grid_edges, grid_edge_weights = get_uniform_smoothness_pw_single_image((img.shape[0], img.shape[1]))
# run kmeans on img to get unaries
kmeans = KMeans(n_clusters=num_seg, random_state=0).fit(img.reshape(-1, 3))
kmeans_labels = kmeans.labels_
# create one hot encoding
kmeans_labels_onehot = np.zeros((len(kmeans_labels), num_seg))
for seg_i in range(num_seg):
kmeans_labels_onehot[:, seg_i] = (kmeans_labels.astype(int) == seg_i).flatten()
# Use onehot as unary
unary_vec = kmeans_labels_onehot
# Add noise to unary
unary_vec += unary_noise * np.random.randn(*unary_vec.shape)
unary_vec *= unary_scale
# display unaries
fig, axarr = plt.subplots(1, num_seg, figsize=(2 * num_seg, 2))
for i in range(num_seg):
unary_i = unary_vec[:, i].reshape(*img.shape[:2])
axarr[i].set_title("Unary term #%i" % i)
labels = gco.cut_general_graph(grid_edges,
return labels.reshape(*img.shape[:2])
image_path = ""
if "http" in image_path:
!wget -nc {image_path}
image_path = image_path.split("/")[-1]
img = plt.imread(image_path)
plt.title("Input Image")
seg = graph_segment(img, num_seg = 3, lamb_grid = 0, unary_noise = 0, unary_scale = 1)
plt.imshow(seg, interpolation="nearest")