text-to-image icon indicating copy to clipboard operation
text-to-image copied to clipboard

trainig on birds data set

Open mohitrgiit opened this issue 8 years ago • 1 comments

?did you train it for birds data set?

mohitrgiit avatar Apr 25 '17 18:04 mohitrgiit

yes we did.

to train on different dataset, just simply write your own data_loader.py

here is an exmple of data loader for bird dataset. enjoy

if xxx
   ...
elif dataset == '200birds':

    cwd = os.getcwd()
    img_dir = os.path.join(cwd, '200birds/200birds_box')
    caption_dir = os.path.join(cwd, '200birds/text_c10')
    VOC_FIR = cwd + '/vocab.txt'
    TRAIN_CLASS_FILE = os.path.join(cwd, '200birds/trainvalclasses.txt')
    TEST_CLASS_FILE = os.path.join(cwd, '200birds/testclasses.txt')


    train_class_set = set()
    with open(TRAIN_CLASS_FILE, 'r') as f:
        for line in f:
            train_class_set.add(line.rstrip())

    test_class_set = set()
    with open(TEST_CLASS_FILE, 'r') as f:
        for line in f:
            test_class_set.add(line.rstrip())

    caption_sub_dir = sorted(load_folder_list( caption_dir ))
    processed_capts = []
    train_captions_list = []
    test_captions_list = []

    for sub_dir in caption_sub_dir: # get caption file list
        with tl.ops.suppress_stdout():
            files = sorted(tl.files.load_file_list(path=sub_dir, regx='_[0-9]+_[0-9]+\.txt'))
            basename = os.path.basename(sub_dir)

            if basename in train_class_set:
                target_list = train_captions_list
            elif basename in test_class_set:
                target_list = test_captions_list
            else:
                raise KeyError

            for f in files:
                file_dir = os.path.join(sub_dir, f)
                t = open(file_dir,'r')

                line_num = 0
                for line in t:
                    preprocessed_line = preprocess_caption(line)
                    processed_capts.append(tl.nlp.process_sentence(preprocessed_line, start_word="<S>", end_word="</S>"))
                    target_list.append(preprocessed_line)
                    line_num += 1

                assert line_num == 10, "Every bird image should have 10 captions"


    ## build vocab
    if not os.path.isfile('vocab.txt'):
        _ = tl.nlp.create_vocab(processed_capts, word_counts_output_file=VOC_FIR, min_word_count=1)
    else:
        print("WARNING: vocab.txt already exists")
    vocab = tl.nlp.Vocabulary(VOC_FIR, start_word="<S>", end_word="</S>", unk_word="<UNK>")


    ## create captions id
    train_captions_ids = []
    for caption in train_captions_list:
        train_captions_ids.append(
            [vocab.word_to_id(word) for word in nltk.tokenize.word_tokenize(caption)])
    train_captions_ids = np.asarray(train_captions_ids)

    test_captions_ids = []
    for caption in test_captions_list:
        test_captions_ids.append(
            [vocab.word_to_id(word) for word in nltk.tokenize.word_tokenize(caption)])
    test_captions_ids = np.asarray(test_captions_ids)

    print(" * tokenized %d captions for training" % len(train_captions_ids))
    print(" * tokenized %d captions for testing" % len(test_captions_ids))
    print(" * tokenized %d captions total" % (len(train_captions_ids) + len(test_captions_ids)))

    ## check
    img_capt = train_captions_list[0]
    print("img_capt: %s" % img_capt)
    print("nltk.tokenize.word_tokenize(img_capt): %s" % nltk.tokenize.word_tokenize(img_capt))
    img_capt_ids = [vocab.word_to_id(word) for word in nltk.tokenize.word_tokenize(img_capt)]#img_capt.split(' ')]
    print("img_capt_ids: %s" % img_capt_ids)
    print("id_to_word: %s" % [vocab.id_to_word(id) for id in img_capt_ids])

    del train_captions_list
    del test_captions_list

    ## load images
    s = time.time()

    train_image = []
    test_image = []
    train_image_256 = []
    test_image_256 = []
    image_sub_dir = sorted(load_folder_list(img_dir))
    for sub_dir in image_sub_dir:  # get image file list
        with tl.ops.suppress_stdout():  # get image files list
            files = sorted(tl.files.load_file_list(path=sub_dir, regx='_[0-9]+_[0-9]+\.jpg'))
            basename = os.path.basename(sub_dir)

            if basename in train_class_set:
                target_list = train_image
                if need_256:
                    target_list_256 = train_image_256
            elif basename in test_class_set:
                target_list = test_image
                if need_256:
                    target_list_256 = test_image_256
            else:
                raise KeyError

            for f in files:
                img_raw = scipy.misc.imread(os.path.join(sub_dir, f))
                if len(img_raw.shape) < 3:
                    img_raw = color.gray2rgb(img_raw)
                img = tl.prepro.imresize(img_raw, size=[64, 64])  # (64, 64, 3)
                img = img.astype(np.float32)
                target_list.append(img)
                if need_256:
                    img = tl.prepro.imresize(img_raw, size=[256, 256])  # (256, 256, 3)
                    img = img.astype(np.float32)

                    target_list_256.append(img)

    # train_image = np.asarray(train_image)
    # test_image = np.asarray(test_image)
    # train_image_256 = np.asarray(train_image_256)
    # test_image_256 = np.asarray(test_image_256)

    print(" * %d images for training" % len(train_image))
    print(" * %d images for testing" % len(test_image))
    print(" * %d images for total" % (len(train_image) + len(test_image)))

    print(" * loading and resizing took %ss" % (time.time()-s))

    n_captions_per_image = 10

    captions_ids_train, captions_ids_test = train_captions_ids, test_captions_ids
    images_train, images_test = train_image, test_image
    if need_256:
        images_train_256, images_test_256 = train_image_256, test_image_256
    n_images_train = len(images_train)
    n_images_test = len(images_test)
    n_captions_train = len(captions_ids_train)
    n_captions_test = len(captions_ids_test)
    print("n_images_train:%d n_captions_train:%d" % (n_images_train, n_captions_train))
    print("n_images_test:%d  n_captions_test:%d" % (n_images_test, n_captions_test))

zsdonghao avatar Apr 25 '17 21:04 zsdonghao