trajnetplusplusbaselines icon indicating copy to clipboard operation
trajnetplusplusbaselines copied to clipboard

Question about dataset conversion and best hyper parameters for sgan

Open kangbeenlee opened this issue 1 year ago • 0 comments

Hello!

I'm really grateful to share your open source prediction model. I have a few question about your code.

1. Dataset Most of prediction paper use ETH and UCY dataset such as ETH, HOTEL, UNIV, ZARA1, and ZARA2 (in social gan paper). So, I wanna use only ETH and UCY datasets for training. I found several compressed datasets in data folder from trajnetplusplusdataset repository like belows:

  • ewap_dataset_light.tgz

    seq_eth, seq_hotel

  • data_zara.rar

    crowds_zara01, crowds_zara02, crowds_zara03

  • data_university_students.rar

    students001, students003, uni_examples

Is it right to use seq_eth, seq_hotel for ETH and HOTEL dataset, use crowds_zara01 for ZARA1, use crowds_zara02 for ZARA2, and use students001, students003, uni_examples for UNIV? if wrong please let me know about proper dataset file!

2. Data conversion for leave-one-out approach training Most of paper use leave-one-out approach (in social lstm, social gan paper), training on 4 sets and test on the remaining set. So, I tried to train your social lstm model with seq_hotel (HOTEL), crowds_zara01 (ZARA1), crowds_zara02 (ZARA2), students001, students003, uni_examples (UNIV) and then test the trained model with seq_eth (ETH). After training the social lstm model like above way and testing with seq_eth (ETH) dataset, I got test_pred folder! However, when I try to visualize the result with visualize_predictions.py, I encountered some error like below:

python -m evaluator.visualize_predictions DATA_BLOCK/trajdata/test_private/biwi_eth.ndjson DATA_BLOCK/trajdata/test_pred/lstm_social_ETH_modes1/biwi_eth.ndjson --n 10 Scene ID: 714 Traceback (most recent call last): File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main return _run_code(code, main_globals, None, File "/usr/lib/python3.8/runpy.py", line 87, in _run_code exec(code, run_globals) File "/home/kblee/Study/trajnet++/trajnetplusplusbaselines/evaluator/visualize_predictions.py", line 100, in main() File "/home/kblee/Study/trajnet++/trajnetplusplusbaselines/evaluator/visualize_predictions.py", line 90, in main full_predicted_paths = add_gt_observation_to_prediction(paths, predicted_paths) File "/home/kblee/Study/trajnet++/trajnetplusplusbaselines/evaluator/visualize_predictions.py", line 18, in add_gt_observation_to_prediction full_predicted_paths = [gt_observation[ped_id][:obs_length] + pred for ped_id, pred in enumerate(model_prediction)] File "/home/kblee/Study/trajnet++/trajnetplusplusbaselines/evaluator/visualize_predictions.py", line 18, in full_predicted_paths = [gt_observation[ped_id][:obs_length] + pred for ped_id, pred in enumerate(model_prediction)] IndexError: list index out of range

The reason was due to the number of tracks (pedestrain) in observation (gt) and model prediction for each scene are different. Could you give any solution to solve this problem? The code I used when train and test is as follows:

training: python -m trajnetbaselines.lstm.trainer --type social --augment --epochs 25 --step_size 10 --n 16 --cell_side 0.6 --embedding_arch two_layer --layer_dims 1024 --batch_size 8 --loss pred --output ETH test (evaluate): python -m trajnetbaselines.lstm.trajnet_evaluator --path trajdata --output OUTPUT_BLOCK/trajdata/lstm_social_ETH.pkl

Test data is made with only seq_eth dataset like this (python -m trajnetdataset.convert --train_fraction 0.0 --val_fraction 0.0) (I changed some code related to floating point like (test_fraction = 1 - args.train_fraction - args.val_fraction)

3. The difference between output_pre and output What is the difference between output_pre and output during data conversion?

4. What is --n parameter in visualize_predictions.py?

5. Best hyper parameters for sgan I want to train social-gan model on ETH and UCY datasets just like in social gan paper. Would you mind sharing the best hyper parameters that achieve the results stated in the paper? I tried to use below code. python -m trajnetbaselines.sgan.trainer --type hiddenstatemlp --augment --noise_dim 8 --k 20 --output ETH Is there any parameter I need to add or change?

Thanks.

kangbeenlee avatar Apr 09 '23 12:04 kangbeenlee