[WIP] batching in MultiDiracDeterminant::mw_accept_rejectMove
Proposed changes
Restructure MultiDiracDeterminant::mw_accept_rejectMove to handle all accepted moves together and all rejected moves together (minimize branching).
A first pass through all walker moves is done to get a vector of accepted walker indices and a vector of rejected walker indices, then the necessary data movement is done first for all accepted moves and then for all rejected moves.
What type(s) of changes does this code introduce?
- Refactoring (no functional changes, no api changes)
Does this introduce a breaking change?
- Yes (see TODO below for list of things that still need to be figured out)
What systems has this change been tested on?
Checklist
- No. This PR is up to date with current the current state of 'develop'
- Yes. Code added or changed in the PR has been clang-formatted
- N/A. This PR adds tests to cover any new code, or to catch a bug that is being fixed
- N/A. Documentation has been added (if appropriate)
TODO:
- [ ] figure out data location: is it safe to assume data is up-to-date on the device before this is called? does anything need to be updated on the host before this returns? Do we need separate code depending on value of
ENABLE_OFFLOAD?
There are three different cases for data members of MultiDiracDeterminant that are affected by mw_accept_rejectMove:
The ones marked with (*) are updated on the device in the current code (in MultiDiracDeterminant::acceptMove)
- already included in the MW resource
- psiMinv_temp
- psiV
- dpsiV
- (*) psiMinv
- (*) TpsiM
- (*) psiM
- (*) dpsiM
- not included in the MW resource but are dual-space allocated
- new_ratios_to_ref_
- (*) ratios_to_ref_
- d2psiV
- d2psiM
- not included in the MW resource and not dual-space allocated
- dspin_psiV
- dspin_psiM
- new_grads
- grads
- new_lapls
- lapls
- new_spingrads
- spingrads
In this PR, I currently have it implemented so that anything that is in the MW resource or handled by a dual-space allocator is updated on the device only. Anything that is not dual-space allocated is only updated on the host.
Currently, there is no separation based on anything like ENABLE_OFFLOAD or Phi->isOMPoffload(), but I can do that if needed; I'm not sure what the state of the data is when this function is called or what it's supposed to be when this function returns (I've assumed for now that any dual-space data is up-to-date on the device when this function is called, and I update it on the device before returning; anything that is not dual-space allocated is assumed to be up-to-date on the host when this function is called, and I update it on the host before returning).
If the platform is added as a template parameter, at least the copy_batched calls should be handled cleanly by that.
I assume the pointers from device_data() and from the mw resource collection deviceptr_lists should also be correct already (i.e. if compiled without offload, device_data() and data() should both return the same (host) pointer), but it would be good to have verification from someone more familiar with those parts of the code.
@ye-luo @anbenali if you know what data is supposed to be up-to-date upon entering and exiting this function, I'd appreciate any feedback (if not, then I'll start looking at what happens upstream and downstream from here and check with you to see if it makes sense)
I can also add any of the data in the list above to the resource collection if it makes sense to do that, but that could also be a separate PR.
Test this please
@kgasperich Are you clear on the next steps?
@kgasperich Are you clear on the next steps?
@prckent yes, I talked to Ye a few days ago and I think the next steps are mostly clear. I'm going to resolve some of the data movement issues by more closely mirroring the behavior of the existing accept function, and then performance improvements on top of that (e.g. not doing unnecessary H2D transfers in cases where we can just do a device copy) will be added in a subsequent PR.
I've updated this so that it has the same data movement patterns as the old code, but there are some differences in how it's done compared to the old code.
- I separated accepted and rejected moves so that similar work is all done together (i.e. loop over all accepted then all rejected, rather than loop over all walkers and branch based on accept/reject for each walker)
- I built pointer lists for the data that needs to be copied
- old
- loop over walkers, for each walker, do necessary copies with
std::copyand do H2D transfer
- loop over walkers, for each walker, do necessary copies with
- new:
- loop over accepted walkers, build list of pointers to the start of each set of data elements (psiMinv_temp, psiMinv, psiV, TpsiM, etc.) that need to be copied
- loop over accepted walkers, copy with
BLAS::copyusing pointer lists (can be trivially replaced with copy_batched in subsequent PR when data movement requirements are figured out) - loop over accepted walkers, do H2D transfer
- similar for rejected walkers
- old
todo in subsequent PR:
- change copy loop to copy_batched (if available on platform?)
- consider pre-allocating some work space for acc/rej idx lists and pointer lists
- look into upstream/downstream providers/users of this data to see if we can get away with fewer data transfers
- if quantities like psiMinv_temp, psiV, dpsiV, new_ratios_to_ref, etc are already up to date on device when this is called, then just do the device copy_batched instead of copy on host followed by H2D transfer
- maybe some opportunity for async copy/copy_batched with H2D transfer? (similar to what's described here)
There is a clear need for more refactoring and optimization of this and related functions, and that is why some things are structured the way that they are. I've made some choices (and left in some comments) that I think will make some of the next steps clearer/easier, but if anyone is opposed to that I can tidy it up more.
Test this please
CI passes. @kgasperich could you update the PR description to reflect what the current code does?
CI passes. @kgasperich could you update the PR description to reflect what the current code does?
I edited the description to summarize what it does and add some more info that might be useful in future refactoring
Test this please