rl
rl copied to clipboard
[Feature Request] ActionDiscretizer custom sampler
Motivation
Currently you cannot implement a custom sampling technique for the ActionDiscretizer transform.
Solution
Bring custom_arange out of transform_input_spec and make it a method of ActionDiscretizer. Wrappers around ActionDiscretizer can then update this.
Alternatives
Additional context
class SamplingStrategy(IntEnum):
MEDIAN = 0
LOW = 1
HIGH = 2
RANDOM = 3
def __init__(
self,
num_intervals: int | torch.Tensor,
action_key: NestedKey = "action",
out_action_key: NestedKey = None,
sampling=None,
categorical: bool = True,
):
if out_action_key is None:
out_action_key = action_key
super().__init__(in_keys_inv=[action_key], out_keys_inv=[out_action_key])
self.action_key = action_key
self.out_action_key = out_action_key
if not isinstance(num_intervals, torch.Tensor):
self.num_intervals = num_intervals
else:
self.register_buffer("num_intervals", num_intervals)
if sampling is None:
sampling = self.SamplingStrategy.MEDIAN
self.sampling = sampling
self.categorical = categorical
def __repr__(self):
def _indent(s):
return indent(s, 4 * " ")
num_intervals = f"num_intervals={self.num_intervals}"
action_key = f"action_key={self.action_key}"
out_action_key = f"out_action_key={self.out_action_key}"
sampling = f"sampling={self.sampling}"
categorical = f"categorical={self.categorical}"
return (
f"{type(self).__name__}(\n{_indent(num_intervals)},\n{_indent(action_key)},"
f"\n{_indent(out_action_key)},\n{_indent(sampling)},\n{_indent(categorical)})"
)
def _custom_arange(self, nint, device):
result = torch.arange(
start=0.0,
end=1.0,
step=1 / nint,
dtype=self.dtype,
device=device,
)
result_ = result
if self.sampling in (
self.SamplingStrategy.HIGH,
self.SamplingStrategy.MEDIAN,
):
result_ = (1 - result).flip(0)
if self.sampling == self.SamplingStrategy.MEDIAN:
result = (result + result_) / 2
else:
result = result_
return result
def transform_input_spec(self, input_spec):
try:
action_spec = input_spec["full_action_spec", self.in_keys_inv[0]]
if not isinstance(action_spec, Bounded):
raise TypeError(
f"action spec type {type(action_spec)} is not supported."
)
n_act = action_spec.shape
if not n_act:
n_act = 1
else:
n_act = n_act[-1]
self.n_act = n_act
self.dtype = action_spec.dtype
interval = (action_spec.high - action_spec.low).unsqueeze(-1)
num_intervals = self.num_intervals
if isinstance(num_intervals, int):
arange = (
self._custom_arange(num_intervals, action_spec.device).expand(
n_act, num_intervals
)
* interval
)
self.register_buffer(
"intervals", action_spec.low.unsqueeze(-1) + arange
)
else:
arange = [
self._custom_arange(_num_intervals, action_spec.device) * interval
for _num_intervals, interval in zip(
num_intervals.tolist(), interval.unbind(-2)
)
]
self.intervals = [
low + arange
for low, arange in zip(
action_spec.low.unsqueeze(-1).unbind(-2), arange
)
]
Checklist
- [X] I have checked that there is no similar issue in the repo (required)