funsor icon indicating copy to clipboard operation
funsor copied to clipboard

Display funsor terms nicely when it breaks into multiple lines

Open fehiepsi opened this issue 4 years ago • 2 comments

fehiepsi avatar Mar 31 '21 01:03 fehiepsi

How about

import black

class Foo:
    def __str__(self):
        ugly = ...
        return black.format_str(ugly, mode=black.FileMode())

for example

>>> import black, funsor, torch
>>> funsor.set_backend("torch")
>>> from funsor.torch.distributions import Multinomial
>>> x = Multinomial(10, torch.tensor([[0.2, 0.8], [0.3, 0.7]]))
>>> print(x)
Multinomial(tensor(10.), tensor([[0.2000, 0.8000],
        [0.3000, 0.7000]]), value)
>>> print(black.format_str(repr(x), mode=black.FileMode()))
Multinomial(
    total_count=tensor(10.0),
    probs=tensor([[0.2000, 0.8000], [0.3000, 0.7000]]),
    value=value,
)

... though I guess this won't work on large tensors that abbreviate like [1, 2, ..., 999].

fritzo avatar Mar 31 '21 11:03 fritzo

Nice solution! I tested with abbreviation and it seems to work. Together with yapf, the result is nice.

from funsor import Reals, Tensor, Variable, ops, testing
from funsor.delta import Delta
import black

shape = (3, 2)
point = Tensor(testing.randn(shape))
x = Variable("x", Reals[shape])
actual = Delta("y", point)(y=ops.log(x))
bx = black.format_str(repr(actual), mode=black.FileMode())
print(bx)

from yapf.yapflib.yapf_api import FormatCode
print(FormatCode(bx)[0])

return

Delta(
    (
        (
            "x",
            (
                Tensor(
                    tensor(
                        [
                            [1.1031, 1.4902, 1.0148, ..., 0.1421, 2.2505, 3.2117],
                            [2.6469, 2.3936, 0.5754, ..., 1.1973, 0.2250, 2.7428],
                            [1.7854, 2.3582, 0.1898, ..., 0.4366, 0.3680, 0.4215],
                            ...,
                            [2.0760, 1.0118, 4.1144, ..., 0.7696, 0.2702, 4.7514],
                            [1.7888, 1.4948, 1.5240, ..., 0.9670, 0.4326, 1.2505],
                            [1.8849, 1.2876, 0.7254, ..., 1.6765, 0.3074, 0.9956],
                        ]
                    )
                ),
                Tensor(tensor(30.6611)),
            ),
        ),
    )
)

Delta(((
    "x",
    (
        Tensor(
            tensor([
                [1.1031, 1.4902, 1.0148, ..., 0.1421, 2.2505, 3.2117],
                [2.6469, 2.3936, 0.5754, ..., 1.1973, 0.2250, 2.7428],
                [1.7854, 2.3582, 0.1898, ..., 0.4366, 0.3680, 0.4215],
                ...,
                [2.0760, 1.0118, 4.1144, ..., 0.7696, 0.2702, 4.7514],
                [1.7888, 1.4948, 1.5240, ..., 0.9670, 0.4326, 1.2505],
                [1.8849, 1.2876, 0.7254, ..., 1.6765, 0.3074, 0.9956],
            ])),
        Tensor(tensor(30.6611)),
    ),
),))

This could be the right path for us. I'll try to see if we can avoid extra dependency...

fehiepsi avatar Mar 31 '21 18:03 fehiepsi