CHR
CHR copied to clipboard
read_object_labels not found in ray.py:97
Actually it seems not defined globally within the hole projects, so maybe it's a piece of testing code before release?
How did you solve the problem?
Here is the git diffs in my repos
diff --git a/CHR/CHR/engine.py b/CHR/CHR/engine.py
index 2845706..1dd7694 100644
--- a/CHR/CHR/engine.py
+++ b/CHR/CHR/engine.py
@@ -13,7 +13,7 @@ from tqdm import tqdm
import numpy as np
from CHR.util import AveragePrecisionMeter, Warp
-
+from CHR.ray import read_image_label
class Engine(object):
def __init__(self, state={}):
diff --git a/CHR/CHR/main.py b/CHR/CHR/main.py
index e4a2533..7ed94be 100644
--- a/CHR/CHR/main.py
+++ b/CHR/CHR/main.py
@@ -72,7 +72,7 @@ def main_ray():
global args, best_prec1, use_gpu
args = parser.parse_args()
- args.data='/DATA/disk1/mcj/dataset/'
+ args.data='/mnt/lyz/SIXray-data/'
args.resume = './CHR/models-/checkpoint.pth.tar'
@@ -81,7 +81,7 @@ def main_ray():
# define dataset
train_dataset = XrayClassification(args.data, 'train')
- val_dataset = XrayClassification(args.data, 'test_new')
+ val_dataset = XrayClassification(args.data, 'test')
num_classes = 5
# load model
diff --git a/CHR/CHR/ray.py b/CHR/CHR/ray.py
index 720472a..bab4fbb 100644
--- a/CHR/CHR/ray.py
+++ b/CHR/CHR/ray.py
@@ -85,18 +85,19 @@ class XrayClassification(data.Dataset):
# define path of csv file
- path_csv = os.path.join(self.root, 'ImageSet','train_test_10-5')
+ path_csv = os.path.join(self.root, 'ImageSet', '10')
# define filename of csv file
file_csv = os.path.join(path_csv, set + '.csv')
# create the csv file if necessary
if not os.path.exists(file_csv):
- if not os.path.exists(path_csv): # create dir if necessary
- os.makedirs(path_csv)
- # generate csv file
- labeled_data = read_object_labels(self.root, self.set)
- # write csv file
- write_object_labels_csv(file_csv, labeled_data)
+ # if not os.path.exists(path_csv): # create dir if necessary
+ # os.makedirs(path_csv)
+ # # generate csv file
+ # labeled_data = read_object_labels(self.root, self.set)
+ # # write csv file
+ # write_object_labels_csv(file_csv, labeled_data)
+ raise ValueError(file_csv + " not found.")
self.classes = object_categories
self.images = read_object_labels_csv(file_csv)
How did you solve the problem?
annotation in the project of SIXray
Here is the git diffs in my repos
diff --git a/CHR/CHR/engine.py b/CHR/CHR/engine.py index 2845706..1dd7694 100644 --- a/CHR/CHR/engine.py +++ b/CHR/CHR/engine.py @@ -13,7 +13,7 @@ from tqdm import tqdm import numpy as np from CHR.util import AveragePrecisionMeter, Warp - +from CHR.ray import read_image_label class Engine(object): def __init__(self, state={}): diff --git a/CHR/CHR/main.py b/CHR/CHR/main.py index e4a2533..7ed94be 100644 --- a/CHR/CHR/main.py +++ b/CHR/CHR/main.py @@ -72,7 +72,7 @@ def main_ray(): global args, best_prec1, use_gpu args = parser.parse_args() - args.data='/DATA/disk1/mcj/dataset/' + args.data='/mnt/lyz/SIXray-data/' args.resume = './CHR/models-/checkpoint.pth.tar' @@ -81,7 +81,7 @@ def main_ray(): # define dataset train_dataset = XrayClassification(args.data, 'train') - val_dataset = XrayClassification(args.data, 'test_new') + val_dataset = XrayClassification(args.data, 'test') num_classes = 5 # load model diff --git a/CHR/CHR/ray.py b/CHR/CHR/ray.py index 720472a..bab4fbb 100644 --- a/CHR/CHR/ray.py +++ b/CHR/CHR/ray.py @@ -85,18 +85,19 @@ class XrayClassification(data.Dataset): # define path of csv file - path_csv = os.path.join(self.root, 'ImageSet','train_test_10-5') + path_csv = os.path.join(self.root, 'ImageSet', '10') # define filename of csv file file_csv = os.path.join(path_csv, set + '.csv') # create the csv file if necessary if not os.path.exists(file_csv): - if not os.path.exists(path_csv): # create dir if necessary - os.makedirs(path_csv) - # generate csv file - labeled_data = read_object_labels(self.root, self.set) - # write csv file - write_object_labels_csv(file_csv, labeled_data) + # if not os.path.exists(path_csv): # create dir if necessary + # os.makedirs(path_csv) + # # generate csv file + # labeled_data = read_object_labels(self.root, self.set) + # # write csv file + # write_object_labels_csv(file_csv, labeled_data) + raise ValueError(file_csv + " not found.") self.classes = object_categories self.images = read_object_labels_csv(file_csv)
How did you solve the problem?
I have update the code. You can git clone the new code thank you