TorchFort icon indicating copy to clipboard operation
TorchFort copied to clipboard

Why 4D states in RL interface?

Open terryfrankcombe opened this issue 3 months ago • 20 comments

Firstly, is there an alternative forum for general support? Calling the below a code development "issue" sounds like a category error to me, but I don't see anywhere else to go.

Anyway, I saw somewhere else that only limited state and action dimensions were implemented for RL. Looking in src/fsrc/torchfort_m.F90 I see that there are only two: 4D/4D, and 4D/2D. I feel like I am missing something very basic here, but why these combinations? E.g. why isn't there at least a 1D/1D serialised interface?

terryfrankcombe avatar Sep 12 '25 07:09 terryfrankcombe

Hey Terry ! As far as my experience with TorchFort is concerned, the developers have created a few combinations of dimensions for state, action, reward and so on. Suppose if you want 1D state and 1D action, you need to create specific module procedures, add the newly created module procedure under the interfacing function and call the interfacing function directly in your simulation code. By doing so, the compiler will then automatically select the correct procedure variant for you based on your array dimensions, data type (c_float vs c_double), and target device (CPU or GPU).

SachinBM-CE avatar Sep 12 '25 08:09 SachinBM-CE

❌ Build workflow failed! View run

github-actions[bot] avatar Sep 12 '25 08:09 github-actions[bot]

My progress is slow.

Inside my torchfort_m.F90 I have extended the torchfort_rl_off_policy_predict interface with a version torchfort_rl_off_policy_predict_double_2d_1d that takes a (real64) 2D environment and seeks a (real64) 1D action. torchfort_rl_off_policy_predict_c is called with a TORCHFORT_DOUBLE argument. Code compiles and runs, I'm passing it kind=8 arrays for the environment and action (with the latter a 6 element array).

My config.yaml now defines a policy_model as SACMLP type with layer_sizes: [16384, 1024, 512, 6] and a critic_model that is MLP with layer_sizes: [16390, 1024, 1].

When I call torchfort_rl_off_policy_predict it throws an error:

terminate called after throwing an instance of 'c10::Error' what(): output with shape [6] doesn't match the broadcast shape [2, 6] Exception raised from mark_resize_outputs at /pytorch/aten/src/ATen/TensorIterator.cpp:1213 (most recent call first): frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits, std::allocator >) + 0x98 (0x12d4c81785e8 in /home/terryk/PyTorchEnv/lib/python3.12/site-packages/torch/lib/libc10.so) frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::_cxx11::basic_string<char, std::char_traits, std::allocator > const&) + 0xe0 (0x12d4c810d4a2 in /home/terryk/PyTorchEnv/lib/python3.12/site-packages/torch/lib/libc10.so) frame #2: at::TensorIteratorBase::mark_resize_outputs(at::TensorIteratorConfig const&) + 0x225 (0x12d4a9ac21c5 in /home/terryk/PyTorchEnv/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so) frame #3: at::TensorIteratorBase::build(at::TensorIteratorConfig&) + 0x64 (0x12d4a9ac4eb4 in /home/terryk/PyTorchEnv/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so) frame #4: + 0x1a57ff0 (0x12d4a9e57ff0 in /home/terryk/PyTorchEnv/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so) frame #5: at::native::copy(at::Tensor&, at::Tensor const&, bool) + 0x5c (0x12d4a9e59aac in /home/terryk/PyTorchEnv/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so) ...

The 6 in the shapes suggests that this error is something to do with the action. Does this sound like the problem is in my network definition?

terryfrankcombe avatar Sep 13 '25 12:09 terryfrankcombe

❌ Build workflow failed! View run

github-actions[bot] avatar Sep 13 '25 12:09 github-actions[bot]

As per your comments I assume your state array is 2D and action is 1D. I also assume you have 16384 data samples. Since, Fortran follows column majoring, input arrays to predict function should be of shape (Number of dimensions, Number of Samples). Hence: Shape of state array should be (2, 16384) Shape of action array should be (1, 16384) Layer sizes for policy: [Dimension of State, Neurons in layer1, Neurons in layer2, ... , Dimension of Action] In your case - [2, 1024, 512, 1] Layer sizes for critic: [Dimension of State + Dimension of Action, Neurons in layer1, ... , 1] In your case - [3, 1024, 1] Output of critic will always be 1 as it is the Q-value.

SachinBM-CE avatar Sep 13 '25 13:09 SachinBM-CE

❌ Build workflow failed! View run

github-actions[bot] avatar Sep 13 '25 13:09 github-actions[bot]

I'm a numpty, I changed bits of the code and forgot other bits. :-/

So I've simplified again. My state contains 32768 elements, and my action 6. Now I have a prediction interface in torchfort_m.F90 is function torchfort_rl_off_policy_predict_double_1d_1d(mname, state, act, stream) result(res) with state and act 1D: real(real64) :: state(:), act(:). In the calling routine the state is stored in a 2D array (Nstate,m) but I'm passing a single state W2(:,t) so that should give a contiguous chunk of memory. The six element action is stored similarly.

I now have policy_model as layer_sizes: [32768, 1024, 512, 6] SACMLP and critic_model as layer_sizes: [32774, 1024, 1] MLP.

I have a new error:

terminate called after throwing an instance of 'c10::IndexError' what(): Dimension out of range (expected to be in range of [-1, 0], but got 1) Exception raised from maybe_wrap_dim_slow at /pytorch/c10/core/WrapDimMinimal.cpp:23 (most recent call first): frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits, std::allocator >) + 0x98 (0x7903d3785e8 in /home/terryk/PyTorchEnv/lib/python3.12/site-packages/torch/lib/libc10.so) frame #1: long c10::detail::maybe_wrap_dim_slow(long, long, bool) + 0x27b (0x7903d35de3b in /home/terryk/PyTorchEnv/lib/python3.12/site-packages/torch/lib/libc10.so) frame #2: at::meta::structured_cat::meta(c10::IListRefat::Tensor const&, long) + 0x13e (0x7901f183a9e in /home/terryk/PyTorchEnv/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so) frame #3: + 0x298a534 (0x7901fd8a534 in /home/terryk/PyTorchEnv/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so) frame #4: + 0x298a5c0 (0x7901fd8a5c0 in /home/terryk/PyTorchEnv/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so) frame #5: at::_ops::cat::redispatch(c10::DispatchKeySet, c10::IListRefat::Tensor const&, long) + 0x7b (0x7901f61bd3b in /home/terryk/PyTorchEnv/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so) frame #6: + 0x482f6c5 (0x79021c2f6c5 in /home/terryk/PyTorchEnv/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so) frame #7: + 0x4830153 (0x79021c30153 in /home/terryk/PyTorchEnv/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so) frame #8: at::_ops::cat::call(c10::IListRefat::Tensor const&, long) + 0x1a7 (0x7901f659047 in /home/terryk/PyTorchEnv/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so) frame #9: torchfort::SACMLPModel::forward(std::vector<at::Tensor, std::allocatorat::Tensor > const&) + 0x6b (0x7903dc7287b in /home/terryk/InstallTrees/TorchFort/build/lib/libtorchfort.so) frame #10: torchfort::ModelWrapper::forward(std::vector<at::Tensor, std::allocatorat::Tensor > const&) const + 0x49 (0x7903dc2e689 in /home/terryk/InstallTrees/TorchFort/build/lib/libtorchfort.so) ...

Surely it's OK to feed in 1D state vectors? However, when I look at the cart_pole example where state_dim is passed explicitly in the C++ interface, the 4-element state is passed with a state_dim of 2. Do you know why that is?

Incidentally, I haven't found any documentation for SACMLP. Do you know how this differs from a regular MLP network?

terryfrankcombe avatar Sep 14 '25 04:09 terryfrankcombe

❌ Build workflow failed! View run

github-actions[bot] avatar Sep 14 '25 04:09 github-actions[bot]

Hi @terryfrankcombe, I will take a swing at this but @azrael417 will know a bit more about the RL implementation, specifically the question about the SACMLP implementation. The model code for that can be found here: https://github.com/NVIDIA/TorchFort/blob/master/src/csrc/models/sac_model.cpp

Looking at the error in your most recent comment, the issue is that the input state vectors in our default implementation are being passed into an MLP model which expects standard dimensioning of the arrays. For MLP inputs, the standard array dimensioning for a 1D state would be (batch_size, n_channels, width), also known as "NCW" format. This is a row-major indexing order so the corresponding Fortran array should be sized (width, n_channels, batch_size).

Now, as you point out, the current Fortran RL interface only has functions with 4D input state arrays exposed. This 4D Fortran array corresponds to a "NCHW" (batch_size, n_channels, height, width) formatted MLP input for a 2D problem. @azrael417, maybe we want to add interfaces for states from problems with 1D and 3D spatial dimensions?

The easiest way to address your problem using the existing interfaces is to size your arrays appropriately for a 2D state (i.e. an NCHW dimensioned input), but set either the height or width of the array to 1, and then batch_size and n_channels appropriately.

Relatedly, if you do end up passing a slice like W2(:,t) as an input, you may need to adjust the indexing to preserve the array dimensions (i.e., W2(:,t:t) will keep the Fortran array 2D, but with a singleton dimension).

Hope this helps!

romerojosh avatar Sep 15 '25 17:09 romerojosh

Hello, sorry for the delayed reply, I as off for the last few days. I will have a look tomorrow.

azrael417 avatar Sep 15 '25 18:09 azrael417

Thanks @romerojosh for your excellent reply!

terryfrankcombe avatar Sep 15 '25 22:09 terryfrankcombe

Ok, so here is the thing: SACMLP is different from a regular MLP in the sense that it produces a log-variance (or log-std to be more precise) prediction as well. For the SAC algorithm, you need your policy to be inherently noisy. Therefore, in order to work with the tests, I had to implement my own class. Generally, I recommend writing a SAC Policy tailored to the problem.

state_dependent_sigma = params.get_param("state_dependent_sigma", true)[0]; double log_sigma_init = params.get_param("log_sigma_init", 0.)[0];

Here, the SAC Policy has a sigma parameter which is either predicted from the state (state_dependent_sigma=True) or constant (this is recommended, because otherwise the policy becomes unstable). This mimics the stable-baselines implementation of a SAC enabled MLP here

When you look at stable baselines, the SAC returns the mean of the action and log-std as well. This is important for the algorithm to work, since the algorithm will sample from a distribution N(mu, sigma), where mu and sigma are predicted from the policy.

Concerning the shape issues: it is a lot of work to implement interfaces for all shape combinations a user might want to use, but I agree it would be useful to have a 1D state and 1D action interface which can be used for all cases.

Note that the SACMLP expects MLP input which is of shape (batch, features). So your state should be one dimensional

azrael417 avatar Sep 16 '25 05:09 azrael417

SAC is not the easiest algorithm to get to work properly, may I ask why you are not using something like PPO? Afaiu, PPO is widely used in the RL community because it is relatively stable without much tuning. SAC is more sensitive to replay buffer sizes and these things. If you want to use off-policy algorithms, does TD3 not work for you?

azrael417 avatar Sep 16 '25 05:09 azrael417

I'm using SAC because I have a one-to-many problem and I want the model to return diverse actions, so the stochastic nature of SAC is attractive. I can quickly evaluate the quality of any action, but there is no single right answer.

Re PPO: I'm a bit confused, the TorchFort docs say TF will only do off-policy RL? Nonetheless, my valid state space isn't terribly large so off-policy sounds like a better option.

terryfrankcombe avatar Sep 16 '25 09:09 terryfrankcombe

@romerojosh, can I return to this for a moment:

Looking at the error in your most recent comment, the issue is that the input state vectors in our default implementation are being passed into an MLP model which expects standard dimensioning of the arrays. For MLP inputs, the standard array dimensioning for a 1D state would be (batch_size, n_channels, width), also known as "NCW" format. This is a row-major indexing order so the corresponding Fortran array should be sized (width, n_channels, batch_size).

Now, as you point out, the current Fortran RL interface only has functions with 4D input state arrays exposed. This 4D Fortran array corresponds to a "NCHW" (batch_size, n_channels, height, width) formatted MLP input for a 2D problem. @azrael417, maybe we want to add interfaces for states from problems with 1D and 3D spatial dimensions?

My data is actually 2D and complex, so 4D with two channels does fit. (Previously I was trying to simplify to get an MWE running.) Then the first dim in the above, batch_size, would be iterating over distinct environment states, with corresponding actions coming out [(batch_size, actions) in the 2D case]? How significant is the ordering of my data? I assume that the above refers to C, so that if I feed it the natural Fortran order W(1:N,1:N,1:2,t:t) for a single state at "time" t the actions will come out of the Fortran interface as (1:n_actions,t:t)?

terryfrankcombe avatar Sep 16 '25 13:09 terryfrankcombe

Using complex states is a tricky thing. In this case, I would write my own action policy to deal with the complex dof. Either using a neural operator architecture or just use the magnitude of the complex numbers. Concerning ordering: it depends on how you implement the backend network. The SAC MLP expects an input of shape (batch size, number of features) in C-ordering, so row major. Since you are using the replay buffer, you only pass one sample at a time, so you would pass a 1D state array (your features) and a 1D action array I assume.

azrael417 avatar Sep 16 '25 13:09 azrael417

If you want to use an action policy (and value function) which can handle more complex states, you need to write your own in pytorch, export it, and then use the torch.jit logic to load it into torchfort. The architecture of the backend network depend on what the states look like. If you have a multi dimensional grid with a feature vector at each site, you could use a multi dimensional convolutional neural network with kernel size =1. This mimics applying an MLP at each site. Then you need to think about what your action space should be and you need to map the input states to the outputs. In the multi-d conv case, you could then use some scale independent site aggregation method, such as global average pooling, i.e. average the extracted features over the whole grid and feed that to an MLP to produce an action estimate.

azrael417 avatar Sep 16 '25 13:09 azrael417

How significant is the ordering of my data?

It is very significant. Once your array is handed off to the TorchFort backend, it is wrapped in a C++ torch tensor. We infer the dimensions of the array from Fortran (as a convenience for users), but also know that the dimension ordering is column major. As a result, we invert the dimension ordering to create the equivalent row major view of the data for our libtorch backend.

As per all the comments about required array dimensioning, these are mostly a result of how the internal MLP and SACMLP models are written. With that said, the MLP model used to be fine with 1D input arrays but there was a concatenation operation on dim=1 that was added at some point to account for handle multiple input tensors to the MLP model. This has the unfortunate effect of requiring the input tensors to have at least two dimensions. I actually think this concatenation operation is not really the right thing to do here anyway so I will put together a PR to resolve this to re-enable 1D inputs to the MLP model.

romerojosh avatar Sep 16 '25 15:09 romerojosh

I think I understand where some of the confusion is coming from: the prediction and exploration routines expect a batch dimension, this is why they are called 2D_2D. This goes together with the replay buffer update routine 1D_1D. Here is the catch: you feed tuples (s, a, s', r, d) to the replay buffer (old state, old action, new state, reward earned, terminal flag). In the absence of multi envs, the states have dim N, the actions dim M, r and d are always a number (since rewards and terminal flags are always scalar valued). This is, because you only push one sample at a time (again, with multi env you do not do that). However, when you run the policy prediction or state evaluation (basically fwd passes through pi and Q), then it expects a batch dim. So, it expects a dim N+1 state, where the last dim (for fortran) is the batch dim. This could be one however. And it produces a M+1 dim output, the batched action.

I think we can make the naming more consistent, but this is not supposed to be user facing and more or less hidden behind the unified interface. However, what you likely want to use is the 2D2D explore and eval function and there 1D1D RB update functions if you want to use SAC MLP.

azrael417 avatar Sep 17 '25 05:09 azrael417

I agree that in practice, it would be nice to being able to simply feed a 1D array (so a single sample state) and implicitly it will produce a single action vector. I will think about how to do that without breaking the interface.

Concerning double vs float: a lot of the internal workings of the RL stuff use float, so some of the values you pass as double will likely be casted to float later internally. This should not be a problem though, since especially for these systems with a lot of noise, double should not be better than float. I recommend using float throughout.

azrael417 avatar Sep 17 '25 05:09 azrael417