pyGCO icon indicating copy to clipboard operation
pyGCO copied to clipboard

Can't segment multisegment image using `gco.cut_general_graph()`

Open isaacwasserman opened this issue 1 year ago • 0 comments

Problem

I'm trying to use pyGCO to segment color images with 3 or more ground truth segments using gco.cut_general_graph() rather than gco.cut_grid_graph_simple() so that I can later pass custom weights for the grid edges. For now, however, I am trying to segment the image in the simplest case where the unaries are appropriate segmentations with little noise and the grid edges have a weight of 0. In this case, if I'm not mistaken, the optimal partition should be the argmax of the unaries. Instead, gco.cut_general_graph() is simply returning the first term of the unary.

I'm not sure if this is a edge case/bug or if I'm misunderstanding the expected behavior. Any ideas would be appreciated.

I've included my complete code below as well as its output.

Code:

import numpy as np
import gco  
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans

def get_uniform_smoothness_pw_single_image(img_shape):
    H, W = img_shape
    E = (H - 1) * W + H * (W - 1)

    edges = np.empty((E, 2), dtype=int)
    edge_weights = np.ones(E, dtype=np.single)
    idx = 0

    # horizontal edges
    for row in range(H):
        edges[idx:idx+W-1,0] = np.arange(W-1) + row * W
        edges[idx:idx+W-1,1] = np.arange(W-1) + row * W + 1
        idx += W-1

    # vertical edges
    for col in range(W):
        edges[idx:idx+H-1,0] = np.arange(0, (H-1)*W, W) + col
        edges[idx:idx+H-1,1] = np.arange(W, H*W, W) + col
        idx += H-1

    return [edges, edge_weights]

def graph_segment(img, num_seg = 2, lamb_grid = 2, unary_noise = 0.1, unary_scale = 3):
    pairwise_pot = (1 - np.eye(num_seg)) * lamb_grid
    grid_edges, grid_edge_weights = get_uniform_smoothness_pw_single_image((img.shape[0], img.shape[1]))
    # run kmeans on img to get unaries
    kmeans = KMeans(n_clusters=num_seg, random_state=0).fit(img.reshape(-1, 3))
    kmeans_labels = kmeans.labels_
    # create one hot encoding
    kmeans_labels_onehot = np.zeros((len(kmeans_labels), num_seg))
    for seg_i in range(num_seg):
        kmeans_labels_onehot[:, seg_i] = (kmeans_labels.astype(int) == seg_i).flatten()
    # Use onehot as unary
    unary_vec = kmeans_labels_onehot
    # Add noise to unary    
    unary_vec += unary_noise * np.random.randn(*unary_vec.shape)
    unary_vec *= unary_scale

    # display unaries
    fig, axarr = plt.subplots(1, num_seg, figsize=(2 * num_seg, 2))
    for i in range(num_seg):    
        unary_i = unary_vec[:, i].reshape(*img.shape[:2])
        axarr[i].imshow(unary_i)
        axarr[i].set_title("Unary term #%i" % i)
    plt.show()

    labels = gco.cut_general_graph(grid_edges, 
                               grid_edge_weights, 
                               unary_vec, 
                               pairwise_pot, 
                               algorithm="swap")

    return labels.reshape(*img.shape[:2])

image_path = "https://hips.hearstapps.com/hmg-prod/images/dog-puppy-on-garden-royalty-free-image-1586966191.jpg"
if "http" in image_path:
    !wget -nc {image_path}
    image_path = image_path.split("/")[-1]
img = plt.imread(image_path)

plt.imshow(img)
plt.title("Input Image")
plt.show()
seg = graph_segment(img, num_seg = 3, lamb_grid = 0, unary_noise = 0, unary_scale = 1)
plt.imshow(seg, interpolation="nearest")
plt.title("Partition")
plt.show()
Screenshot 2023-07-19 at 1 32 07 PM

isaacwasserman avatar Jul 19 '23 17:07 isaacwasserman