rl icon indicating copy to clipboard operation
rl copied to clipboard

[Feature Request] ActionDiscretizer custom sampler

Open oslumbers opened this issue 1 year ago • 0 comments

Motivation

Currently you cannot implement a custom sampling technique for the ActionDiscretizer transform.

Solution

Bring custom_arange out of transform_input_spec and make it a method of ActionDiscretizer. Wrappers around ActionDiscretizer can then update this.

Alternatives

Additional context

class SamplingStrategy(IntEnum):

        MEDIAN = 0
        LOW = 1
        HIGH = 2
        RANDOM = 3

    def __init__(
        self,
        num_intervals: int | torch.Tensor,
        action_key: NestedKey = "action",
        out_action_key: NestedKey = None,
        sampling=None,
        categorical: bool = True,
    ):
        if out_action_key is None:
            out_action_key = action_key
        super().__init__(in_keys_inv=[action_key], out_keys_inv=[out_action_key])
        self.action_key = action_key
        self.out_action_key = out_action_key
        if not isinstance(num_intervals, torch.Tensor):
            self.num_intervals = num_intervals
        else:
            self.register_buffer("num_intervals", num_intervals)
        if sampling is None:
            sampling = self.SamplingStrategy.MEDIAN
        self.sampling = sampling
        self.categorical = categorical

    def __repr__(self):
        def _indent(s):
            return indent(s, 4 * " ")

        num_intervals = f"num_intervals={self.num_intervals}"
        action_key = f"action_key={self.action_key}"
        out_action_key = f"out_action_key={self.out_action_key}"
        sampling = f"sampling={self.sampling}"
        categorical = f"categorical={self.categorical}"
        return (
            f"{type(self).__name__}(\n{_indent(num_intervals)},\n{_indent(action_key)},"
            f"\n{_indent(out_action_key)},\n{_indent(sampling)},\n{_indent(categorical)})"
        )

    def _custom_arange(self, nint, device):
        result = torch.arange(
            start=0.0,
            end=1.0,
            step=1 / nint,
            dtype=self.dtype,
            device=device,
        )
        result_ = result
        if self.sampling in (
            self.SamplingStrategy.HIGH,
            self.SamplingStrategy.MEDIAN,
        ):
            result_ = (1 - result).flip(0)
        if self.sampling == self.SamplingStrategy.MEDIAN:
            result = (result + result_) / 2
        else:
            result = result_
        return result

    def transform_input_spec(self, input_spec):
        try:
            action_spec = input_spec["full_action_spec", self.in_keys_inv[0]]
            if not isinstance(action_spec, Bounded):
                raise TypeError(
                    f"action spec type {type(action_spec)} is not supported."
                )

            n_act = action_spec.shape
            if not n_act:
                n_act = 1
            else:
                n_act = n_act[-1]
            self.n_act = n_act

            self.dtype = action_spec.dtype
            interval = (action_spec.high - action_spec.low).unsqueeze(-1)

            num_intervals = self.num_intervals

            if isinstance(num_intervals, int):
                arange = (
                    self._custom_arange(num_intervals, action_spec.device).expand(
                        n_act, num_intervals
                    )
                    * interval
                )
                self.register_buffer(
                    "intervals", action_spec.low.unsqueeze(-1) + arange
                )
            else:
                arange = [
                    self._custom_arange(_num_intervals, action_spec.device) * interval
                    for _num_intervals, interval in zip(
                        num_intervals.tolist(), interval.unbind(-2)
                    )
                ]
                self.intervals = [
                    low + arange
                    for low, arange in zip(
                        action_spec.low.unsqueeze(-1).unbind(-2), arange
                    )
                ]

Checklist

  • [X] I have checked that there is no similar issue in the repo (required)

oslumbers avatar Nov 25 '24 15:11 oslumbers