DiT
DiT copied to clipboard
Fixed the sampling issue in the dataset with varying categories
This pull request aims to solve the below issue when sampling in a different classification datasets, and this issue occurs because the y_null
and class_labels
are hardcoded in the sample.py file.
RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasSgemmStridedBatched( handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches)`