datasets
datasets copied to clipboard
Expose py.typed
trafficstars
This allows users to see the type information. Without this, MyPy reports errors whenever using classes from tensorflow_datasets.
See PEP 561 for details.
Fixes: #3998
@Conchylicultor Any chance you could review this one line change?
This tiny pull request would repair a dozen type errors in my code:
cmm/problem/deduction.py:64: error: Expected type in class pattern; found "Any" [misc]
cmm/problem/deduction.py:65: error: Statement is unreachable [unreachable]
cmm/problem/deduction.py:69: error: Expected type in class pattern; found "Any" [misc]
cmm/problem/deduction.py:70: error: Statement is unreachable [unreachable]
cmm/problem/deduction.py:73: error: Statement is unreachable [unreachable]
cmm/problem/deduction.py:88: error: Expected type in class pattern; found "Any" [misc]
cmm/problem/deduction.py:89: error: Statement is unreachable [unreachable]
cmm/problem/deduction.py:92: error: Expected type in class pattern; found "Any" [misc]
cmm/problem/deduction.py:93: error: Statement is unreachable [unreachable]
cmm/problem/deduction.py:97: error: Statement is unreachable [unreachable]
due to code like this:
def initial_state(self, example_rng: Generator) -> DeductionProblemState:
ds_input_, ds_label = next(self.dataset_iterator)
label: None | SimplePoolingMessage = None
input_: None | SimplePoolingMessage = None
for key, value in self.info.features.items(): # pyright: ignore
if key == 'label':
ds_value = ds_label
else:
ds_value = ds_input_
def tf_tensor_to_jax_array(tensor: Any) -> Any:
dtype = tf_dtype_to_jax_dtype(tensor.dtype)
return jnp.asarray(tensor, dtype)
match value:
case Image(shape=shape):
assert ds_value.shape[1:] == shape
ds_value = tf_tensor_to_jax_array(ds_value)
new_shape = (ds_value.shape[0], -1)
expectation_parameters = jnp.reshape(ds_value, new_shape)
case ClassLabel(num_classes=num_classes):
expectation_parameters = one_hot(tf_tensor_to_jax_array(ds_value), num_classes)
case _:
raise TypeError
message = SimplePoolingMessage(jnp.ones(self.batch_size), expectation_parameters)
if key == 'label':
label = message
else:
input_ = message
assert input_ is not None
assert label is not None
return DeductionProblemState(input_, label)
def observation_info(self) -> CompoundDistributionInfo[DeductionProblemObservation]:
label: None | DistributionInfo = None
input_: None | DistributionInfo = None
distribution_cls: type[NaturalParametrization[Any, Any]]
for key, value in self.info.features.items(): # pyright: ignore
match value:
case Image(shape=shape, dtype=dtype):
dtype = tf_dtype_to_jax_dtype(dtype)
features = int(np.prod(shape)) # pyright: ignore
distribution_cls = MultivariateUnitNormalNP
case ClassLabel(num_classes=num_classes):
features = num_classes
distribution_cls = MultinomialNP
case _:
raise TypeError
message = DistributionInfo(features, distribution_cls, {})
if key == 'label':
label = message
else:
input_ = message
assert input_ is not None
assert label is not None
return CompoundDistributionInfo(DeductionProblemObservation,
{'input_': input_, 'label': label})
@pierrot0 @fineguy Would it be possible to have this one line change reviewed?
Check out this pull request on ![]()
See visual diffs & provide feedback on Jupyter Notebooks.
Powered by ReviewNB