dcgan.torch
dcgan.torch copied to clipboard
Handle file errors without crashing
Not all file corpuses are flawless; sometimes files are empty or the suffix doesn't match the format or they get deleted during the run etc. Since dcgan.torch assumes file reads will succeed without any problem, it will crash too if anything is amiss with any of the thousands or millions of files it may read.
This can be fixed by checking reading error status and skipping images that fail with a warning message (and additional verbose option to get the exact filename of the offending file, since getByClass
doesn't propagate its randomly chosen file upwards).
FeepingCreature provided a patch implementing that in data/dataset.lua
, which we've been using without any problem for several days now:
diff --git a/data/dataset.lua b/data/dataset.lua
index 0d39e27..a9d28eb 100644
--- a/data/dataset.lua
+++ b/data/dataset.lua
@@ -232,7 +232,6 @@ function dataset:__init(...)
end
runningIndex = runningIndex + length
end
-
--==========================================================================
-- clean up temporary files
print('Cleaning up temporary files')
@@ -313,6 +312,7 @@ end
function dataset:getByClass(class)
local index = math.ceil(torch.uniform() * self.classListSample[class]:nElement())
local imgpath = ffi.string(torch.data(self.imagePath[self.classListSample[class][index]]))
+ if self.verbose then print('Image path: ' .. imgpath) end
return self:sampleHookTrain(imgpath)
end
@@ -322,7 +322,7 @@ local function tableToOutput(self, dataTable, scalarTable)
local quantity = #scalarTable
assert(dataTable[1]:dim() == 3)
data = torch.Tensor(quantity,
- self.sampleSize[1], self.sampleSize[2], self.sampleSize[3])
+ self.sampleSize[1], self.sampleSize[2], self.sampleSize[3])
scalarLabels = torch.LongTensor(quantity):fill(-1111)
for i=1,#dataTable do
data[i]:copy(dataTable[i])
@@ -336,11 +336,15 @@ function dataset:sample(quantity)
assert(quantity)
local dataTable = {}
local scalarTable = {}
- for i=1,quantity do
+ while table.getn(dataTable)<quantity do
local class = torch.random(1, #self.classes)
- local out = self:getByClass(class)
- table.insert(dataTable, out)
- table.insert(scalarTable, class)
+ local success, out = pcall(function() return self:getByClass(class) end)
+ if success then
+ table.insert(dataTable, out)
+ table.insert(scalarTable, class)
+ else
+ print("failed to get an instance of "..class)
+ end
end
local data, scalarLabels = tableToOutput(self, dataTable, scalarTable)
return data, scalarLabels