SSUN icon indicating copy to clipboard operation
SSUN copied to clipboard

Training on other datasets

Open aldiak opened this issue 4 years ago • 10 comments

Hi, I would like to get the parameters used for the KSC and Indian Pines datasets if possible.

aldiak avatar Aug 16 '20 19:08 aldiak

As mentioned in the paper, we use the same hyper-parameter setting for all three datasets. For experiments on the Indian Pines and KSC datasets, just replace dataID=1 in SSUN.py to dataID=2 (for Indian Pines) and dataID=6 (for KSC), respectively.

YonghaoXu avatar Aug 17 '20 08:08 YonghaoXu

As mentioned in the paper, we use the same hyper-parameter setting for all three datasets. For experiments on the Indian Pines and KSC datasets, just replace dataID=1 in SSUN.py to dataID=2 (for Indian Pines) and dataID=6 (for KSC), respectively.

Alright, thanks.

aldiak avatar Aug 17 '20 11:08 aldiak

But when I change the data id only I am getting the following error:

OASpectral_IP[0:9,r] = ProducerA ValueError: could not broadcast input array from shape (16) into shape (9)

aldiak avatar Aug 17 '20 12:08 aldiak

This error could be solved by changing the initialization of OASpectral_IP by OASpectral_IP = np.zeros((16+2,randtime)), where 16 corresponds to the number of categories in the Indian Pines dataset.

YonghaoXu avatar Aug 18 '20 01:08 YonghaoXu

I change it to:
OASpectral_IP = np.zeros((16+2,randtime)) s1s2=1 OASpectral_Pavia1 = 'spec1' time_step = 3

But I am still getting this error:

File "SSUN.py", line 240, in OASpectral_IP[0:9,r] = ProducerA ValueError: could not broadcast input array from shape (16) into shape (9)

aldiak avatar Aug 18 '20 07:08 aldiak

The ProducerA vector contains the producer accuracy of the input data. Since there are 16 categories in the Indian Pines dataset, the ProducerA vector should also have 16 elements. Thus, you can simply replace the 9 in the bracket by 16. Similar modification could be made for the KSC dataset.

YonghaoXu avatar Aug 18 '20 08:08 YonghaoXu

That is all the modification:

OASpectral_IP = np.zeros((16+2,randtime)) s1s2=1 OASpectral_Pavia1 = 'spec1' time_step = 3

for r in range(0,randtime):

#################Pavia#################
dataID=2
data = HyperspectralSamples(dataID=dataID, timestep=time_step, w=w, num_PC=num_PC, israndom=israndom, s1s2=s1s2)
X = data[0]
X_train = data[1]
X_test = data[2]
XP = data[3]
XP_train = data[4]
XP_test = data[5]
Y = data[6]-1
Y_train = data[7]-1
Y_test = data[8]-1

batch_size = 128

nb_classes = Y_train.max()+1
nb_epoch = 50
nb_features = X.shape[-1]

img_rows, img_cols = XP.shape[1],XP.shape[1]
# convert class vectors to binary class matrices
y_train = np_utils.to_categorical(Y_train, nb_classes)
y_test = np_utils.to_categorical(Y_test, nb_classes)

model = LSTM_RS(time_step=time_step,nb_features=nb_features)
tic1 = time.clock()
histloss=model.fit([X_train], [y_train], nb_epoch=nb_epoch, batch_size=batch_size, verbose=1, shuffle=True)
losses = histloss.history
toc1 = time.clock()

tic2 = time.clock()

PredictLabel = model.predict([X_test],verbose=1).argmax(axis=-1)
toc2 = time.clock()

OA,Kappa,ProducerA = CalAccuracy(PredictLabel,Y_test[:,0])    
OASpectral_IP[0:16,r] = ProducerA
OASpectral_IP[-2,r] = OA
OASpectral_IP[-1,r] = Kappa

But still getting error:

Traceback (most recent call last): File "SSUN.py", line 254, in X_result = DrawResult(Spectral,1) File "/home/alou/Desktop/ssun/train.py", line 117, in DrawResult X_result[np.where(labels==i),0] = palette[i,0] IndexError: index 9 is out of bounds for axis 0 with size 9

aldiak avatar Aug 18 '20 09:08 aldiak

the train.py here represent your helper function.

aldiak avatar Aug 18 '20 09:08 aldiak

You need to change the imageID in DrawResult func to generate the corresponding classification map. Please refer to the DrawResult func in HyperFunctions.py for details.

YonghaoXu avatar Aug 20 '20 01:08 YonghaoXu

Ok thank

Le jeudi 20 août 2020, YonghaoXu [email protected] a écrit :

You need to change the imageID in DrawResult func to generate the corresponding classification map. Please refer to the DrawResult func in HyperFunctions.py for details.

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/YonghaoXu/SSUN/issues/11#issuecomment-676838876, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFOVODUTANLAJ5U6N4VY5MTSBRZ7TANCNFSM4QA7J5VA .

aldiak avatar Aug 20 '20 05:08 aldiak