datasets icon indicating copy to clipboard operation
datasets copied to clipboard

Expose py.typed

Open NeilGirdhar opened this issue 3 years ago • 2 comments
trafficstars

This allows users to see the type information. Without this, MyPy reports errors whenever using classes from tensorflow_datasets.

See PEP 561 for details.

Fixes: #3998

NeilGirdhar avatar Apr 25 '22 21:04 NeilGirdhar

@Conchylicultor Any chance you could review this one line change?

NeilGirdhar avatar May 03 '22 07:05 NeilGirdhar

This tiny pull request would repair a dozen type errors in my code:

cmm/problem/deduction.py:64: error: Expected type in class pattern; found "Any"  [misc]
cmm/problem/deduction.py:65: error: Statement is unreachable  [unreachable]
cmm/problem/deduction.py:69: error: Expected type in class pattern; found "Any"  [misc]
cmm/problem/deduction.py:70: error: Statement is unreachable  [unreachable]
cmm/problem/deduction.py:73: error: Statement is unreachable  [unreachable]
cmm/problem/deduction.py:88: error: Expected type in class pattern; found "Any"  [misc]
cmm/problem/deduction.py:89: error: Statement is unreachable  [unreachable]
cmm/problem/deduction.py:92: error: Expected type in class pattern; found "Any"  [misc]
cmm/problem/deduction.py:93: error: Statement is unreachable  [unreachable]
cmm/problem/deduction.py:97: error: Statement is unreachable  [unreachable]

due to code like this:

    def initial_state(self, example_rng: Generator) -> DeductionProblemState:
        ds_input_, ds_label = next(self.dataset_iterator)
        label: None | SimplePoolingMessage = None
        input_: None | SimplePoolingMessage = None
        for key, value in self.info.features.items():  # pyright: ignore
            if key == 'label':
                ds_value = ds_label
            else:
                ds_value = ds_input_

            def tf_tensor_to_jax_array(tensor: Any) -> Any:
                dtype = tf_dtype_to_jax_dtype(tensor.dtype)
                return jnp.asarray(tensor, dtype)

            match value:
                case Image(shape=shape):
                    assert ds_value.shape[1:] == shape
                    ds_value = tf_tensor_to_jax_array(ds_value)
                    new_shape = (ds_value.shape[0], -1)
                    expectation_parameters = jnp.reshape(ds_value, new_shape)
                case ClassLabel(num_classes=num_classes):
                    expectation_parameters = one_hot(tf_tensor_to_jax_array(ds_value), num_classes)
                case _:
                    raise TypeError
            message = SimplePoolingMessage(jnp.ones(self.batch_size), expectation_parameters)
            if key == 'label':
                label = message
            else:
                input_ = message
        assert input_ is not None
        assert label is not None
        return DeductionProblemState(input_, label)

    def observation_info(self) -> CompoundDistributionInfo[DeductionProblemObservation]:
        label: None | DistributionInfo = None
        input_: None | DistributionInfo = None
        distribution_cls: type[NaturalParametrization[Any, Any]]
        for key, value in self.info.features.items():  # pyright: ignore
            match value:
                case Image(shape=shape, dtype=dtype):
                    dtype = tf_dtype_to_jax_dtype(dtype)
                    features = int(np.prod(shape))  # pyright: ignore
                    distribution_cls = MultivariateUnitNormalNP
                case ClassLabel(num_classes=num_classes):
                    features = num_classes
                    distribution_cls = MultinomialNP
                case _:
                    raise TypeError
            message = DistributionInfo(features, distribution_cls, {})
            if key == 'label':
                label = message
            else:
                input_ = message
        assert input_ is not None
        assert label is not None
        return CompoundDistributionInfo(DeductionProblemObservation,
                                        {'input_': input_, 'label': label})

NeilGirdhar avatar May 05 '22 07:05 NeilGirdhar

@pierrot0 @fineguy Would it be possible to have this one line change reviewed?

NeilGirdhar avatar Dec 07 '22 21:12 NeilGirdhar

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB