python-betterproto icon indicating copy to clipboard operation
python-betterproto copied to clipboard

Extending generated dataclasses with get/set conversion

Open michaelosthege opened this issue 3 years ago • 3 comments

I have two cases where I'd like to convert field types automatically on creation/get/set.

I understand that this is a little complicated with dataclasses to begin with, but I was hoping that the fields and Message class might enable something. At least a similar idea was mentioned in #253.

Both scenarios for automatic conversion are covered by the task of protobuffing NumPy arrays. I found two projects related to this--neither of them worked for the generic case:

  • https://github.com/telamonian/numpy-protobuf recreates the NumPy array object model, but doesn't work for generic types.
  • https://github.com/josteinbf/numproto has a PyPI release but it's trivial and incomplete, storing only stores 1D arrays data without shape & dtype information.

The following protobuf doesn't cover all generic dtypes either, but it's easy to work with:

syntax = "proto3";
package npproto;


// Represents a NumPy array of arbitrary shape or dtype.
// Note that the array must support the buffer protocol.
message ndarray {
    bytes data = 1;
    string dtype = 2;
    repeated int64 shape = 3;
    repeated int64 strides = 4;
}

With betterproto, we get this generated dataclass:

npproto.py 👇

# Generated by the protocol buffer compiler.  DO NOT EDIT!
# sources: npproto/ndarray.proto
# plugin: python-betterproto
from dataclasses import dataclass
from typing import List

import betterproto
from betterproto.grpc.grpclib_server import ServiceBase


@dataclass(eq=False, repr=False)
class Ndarray(betterproto.Message):
    """
    Represents a NumPy array of arbitrary shape or dtype. Note that the array
    must support the buffer protocol.
    """

    data: bytes = betterproto.bytes_field(1)
    dtype: str = betterproto.string_field(2)
    shape: List[int] = betterproto.int64_field(3)
    strides: List[int] = betterproto.int64_field(4)

I wrote the following code to convert between actual NumPy arrays and Ndarray messages:

utils.py 👇

import numpy
from npproto import Ndarray


def ndarray_from_numpy(arr: numpy.ndarray) -> Ndarray:
    dt = str(arr.dtype)
    if "datetime64" in dt:
        # datetime64 doesn't support the buffer protocol.
        # See https://github.com/numpy/numpy/issues/4983
        # This is a hack that automatically encodes it as int64.
        arr = arr.astype("int64")
    return Ndarray(
        shape=list(arr.shape),
        dtype=dt,
        data=bytes(arr.data),
        strides=list(arr.strides),
    )


def ndarray_to_numpy(nda: Ndarray) -> numpy.ndarray:
    if "datetime64" in nda.dtype:
        # Backwards conversion: The data was stored as int64.
        arr = numpy.ndarray(
            buffer=nda.data,
            shape=nda.shape,
            dtype="int64",
            strides=nda.strides,
        ).astype(nda.dtype)
    else:
        arr = numpy.ndarray(
            buffer=nda.data,
            shape=nda.shape,
            dtype=numpy.dtype(nda.dtype),
            strides=nda.strides,
        )
    return arr

Here are some tests for it:

import numpy
from datetime import datetime
import pytest


class TestUtils:
    @pytest.mark.parametrize(
        "arr",
        [
            numpy.arange(5),
            numpy.random.uniform(size=(2, 3)),
            numpy.array(5),
            numpy.array(["hello", "world"]),
            numpy.array([datetime(2020, 3, 4, 5, 6, 7, 8), datetime(2020, 3, 4, 5, 6, 7, 9)]),
            numpy.array([(1, 2), (3, 2, 1)], dtype=object),
        ],
    )
    def test_conversion(self, arr: numpy.ndarray):
        nda = utils.ndarray_from_numpy(arr)
        enc = bytes(nda)
        dec = npproto.Ndarray().parse(enc)
        assert isinstance(dec.data, bytes)
        result = utils.ndarray_to_numpy(dec)
        numpy.testing.assert_array_equal(result, arr)
        pass

With this we have two examples where automatic conversion would be neat: tuple ↔ list and ndarray ↔ Ndarray.

This message has a need for both:

message MyMessage {
    repeated string dimnames = 1;
    npproto.ndarray arr = 2;
}

The generated dataclass:

@dataclass(eq=False, repr=False)
class MyMessage(betterproto.Message):
    dimnames: List[str] = betterproto.string_field(1)
    arr: "npproto.Ndarray" = betterproto.message_field(2)

The dataclass constructor doesn't automatically convert types, so creating a MyMessage with a tuple of dimnames is problematic:

msg = MyMessage(dimnames=("first", "second"))
bytes(msg)

Traceback (most recent call last):
...
  File "...\betterproto\__init__.py", line 772, in __bytes__
    output += _serialize_single(
  File "...\betterproto\__init__.py", line 358, in _serialize_single
    value = _preprocess_single(proto_type, wraps, value)
  File "...\betterproto\__init__.py", line 319, in _preprocess_single
    return encode_varint(value)
  File "...\betterproto\__init__.py", line 297, in encode_varint
    if value < 0:
TypeError: '<' not supported between instances of 'tuple' and 'str'

If the type annotation was Tuple[str] it would've worked.

Likewise it would be neat if the conversion code from above would be applied automatically so we could do this:

msg = MyMessage(arr=numpy.ones((1, 2))
encoded = bytes(msg)
decoded = MyMessage().parse(enc)
assert isinstance(decoded.arr, numpy.ndarray)

Thoughts, @Gobot1234 ? Would it be feasible to add built-in support for NumPy arrays this way? Maybe via an API to register conversion functions into the created message types?

MyMessage.register_converters(get=utils.numpy_from_ndarray, set=utils.numpy_to_ndarray)

michaelosthege avatar Dec 31 '21 18:12 michaelosthege

I like the idea overall although I think it could be more transparent to type checkers, have you considered using typing.Annotated for storing the converter functions and then the actual type could be used to annotate this?

Syntax wise could look like

@dataclass(eq=False, repr=False)
class MyMessage(betterproto.Message):
    dimnames: List[str] = betterproto.string_field(1)
    arr: Annotated[
        numpy.ndarray, betterproto.converter(from_array, to_array, npproto.Ndarray)
    ] = betterproto.message_field(2)

Gobot1234 avatar Dec 31 '21 20:12 Gobot1234

Interesting, I didn't know about typing.Annotated. Two questions:

  1. Can the getter/setter of the field be modified to automatically apply a converter from the annotation?
  2. How can we inject the converter functions and tell betterproto to write the Annotation?

I suppose we'd prefer to store as the user-defined type (eg. ndarray or tuple and only apply the converter for (de)serialization?

michaelosthege avatar Jan 01 '22 11:01 michaelosthege

Yes I think this would be good as we could actually simplify the code which handles datetime and timedelta. Yeah I'd agree with the last comment as well

Gobot1234 avatar Jan 02 '22 21:01 Gobot1234