lightning-flash
lightning-flash copied to clipboard
refactor/fix default_uncollate
What does this PR do?
I found there is a problem if the meta-data included size as a tuple which is an argument for upscaling predictions as images after image segmentation so start playing around and do not claim that this is the right solution... :)
>>> import torch
>>> from pprint import pprint
>>> batch = {"input": torch.zeros([5, 3, 224, 224]), "target": torch.zeros([5, 3, 224, 224]),
... "metadata": {
... 'size': [torch.tensor([266, 266, 266, 266, 266]), torch.tensor([266, 266, 266, 266, 266])],
... 'height': torch.tensor([266, 266, 266, 266, 266]),
... 'width': torch.tensor([266, 266, 266, 266, 266])
... }}
>>> bbatch = default_uncollate(batch)
>>> len(bbatch)
5
>>> print(bbatch[0].keys())
dict_keys(['input', 'target', 'metadata'])
>>> print(bbatch[0]["input"].size(), bbatch[0]["target"].size())
torch.Size([3, 224, 224]) torch.Size([3, 224, 224])
>>> pprint(bbatch[0]["metadata"])
{'height': tensor(266),
'size': (tensor(266), tensor(266)),
'width': tensor(266)}
Before submitting
- [ ] Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
- [x] Did you read the contributor guideline, Pull Request section?
- [x] Did you make sure your PR does only one thing, instead of bundling different changes together?
- [x] Did you make sure to update the documentation with your changes?
- [ ] Did you write any new necessary tests? [not needed for typos/docs]
- [ ] Did you verify new and existing tests pass locally with your changes?
- [ ] If you made a notable change (that affects users), did you update the CHANGELOG?
PR review
- [ ] Is this pull request ready for review? (if not, please submit in draft mode)
Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃
Codecov Report
Merging #1347 (96a8cf0) into master (e7adb6a) will not change coverage. The diff coverage is
100.00%.
@@ Coverage Diff @@
## master #1347 +/- ##
=======================================
Coverage 92.28% 92.28%
=======================================
Files 287 287
Lines 13050 13050
=======================================
Hits 12043 12043
Misses 1007 1007
| Flag | Coverage Δ | |
|---|---|---|
| pytest | 12.12% <33.33%> (ø) |
|
| tpu | 12.12% <33.33%> (ø) |
|
| unittests | 92.83% <100.00%> (ø) |
Flags with carried forward coverage won't be shown. Click here to find out more.
| Impacted Files | Coverage Δ | |
|---|---|---|
| flash/core/data/batch.py | 96.55% <100.00%> (ø) |
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.
The recent commits should fix the tests (if not, I'll fix them further), also I improved the error message to print the type of the input type with the ValueError:
>>> default_uncollate({"preds"})
# Error below:
71 if isinstance(batch, (list, tuple, Tensor)):
72 return list(batch)
---> 73 raise ValueError(
74 "The batch of outputs to be uncollated is expected to be a `dict` or list-like "
75 f"(e.g. `torch.Tensor`, `list`, `tuple`, etc.), but got input of type: {type(batch)}"
76 )
ValueError: The batch of outputs to be uncollated is expected to be a `dict` or list-like (e.g. `torch.Tensor`, `list`, `tuple`, etc.), but got input of type: <class 'set'>