CRAFT-pytorch icon indicating copy to clipboard operation
CRAFT-pytorch copied to clipboard

Gaussian heatmap ?

Open tinhchuquang opened this issue 5 years ago • 36 comments

In paper, They create Ground Truth Label use Gaussian heatmap by other application. Can you show me algorithm create Gaussian heatmap? Thanks

tinhchuquang avatar Jun 21 '19 04:06 tinhchuquang

You can use a standard normal distribution to calculate the probability associated with any pixel as a function of pixel's distance from the center.

import cv2
import numpy as np
from math import exp

# Probability as a function of distance from the center derived
# from a gaussian distribution with mean = 0 and stdv = 1
scaledGaussian = lambda x : exp(-(1/2)*(x**2))


imgSize = 512
isotropicGrayscaleImage = np.zeros((imgSize,imgSize),np.uint8)

for i in range(imgSize):
  for j in range(imgSize):

    # find euclidian distance from center of image (imgSize/2,imgSize/2) 
    # and scale it to range of 0 to 2.5 as scaled Gaussian
    # returns highest probability for x=0 and approximately
    # zero probability for x > 2.5

    distanceFromCenter = np.linalg.norm(np.array([i-imgSize/2,j-imgSize/2]))
    distanceFromCenter = 2.5*distanceFromCenter/(imgSize/2)
    scaledGaussianProb = scaledGaussian(distanceFromCenter)
    isotropicGrayscaleImage[i,j] = np.clip(scaledGaussianProb*255,0,255)

# Convert Grayscale to HeatMap Using Opencv
isotropicGaussianHeatmapImage = cv2.applyColorMap(isotropicGrayscaleImage, 
                                                  cv2.COLORMAP_JET)

You can find a more intuitive implementation here IsotropiceGaussianMap Implementation using python

hetul-patel avatar Jun 21 '19 19:06 hetul-patel

First, Thanks show for I GaussisionHeatmap. I use it for character OK Heatmap But I use with link word (affinity box) to create Heatmap, it doesn't work, Becase I don't know end of word . Example Heatmap_word Can you show your Solution? Thanks

tinhchuquang avatar Jun 25 '19 05:06 tinhchuquang

A working example of creating the Gaussian heat map with perspective transform


UPDATE 27-06-19 Added functionality to incorporate out of bound character bbox

UPDATE 22-08-19 Checking if character bbox is valid using shapely.geometry.Polygon

from torch.utils import data
import matplotlib.pyplot as plt
import numpy as np
import cv2
from shapely.geometry import Polygon

DEBUG = True


def four_point_transform(image, pts):

	max_x, max_y = np.max(pts[:, 0]).astype(np.int32), np.max(pts[:, 1]).astype(np.int32)

	dst = np.array([
		[0, 0],
		[image.shape[1] - 1, 0],
		[image.shape[1] - 1, image.shape[0] - 1],
		[0, image.shape[0] - 1]], dtype="float32")

	M = cv2.getPerspectiveTransform(dst, pts)
	warped = cv2.warpPerspective(image, M, (max_x, max_y))

	return warped


class DataLoader(data.Dataset):

	def __init__(self, type_):

		self.type_ = type_
		self.base_path = '<Path for Images>'
		if DEBUG:
			import os
			if not os.path.exists('cache.pkl'):
				with open('cache.pkl', 'wb') as f:
					import pickle
					from scipy.io import loadmat
					mat = loadmat('Path for gt.mat')
					pickle.dump([mat['imnames'][0][0:1000], mat['charBB'][0][0:1000], mat['txt'][0][0:1000]], f)
					print('Created the pickle file, rerun the program')
					exit(0)
			else:
				with open('cache.pkl', 'rb') as f:
					import pickle
					self.imnames, self.charBB, self.txt = pickle.load(f)
					print('Loaded DEBUG')

		else:

			from scipy.io import loadmat
			mat = loadmat('Path for gt.mat')

			total_number = mat['imnames'][0].shape[0]
			train_images = int(total_number * 0.9)

			if self.type_ == 'train':

				self.imnames = mat['imnames'][0][0:train_images]
				self.charBB = mat['charBB'][0][0:train_images]  # number of images, 2, 4, num_character

			else:

				self.imnames = mat['imnames'][0][train_images:]
				self.charBB = mat['charBB'][0][train_images:]  # number of images, 2, 4, num_character

		for no, i in enumerate(self.txt):
			all_words = []
			for j in i:
				all_words += [k for k in ' '.join(j.split('\n')).split() if k!='']
			self.txt[no] = all_words

		sigma = 10
		spread = 3
		extent = int(spread * sigma)
		self.gaussian_heatmap = np.zeros([2 * extent, 2 * extent], dtype=np.float32)

		for i in range(2 * extent):
			for j in range(2 * extent):
				self.gaussian_heatmap[i, j] = 1 / 2 / np.pi / (sigma ** 2) * np.exp(
					-1 / 2 * ((i - spread * sigma - 0.5) ** 2 + (j - spread * sigma - 0.5) ** 2) / (sigma ** 2))

		self.gaussian_heatmap = (self.gaussian_heatmap / np.max(self.gaussian_heatmap) * 255).astype(np.uint8)

	def add_character(self, image, bbox):

		if not Polygon(bbox.reshape([4, 2]).astype(np.int32)).is_valid:
			return image
		top_left = np.array([np.min(bbox[:, 0]), np.min(bbox[:, 1])]).astype(np.int32)
		if top_left[1] > image.shape[0] or top_left[0] > image.shape[1]:
			# This means there is some bug in the character bbox
			# Will have to look into more depth to understand this
			return image
		bbox -= top_left[None, :]
		transformed = four_point_transform(self.gaussian_heatmap.copy(), bbox.astype(np.float32))

		start_row = max(top_left[1], 0) - top_left[1]
		start_col = max(top_left[0], 0) - top_left[0]
		end_row = min(top_left[1]+transformed.shape[0], image.shape[0])
		end_col = min(top_left[0]+transformed.shape[1], image.shape[1])

		image[max(top_left[1], 0):end_row, max(top_left[0], 0):end_col] += transformed[start_row:end_row - top_left[1], start_col:end_col - top_left[0]]

		return image

	def generate_target(self, image_size, character_bbox):

		character_bbox = character_bbox.transpose(2, 1, 0)

		channel, height, width = image_size

		target = np.zeros([height, width], dtype=np.uint8)

		for i in range(character_bbox.shape[0]):

			target = self.add_character(target, character_bbox[i])

		return target/255, np.float32(target != 0)

	def add_affinity(self, image, bbox_1, bbox_2):

		center_1, center_2 = np.mean(bbox_1, axis=0), np.mean(bbox_2, axis=0)
		tl = np.mean([bbox_1[0], bbox_1[1], center_1], axis=0)
		bl = np.mean([bbox_1[2], bbox_1[3], center_1], axis=0)
		tr = np.mean([bbox_2[0], bbox_2[1], center_2], axis=0)
		br = np.mean([bbox_2[2], bbox_2[3], center_2], axis=0)

		affinity = np.array([tl, tr, br, bl])

		return self.add_character(image, affinity)

	def generate_affinity(self, image_size, character_bbox, text):

		"""

		:param image_size: shape = [3, image_height, image_width]
		:param character_bbox: [2, 4, num_characters]
		:param text: [num_words]
		:return:
		"""

		character_bbox = character_bbox.transpose(2, 1, 0)

		channel, height, width = image_size

		target = np.zeros([height, width], dtype=np.uint8)

		total_letters = 0

		for word in text:
			for char_num in range(len(word)-1):
				target = self.add_affinity(target, character_bbox[total_letters].copy(), character_bbox[total_letters+1].copy())
				total_letters += 1
			total_letters += 1

		return target / 255, np.float32(target != 0)

	def __getitem__(self, item):

		image = plt.imread(self.base_path+'/'+self.imnames[item][0]).transpose(2, 0, 1)/255
		weight, target = self.generate_target(image.shape, self.charBB[item].copy())
		weight_affinity, target_affinity = self.generate_affinity(image.shape, self.charBB[item].copy(), self.txt[item].copy())

		return image, weight, target, weight_affinity, target_affinity

	def __len__(self):

		return len(self.imnames)


if __name__ == "__main__":

	dataloader = DataLoader('train')
	image, weight, target, weight_affinity, target_affinity = dataloader[0]

	plt.imsave('image.png', image.transpose(1, 2, 0))
	plt.imsave('target.png', target)
	plt.imsave('weight.png', weight)
	plt.imsave('weight_affinity.png', weight_affinity)
	plt.imsave('target_affinity.png', target_affinity)
	plt.imsave('together.png', np.concatenate([weight[:, :, None], weight_affinity[:, :, None], np.zeros_like(weight)[:, :, None]], axis=2))

Reference Code - https://www.pyimagesearch.com/2014/08/25/4-point-opencv-getperspective-transform-example/

Do point me out if there is a bug, I will try my best to address it.

mayank-git-hub avatar Jun 26 '19 14:06 mayank-git-hub

A working example of creating the Gaussian heat map with perspective transform

from torch.utils import data
import matplotlib.pyplot as plt
import numpy as np
import cv2

DEBUG = True


def four_point_transform(image, pts):

	max_x, max_y = np.max(pts[:, 0]).astype(np.int32), np.max(pts[:, 1]).astype(np.int32)

	dst = np.array([
		[0, 0],
		[image.shape[1] - 1, 0],
		[image.shape[1] - 1, image.shape[0] - 1],
		[0, image.shape[0] - 1]], dtype="float32")

	M = cv2.getPerspectiveTransform(dst, pts)
	warped = cv2.warpPerspective(image, M, (max_x, max_y))

	return warped


class DataLoader(data.Dataset):

	def __init__(self, type_):

		self.type_ = type_
		self.base_path = '<Path for Images>'
		if DEBUG:
			import os
			if not os.path.exists('cache.pkl'):
				with open('cache.pkl', 'wb') as f:
					import pickle
					from scipy.io import loadmat
					mat = loadmat('Path for gt.mat')
					pickle.dump([mat['imnames'][0][0:1000], mat['charBB'][0][0:1000], mat['txt'][0][0:1000]], f)
					print('Created the pickle file, rerun the program')
					exit(0)
			else:
				with open('cache.pkl', 'rb') as f:
					import pickle
					self.imnames, self.charBB, self.txt = pickle.load(f)
					print('Loaded DEBUG')

		else:

			from scipy.io import loadmat
			mat = loadmat('Path for gt.mat')

			total_number = mat['imnames'][0].shape[0]
			train_images = int(total_number * 0.9)

			if self.type_ == 'train':

				self.imnames = mat['imnames'][0][0:train_images]
				self.charBB = mat['charBB'][0][0:train_images]  # number of images, 2, 4, num_character

			else:

				self.imnames = mat['imnames'][0][train_images:]
				self.charBB = mat['charBB'][0][train_images:]  # number of images, 2, 4, num_character

		for no, i in enumerate(self.txt):
			all_words = []
			for j in i:
				all_words += [k for k in ' '.join(j.split('\n')).split() if k!='']
			self.txt[no] = all_words

		sigma = 10
		spread = 3
		extent = int(spread * sigma)
		self.gaussian_heatmap = np.zeros([2 * extent, 2 * extent], dtype=np.float32)

		for i in range(2 * extent):
			for j in range(2 * extent):
				self.gaussian_heatmap[i, j] = 1 / 2 / np.pi / (sigma ** 2) * np.exp(
					-1 / 2 * ((i - spread * sigma - 0.5) ** 2 + (j - spread * sigma - 0.5) ** 2) / (sigma ** 2))

		self.gaussian_heatmap = (self.gaussian_heatmap / np.max(self.gaussian_heatmap) * 255).astype(np.uint8)

	def add_character(self, image, bbox):

		top_left = np.array([np.min(bbox[:, 0]), np.min(bbox[:, 1])]).astype(np.int32)
		bbox -= top_left[None, :]
		transformed = four_point_transform(self.gaussian_heatmap.copy(), bbox.astype(np.float32))
		image[top_left[1]:top_left[1]+transformed.shape[0], top_left[0]:top_left[0]+transformed.shape[1]] += transformed
		return image

	def generate_target(self, image_size, character_bbox):

		character_bbox = character_bbox.transpose(2, 1, 0)

		channel, height, width = image_size

		target = np.zeros([height, width], dtype=np.uint8)

		for i in range(character_bbox.shape[0]):

			target = self.add_character(target, character_bbox[i])

		return target/255, np.float32(target != 0)

	def add_affinity(self, image, bbox_1, bbox_2):

		center_1, center_2 = np.mean(bbox_1, axis=0), np.mean(bbox_2, axis=0)
		tl = np.mean([bbox_1[0], bbox_1[1], center_1], axis=0)
		bl = np.mean([bbox_1[2], bbox_1[3], center_1], axis=0)
		tr = np.mean([bbox_2[0], bbox_2[1], center_2], axis=0)
		br = np.mean([bbox_2[2], bbox_2[3], center_2], axis=0)

		affinity = np.array([tl, tr, br, bl])

		return self.add_character(image, affinity)

	def generate_affinity(self, image_size, character_bbox, text):

		"""

		:param image_size: shape = [3, image_height, image_width]
		:param character_bbox: [2, 4, num_characters]
		:param text: [num_words]
		:return:
		"""

		character_bbox = character_bbox.transpose(2, 1, 0)

		channel, height, width = image_size

		target = np.zeros([height, width], dtype=np.uint8)

		total_letters = 0

		for word in text:
			for char_num in range(len(word)-1):
				target = self.add_affinity(target, character_bbox[total_letters], character_bbox[total_letters+1])
				total_letters += 1
			total_letters += 1

		return target / 255, np.float32(target != 0)

	def __getitem__(self, item):

		image = plt.imread(self.base_path+'/'+self.imnames[item][0]).transpose(2, 0, 1)/255
		weight, target = self.generate_target(image.shape, self.charBB[item].copy())
		weight_affinity, target_affinity = self.generate_affinity(image.shape, self.charBB[item].copy(), self.txt[item].copy())

		return image, weight, target, weight_affinity, target_affinity

	def __len__(self):

		return len(self.imnames)


if __name__ == "__main__":

	dataloader = DataLoader('train')
	image, weight, target, weight_affinity, target_affinity = dataloader[0]

	plt.imsave('image.png', image.transpose(1, 2, 0))
	plt.imsave('target.png', target)
	plt.imsave('weight.png', weight)
	plt.imsave('weight_affinity.png', weight_affinity)
	plt.imsave('target_affinity.png', target_affinity)
	plt.imsave('together.png', np.concatenate([weight[:, :, None], weight_affinity[:, :, None], np.zeros_like(weight)[:, :, None]], axis=2))

Reference Code - https://www.pyimagesearch.com/2014/08/25/4-point-opencv-getperspective-transform-example/

Do point me out if there is a bug, I will try my best to address it.

In function "add_char" has bug, it is "operands could not be broadcast together with shapes (20,0) (20,29) (20,0)" This code for test dataloader def load_data(): dataloader = DataLoader('train') trainloader = torch.utils.data.DataLoader(dataloader, batch_size = 1, shuffle=True, num_workers=8) for batch_idx, (image, weight, target, weight_affinity, target_affinity) in enumerate(trainloader): print(batch_idx, ' -- ', image.shape, '--', weight.shape, '--', weight_affinity.shape)

tinhchuquang avatar Jun 27 '19 08:06 tinhchuquang

A working example of creating the Gaussian heat map with perspective transform

from torch.utils import data
import matplotlib.pyplot as plt
import numpy as np
import cv2

DEBUG = True


def four_point_transform(image, pts):

	max_x, max_y = np.max(pts[:, 0]).astype(np.int32), np.max(pts[:, 1]).astype(np.int32)

	dst = np.array([
		[0, 0],
		[image.shape[1] - 1, 0],
		[image.shape[1] - 1, image.shape[0] - 1],
		[0, image.shape[0] - 1]], dtype="float32")

	M = cv2.getPerspectiveTransform(dst, pts)
	warped = cv2.warpPerspective(image, M, (max_x, max_y))

	return warped


class DataLoader(data.Dataset):

	def __init__(self, type_):

		self.type_ = type_
		self.base_path = '<Path for Images>'
		if DEBUG:
			import os
			if not os.path.exists('cache.pkl'):
				with open('cache.pkl', 'wb') as f:
					import pickle
					from scipy.io import loadmat
					mat = loadmat('Path for gt.mat')
					pickle.dump([mat['imnames'][0][0:1000], mat['charBB'][0][0:1000], mat['txt'][0][0:1000]], f)
					print('Created the pickle file, rerun the program')
					exit(0)
			else:
				with open('cache.pkl', 'rb') as f:
					import pickle
					self.imnames, self.charBB, self.txt = pickle.load(f)
					print('Loaded DEBUG')

		else:

			from scipy.io import loadmat
			mat = loadmat('Path for gt.mat')

			total_number = mat['imnames'][0].shape[0]
			train_images = int(total_number * 0.9)

			if self.type_ == 'train':

				self.imnames = mat['imnames'][0][0:train_images]
				self.charBB = mat['charBB'][0][0:train_images]  # number of images, 2, 4, num_character

			else:

				self.imnames = mat['imnames'][0][train_images:]
				self.charBB = mat['charBB'][0][train_images:]  # number of images, 2, 4, num_character

		for no, i in enumerate(self.txt):
			all_words = []
			for j in i:
				all_words += [k for k in ' '.join(j.split('\n')).split() if k!='']
			self.txt[no] = all_words

		sigma = 10
		spread = 3
		extent = int(spread * sigma)
		self.gaussian_heatmap = np.zeros([2 * extent, 2 * extent], dtype=np.float32)

		for i in range(2 * extent):
			for j in range(2 * extent):
				self.gaussian_heatmap[i, j] = 1 / 2 / np.pi / (sigma ** 2) * np.exp(
					-1 / 2 * ((i - spread * sigma - 0.5) ** 2 + (j - spread * sigma - 0.5) ** 2) / (sigma ** 2))

		self.gaussian_heatmap = (self.gaussian_heatmap / np.max(self.gaussian_heatmap) * 255).astype(np.uint8)

	def add_character(self, image, bbox):

		top_left = np.array([np.min(bbox[:, 0]), np.min(bbox[:, 1])]).astype(np.int32)
		bbox -= top_left[None, :]
		transformed = four_point_transform(self.gaussian_heatmap.copy(), bbox.astype(np.float32))
		image[top_left[1]:top_left[1]+transformed.shape[0], top_left[0]:top_left[0]+transformed.shape[1]] += transformed
		return image

	def generate_target(self, image_size, character_bbox):

		character_bbox = character_bbox.transpose(2, 1, 0)

		channel, height, width = image_size

		target = np.zeros([height, width], dtype=np.uint8)

		for i in range(character_bbox.shape[0]):

			target = self.add_character(target, character_bbox[i])

		return target/255, np.float32(target != 0)

	def add_affinity(self, image, bbox_1, bbox_2):

		center_1, center_2 = np.mean(bbox_1, axis=0), np.mean(bbox_2, axis=0)
		tl = np.mean([bbox_1[0], bbox_1[1], center_1], axis=0)
		bl = np.mean([bbox_1[2], bbox_1[3], center_1], axis=0)
		tr = np.mean([bbox_2[0], bbox_2[1], center_2], axis=0)
		br = np.mean([bbox_2[2], bbox_2[3], center_2], axis=0)

		affinity = np.array([tl, tr, br, bl])

		return self.add_character(image, affinity)

	def generate_affinity(self, image_size, character_bbox, text):

		"""

		:param image_size: shape = [3, image_height, image_width]
		:param character_bbox: [2, 4, num_characters]
		:param text: [num_words]
		:return:
		"""

		character_bbox = character_bbox.transpose(2, 1, 0)

		channel, height, width = image_size

		target = np.zeros([height, width], dtype=np.uint8)

		total_letters = 0

		for word in text:
			for char_num in range(len(word)-1):
				target = self.add_affinity(target, character_bbox[total_letters], character_bbox[total_letters+1])
				total_letters += 1
			total_letters += 1

		return target / 255, np.float32(target != 0)

	def __getitem__(self, item):

		image = plt.imread(self.base_path+'/'+self.imnames[item][0]).transpose(2, 0, 1)/255
		weight, target = self.generate_target(image.shape, self.charBB[item].copy())
		weight_affinity, target_affinity = self.generate_affinity(image.shape, self.charBB[item].copy(), self.txt[item].copy())

		return image, weight, target, weight_affinity, target_affinity

	def __len__(self):

		return len(self.imnames)


if __name__ == "__main__":

	dataloader = DataLoader('train')
	image, weight, target, weight_affinity, target_affinity = dataloader[0]

	plt.imsave('image.png', image.transpose(1, 2, 0))
	plt.imsave('target.png', target)
	plt.imsave('weight.png', weight)
	plt.imsave('weight_affinity.png', weight_affinity)
	plt.imsave('target_affinity.png', target_affinity)
	plt.imsave('together.png', np.concatenate([weight[:, :, None], weight_affinity[:, :, None], np.zeros_like(weight)[:, :, None]], axis=2))

Reference Code - https://www.pyimagesearch.com/2014/08/25/4-point-opencv-getperspective-transform-example/ Do point me out if there is a bug, I will try my best to address it.

In function "add_char" has bug, it is "operands could not be broadcast together with shapes (20,0) (20,29) (20,0)" This code for test dataloader def load_data(): dataloader = DataLoader('train') trainloader = torch.utils.data.DataLoader(dataloader, batch_size = 1, shuffle=True, num_workers=8) for batch_idx, (image, weight, target, weight_affinity, target_affinity) in enumerate(trainloader): print(batch_idx, ' -- ', image.shape, '--', weight.shape, '--', weight_affinity.shape)

Sometimes the co-ordinates of the character boxes are outside the image dimensions. In those cases, this error is being generated, I will try to update the code to incorporate the out of image dimension case, till then you can add an if else block to discard those character bboxs which have values greater than the image dimension or less than 0.

mayank-git-hub avatar Jun 27 '19 11:06 mayank-git-hub

if np.any(bbox < 0) or np.any(bbox[:, 0] > image.shape[1]) or np.any(bbox[:, 1] > image.shape[0]):
			return image

You can add this line in the add_character function

mayank-git-hub avatar Jun 27 '19 12:06 mayank-git-hub

def add_character(self, image, bbox):

		top_left = np.array([np.min(bbox[:, 0]), np.min(bbox[:, 1])]).astype(np.int32)
		if top_left[1] > image.shape[0] or top_left[0] > image.shape[1]:
			# This means there is some bug in the character bbox
			# Will have to look into more depth to understand this
			return image
		bbox -= top_left[None, :]
		transformed = four_point_transform(self.gaussian_heatmap.copy(), bbox.astype(np.float32))

		start_row = max(top_left[1], 0) - top_left[1]
		start_col = max(top_left[0], 0) - top_left[0]
		end_row = min(top_left[1]+transformed.shape[0], image.shape[0])
		end_col = min(top_left[0]+transformed.shape[1], image.shape[1])

		image[max(top_left[1], 0):end_row, max(top_left[0], 0):end_col] += transformed[start_row:end_row - top_left[1], start_col:end_col - top_left[0]]

		return image

I have made these changes to the code and am not getting the error, I hope this resolves your error too.

mayank-git-hub avatar Jun 27 '19 12:06 mayank-git-hub

Thanks pro. Code work for me is:

def add_character(self, image, bbox):
    if np.any(bbox < 0) or np.any(bbox[:, 0] > image.shape[1]) or np.any(bbox[:, 1] > image.shape[0]):
        return image

    top_left = np.array([np.min(bbox[:, 0]), np.min(bbox[:, 1])]).astype(np.int32)
    bbox -= top_left[None, :]
    transformed = four_point_transform(self.gaussian_heatmap.copy(), bbox.astype(np.float32))

    start_row = max(top_left[1], 0) - top_left[1]
    start_col = max(top_left[0], 0) - top_left[0]
    end_row = min(top_left[1] + transformed.shape[0], image.shape[0])
    end_col = min(top_left[0] + transformed.shape[1], image.shape[1])

    image[max(top_left[1], 0):end_row, max(top_left[0], 0):end_col] += transformed[start_row:end_row - top_left[1],
                                                                       start_col:end_col - top_left[0]]

    return image

Base your gaussion heatmap. I implement paper. Thanks much.

tinhchuquang avatar Jul 03 '19 10:07 tinhchuquang

Thanks, @mayank-git-hub for your explanation of Gaussian heatmap.

Some comment about SynthText dataset is that some trascription are incorrectly labeled in SynthText. It does not guarantee that a word box corresponds to a text trancription one-to-one. Using the code below to separate the transcrition, you will obtain the start and end points of the transcription correponding to its word box.

import re
import itertools

texts = [re.split(' \n|\n |\n| ',t.strip()) for t in texts]
texts = list(itertools.chain(*texts))
texts = [t for t in texts if len(t)>0]

YoungminBaek avatar Jul 07 '19 13:07 YoungminBaek

Welcome @YoungminBaek , @tinhchuquang !

@YoungminBaek isn't my code for creating word start and end points the same as yours in functionality?(Though your seems a bit cleaner!)

for no, i in enumerate(self.txt):
	all_words = []
	for j in i:
		all_words += [k for k in ' '.join(j.split('\n')).split() if k!='']
	self.txt[no] = all_words

mayank-git-hub avatar Jul 07 '19 13:07 mayank-git-hub

@mayank-git-hub Oh, your code already has that functionality. I missed it. Please forget about my previous comment. :)

YoungminBaek avatar Jul 12 '19 15:07 YoungminBaek

HI! When there are more than one character in an image,is one charactor bounding box in mat['charBB'] is xmin ymin xmax ymax, as a vector or all the character in this image xmin...ymax as a vector. I mean one line of mat[charBB][0] is [xmin, ymin, xmax, ymax] or [xmin_char1,...,ymax_char1, xmin_char2,...,ymax_char2,...] if [xmin_char1,...,ymax_char1, xmin_char2,...,ymax_char2,...], how to cancatnate multiple vectors which are diffrent in their length.

namedysx avatar Jul 29 '19 01:07 namedysx

I want to transfrom my dataset from format .txt to .mat

namedysx avatar Jul 29 '19 01:07 namedysx

@namedysx You want create file .mat same SynthText dataset, mat['charBB'][0] has shape is [2, 4, 6], mat['charBB'][1] is [2, 4, 10] ?

tinhchuquang avatar Jul 29 '19 04:07 tinhchuquang

@namedysx You want create file .mat same SynthText dataset, mat['charBB'][0] has shape is [2, 4, 6], mat['charBB'][1] is [2, 4, 10] ? thanks for ur reply. yes, and what does each channel mean?

namedysx avatar Jul 30 '19 00:07 namedysx

My way is create matrix no same shape:

import numpy as np
def create_word():
    char_bb = []
    length_word = np.random.randint(1, 4) # random length word
    for i in range(length_word):
    # add 4 point of one char
          char_bb.append([np.random.randint(1, 255), np.random.randint(1, 255)])
          char_bb.append([np.random.randint(1, 255), np.random.randint(1, 255)])
          char_bb.append([np.random.randint(1, 255), np.random.randint(1, 255)])
          char_bb.append([np.random.randint(1, 255), np.random.randint(1, 255)])
    return np.array(char_bb)

char_bbs = []

char_bbs.append(create_word())
char_bbs.append(create_word())

char_bbs = np.array(char_bbs, ndmin=2) # shape = (1 , lenght(charbbs))
for i in range(char_bbs.shape[1]):
    char_bbs[0][i] = char_bbs[0][i].reshape(-1, 4, 2).transpose(2, 1, 0)

print(char_bbs.shape)
print(char_bbs[0][0].shape, char_bbs[0][1].shape)

tinhchuquang avatar Jul 30 '19 02:07 tinhchuquang

What's the max and min value of the output gassian map score text and score link?

brooklyn1900 avatar Aug 15 '19 08:08 brooklyn1900

I want to transfrom my dataset from format .txt to .mat

Hi, Can you share your code about txt to max?Thanks a lot!

jjprincess avatar Aug 21 '19 03:08 jjprincess

@mayank-git-hub
How appropriate is the value of spread? I think 1 is the most similar to the picture in the paper. what do you think..? or Is there a way to get the normal distribution tightly in the box?

hanish3464 avatar Aug 22 '19 05:08 hanish3464

@mayank-git-hub How appropriate is the value of spread? I think 1 is the most similar to the picture in the paper. what do you think..? or Is there a way to get the normal distribution tightly in the box?

I thought that if the affinity and the character bbox should overlap, the spread should be large so there is less chance of the word being broken. That is why I kept the spread large. I am implementing the craft training procedure, would try out will smaller values of spread and check the results. Thanks for bringing to my mind that the spread can also be a hyper-parameter to consider.

By theory the normal distribution actually should spread infinitely, but due to floating point limits it actually gets limited. You can play with the spread and configure it to the edge values you want.

mayank-git-hub avatar Aug 22 '19 05:08 mayank-git-hub

Also the code I gave above (https://github.com/clovaai/CRAFT-pytorch/issues/3#issuecomment-505903264) breaks down if the character quadrilateral is not valid. To check you can use - from shapely.geometry import Polygon Polygon(bbox.reshape([4, 2]).astype(np.int32)).is_valid: If this is False, then no need to add that character bbox

mayank-git-hub avatar Aug 22 '19 06:08 mayank-git-hub

@mayank-git-hub How appropriate is the value of spread? I think 1 is the most similar to the picture in the paper. what do you think..? or Is there a way to get the normal distribution tightly in the box?

sigma = 10
spread = 3
extent = int(spread * sigma)
center = spread * sigma / 2
gaussian_heatmap = np.zeros([extent, extent], dtype=np.float32)

for i_ in range(extent):
    for j_ in range(extent):
	    gaussian_heatmap[i_, j_] = 1 / 2 / np.pi / (sigma ** 2) * np.exp(
		-1 / 2 * ((i_ - center - 0.5) ** 2 + (j_ - center - 0.5) ** 2) / (sigma ** 2))

gaussian_heatmap = (gaussian_heatmap / np.max(gaussian_heatmap) * 255).astype(np.uint8)

This seems to work good for me.

mayank-git-hub avatar Aug 24 '19 12:08 mayank-git-hub

Hi @mayank-git-hub , this gausian heat map use to gen Region score and Affinity and we use it for training right ? But as i know the gausian heatmap just the same with every char, just have diffirent about the size of annotation for each char, so why we should use it for traning ?

i use this code for gen heatmap: in this example 1 image is 1 char

> import cv2
> import numpy as np
> from math import exp
> import matplotlib.pyplot as plt
> 
> # Probability as a function of distance from the center derived
> # from a gaussian distribution with mean = 0 and stdv = 1
> scaledGaussian = lambda x : exp(-(1/2)*(x**2))
> 
> image = cv2.imread('./test_imgs/3.jpg')
> h,w,c = image.shape
> isotropicGrayscaleImage = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
> # isotropicGrayscaleImage = np.zeros((imgSize,imgSize),np.uint8)
> 
> for i in range(h):
>   for j in range(w):
> 
>     # find euclidian distance from center of image (imgSize/2,imgSize/2) 
>     # and scale it to range of 0 to 2.5 as scaled Gaussian
>     # returns highest probability for x=0 and approximately
>     # zero probability for x > 2.5
> 
>     distanceFromCenter = np.linalg.norm(np.array([i-h/2,j-w/2]))
>     distanceFromCenter = 2.5*distanceFromCenter/(h/2)
>     scaledGaussianProb = scaledGaussian(distanceFromCenter)
>     isotropicGrayscaleImage[i,j] = np.clip(scaledGaussianProb*255,0,255)
> 
> # Convert Grayscale to HeatMap Using Opencv
> isotropicGaussianHeatmapImage = cv2.applyColorMap(isotropicGrayscaleImage, 
>                                                   cv2.COLORMAP_JET)
> 
> plt.imshow(cv2.cvtColor(isotropicGaussianHeatmapImage,4))
> plt.show()

uname0x96 avatar Sep 01 '19 11:09 uname0x96

You are assuming that the character bbox would always be horizontal which is not the case.

Also, if you create the Gaussian heatmap everytime for a new character, it would be very computationally costly.

So for generating skewed Gaussian heatmap with less computation time, you could generate a template and do perspective transformation for it as mentioned by the author in the paper.

mayank-git-hub avatar Sep 01 '19 11:09 mayank-git-hub

@mayank-git-hub I have a question, your code about four_point_transform! This paper, warped isotropic 2d gaussian in skewd box. but your code seems warp isotropic 2d gaussian after box makes isotropic. please let me know if I misunderstood

hanish3464 avatar Sep 02 '19 02:09 hanish3464

@hanish3464 You could try this explanation of the perspective transform.

https://www.pyimagesearch.com/2014/08/25/4-point-opencv-getperspective-transform-example/

They are trying to bring a skewed image back to horizontal rectangle using four_point_transform, while I am doing the opposite.

mayank-git-hub avatar Sep 02 '19 04:09 mayank-git-hub

@hanish3464 You could try this explanation of the perspective transform.

https://www.pyimagesearch.com/2014/08/25/4-point-opencv-getperspective-transform-example/

They are trying to bring a skewed image back to horizontal rectangle using four_point_transform, while I am doing the opposite.

it mean you are using polygon annotation instead of rectangle?

uname0x96 avatar Sep 03 '19 09:09 uname0x96

Quadrilateral instead of rectangle, not polygon

mayank-git-hub avatar Sep 04 '19 10:09 mayank-git-hub

@mayank-git-hub hmm the follow you export heatmap is bellow ? Step1: read annotation get position for each char Step2: crop image with the point from step1 and skew image to horizontal rectangle Step3: convert img step2 -> gausian heatmap Step4: deskew from heatmap horizontal rectangle to origin shape with Quadrilateral Step5: put back the result step4 to original image

it's right bro?

uname0x96 avatar Sep 04 '19 12:09 uname0x96

@mayank-git-hub hmm the follow you export heatmap is bellow ? Step1: read annotation get position for each char Step2: crop image with the point from step1 and skew image to horizontal rectangle Step3: convert img step2 -> gausian heatmap Step4: deskew from heatmap horizontal rectangle to origin shape with Quadrilateral Step5: put back the result step4 to original image

it's right bro?

Umm, not quite, your step 2 seems to be unnecessary.

  1. Create an isotropic square Gaussian heatmap.
  2. Read annotation and get position of each char
  3. Skew square Gaussian heatmap to quadilateral annotation I got from step 2
  4. Add the output of step 3 to the target annotations

mayank-git-hub avatar Sep 04 '19 16:09 mayank-git-hub