BuildingExtraction icon indicating copy to clipboard operation
BuildingExtraction copied to clipboard

数据测试结果疑问

Open yelimei opened this issue 2 years ago • 13 comments

您好,在复现您的代码时我用我自己的图片替换了您提供的数据,不知道为什么最后出来的结果图还是原有的提取结果图。请问这个是需要怎么修改

yelimei avatar Sep 30 '22 02:09 yelimei

需要更改csv文件,'TEST_DATASET': 'Tools/generate_dep_info/test_data.csv',

KyanChen avatar Sep 30 '22 02:09 KyanChen

您好,我修改过后在运行显示我没有对应的“Label”文件,这个需要怎么修改呢?

需要更改csv文件,'TEST_DATASET': 'Tools/generate_dep_info/test_data.csv',

yelimei avatar Sep 30 '22 06:09 yelimei

使用test脚本,做下属修改 'IS_VAL': False, 'VAL_BATCH_SIZE': 4, 'VAL_DATASET': 'Tools/generate_dep_info/val_data.csv', # if you need the test process. 'IS_TEST': True, 'TEST_DATASET': '自己的csv, 'IMG_SIZE': [512, 512],

报错也可以出来预测的结果图

KyanChen avatar Sep 30 '22 06:09 KyanChen

austin2_01179_258_289 image 您好,按照您的修改来做,我仍然生成的是原来结果图片,对应的test_data.csv也做了相应修改。不知该如何修改了

yelimei avatar Sep 30 '22 06:09 yelimei

把代码贴出来

KyanChen avatar Sep 30 '22 07:09 KyanChen

    'IS_VAL': False,
    'VAL_BATCH_SIZE': 4,
    'VAL_DATASET': 'Tools/generate_dep_info/val_data.csv',
    # if you need the test process.
    'IS_TEST': True,
    'TEST_DATASET': 'Tools/generate_dep_info/test_data.csv',
    'IMG_SIZE': [256, 256],
    'PHASE': 'seg',

yelimei avatar Sep 30 '22 07:09 yelimei

import os

Change the numbers when you want to test with specific gpus

os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1, 2, 3'

import torch from STTNet import STTNet import torch.nn.functional as F from Utils.Datasets import get_data_loader from Utils.Utils import make_numpy_img, inv_normalize_img, encode_onehot_to_mask, get_metrics, Logger import matplotlib.pyplot as plt import numpy as np from collections import OrderedDict

if name == 'main': model_infos = { # vgg16_bn, resnet50, resnet18 'backbone': 'resnet50', 'pretrained': True, 'out_keys': ['block4'], 'in_channel': 3, 'n_classes': 2, 'top_k_s': 64, 'top_k_c': 16, 'encoder_pos': True, 'decoder_pos': True, 'model_pattern': ['X', 'A', 'S', 'C'],

    'log_path': 'Results',
    'NUM_WORKERS': 0,
    # if you need the validation process.
    'IS_VAL': False,
    'VAL_BATCH_SIZE': 4,
    'VAL_DATASET': 'Tools/generate_dep_info/val_data.csv',
    # if you need the test process.
    'IS_TEST': True,
    'TEST_DATASET': 'Tools/generate_dep_info/test_data.csv',
    'IMG_SIZE': [256, 256],
    'PHASE': 'seg',

    # INRIA Dataset
    'PRIOR_MEAN': [0.40672500537632994, 0.42829032416229895, 0.39331840468605667],
    'PRIOR_STD': [0.029498464618176873, 0.027740088491668233, 0.028246722411879095],
    # # # WHU Dataset
    # 'PRIOR_MEAN': [0.4352682576428411, 0.44523221318154493, 0.41307610541534784],
    # 'PRIOR_STD': [0.026973196780331585, 0.026424642808887323, 0.02791246590291434],

    # load state dict path
    'load_checkpoint_path': r'D:/building/buildingextraction/BuildingExtraction-main/Checkpoints/ckpt_latest.pt',
}
if model_infos['IS_VAL']:
    os.makedirs(model_infos['log_path']+'/val', exist_ok=True)
if model_infos['IS_TEST']:
    os.makedirs(model_infos['log_path']+'/test', exist_ok=True)
logger = Logger(model_infos['log_path'] + '/log.log')

data_loaders = get_data_loader(model_infos, test_mode=True)
loss_weight = 0.1
model = STTNet(**model_infos)

logger.write(f'load checkpoint from {model_infos["load_checkpoint_path"]}\n')
state_dict = torch.load(model_infos['load_checkpoint_path'], map_location='cpu')
model_dict = state_dict['model_state_dict']
try:
    model_dict = OrderedDict({k.replace('module.', ''): v for k, v in model_dict.items()})
    model.load_state_dict(model_dict)
except Exception as e:
    model.load_state_dict(model_dict)
model = model.cuda()
device_ids = range(torch.cuda.device_count())
if len(device_ids) > 1:
    model = torch.nn.DataParallel(model, device_ids=device_ids)
    logger.write(f'Use GPUs: {device_ids}\n')
else:
    logger.write(f'Use GPUs: 1\n')

patterns = ['val', 'test']
for pattern_id, is_pattern in enumerate([model_infos['IS_VAL'], model_infos['IS_TEST']]):
    if is_pattern:
        # pred: logits, tensor, nBatch * nClass * W * H
        # target: labels, tensor, nBatch * nClass * W * H
        # output, batch['label']
        collect_result = {'pred': [], 'target': []}
        pattern = patterns[pattern_id]
        model.eval()
        for batch_id, batch in enumerate(data_loaders[pattern]):
            # Get data
            img_batch = batch['img'].cuda()
            label_batch = batch['label'].cuda()
            img_names = batch['img_name']
            collect_result['target'].append(label_batch.data.cpu())

            # inference
            with torch.no_grad():
                logits, att_branch_output = model(img_batch)

            collect_result['pred'].append(logits.data.cpu())
            # get segmentation result, when the phase is test.
            pred_label = torch.argmax(logits, 1)
            pred_label *= 255

            # output the segmentation result
            if pattern == 'test' or batch_id % 5 == 1:
                batch_size = pred_label.size(0)
                # k = np.clip(int(0.3 * batch_size), a_min=1, a_max=batch_size)
                # ids = np.random.choice(range(batch_size), k, replace=False)
                ids = range(batch_size)
                for img_id in ids:
                    img = img_batch[img_id].detach().cpu()
                    target = label_batch[img_id].detach().cpu()
                    pred = pred_label[img_id].detach().cpu()
                    img_name = img_names[img_id]

                    img = make_numpy_img(
                        inv_normalize_img(img, model_infos['PRIOR_MEAN'], model_infos['PRIOR_STD']))
                    target = make_numpy_img(encode_onehot_to_mask(target)) * (255,190,190)
                    pred = make_numpy_img(pred)

                    vis = np.concatenate([img / 255., target / 255. ], axis=0)
                    vis = np.clip(vis, a_min=0, a_max=1)
                    file_name = os.path.join(model_infos['log_path'], pattern, f'{img_name.split(".")[0]}.png')
                    plt.imsave(file_name, vis)

        collect_result['pred'] = torch.cat(collect_result['pred'], dim=0)
        collect_result['target'] = torch.cat(collect_result['target'], dim=0)
        IoU, OA, F1_score = get_metrics('seg', **collect_result)
        logger.write(f'{pattern}: Iou:{IoU[-1]:.4f} OA:{OA[-1]:.4f} F1:{F1_score[-1]:.4f}\n')

yelimei avatar Sep 30 '22 07:09 yelimei

vis = np.concatenate([img / 255., target / 255., pred / 255.], axis=0)

KyanChen avatar Sep 30 '22 07:09 KyanChen

也不对哎

yelimei avatar Sep 30 '22 07:09 yelimei

咋不对

KyanChen avatar Sep 30 '22 07:09 KyanChen

1 第一张图是我自己的数据集,2三张是原有lable中的

yelimei avatar Sep 30 '22 07:09 yelimei

这是什么数据集?训练过吗

KyanChen avatar Sep 30 '22 07:09 KyanChen

我自己随便截图的,不过图片属性也是根据您的数据中进行调整设置了,训练过

yelimei avatar Sep 30 '22 07:09 yelimei