io icon indicating copy to clipboard operation
io copied to clipboard

Cannot merge options for dataset of type IO>BigQueryDataset

Open kardiff18 opened this issue 2 years ago • 0 comments

When trying to read data from the BigQueryReadSession from Tensorflow IO, I get an error that the datasets cannot be merged. The full error message is:

UNIMPLEMENTED: Cannot merge options for dataset of type IO>BigQueryDataset, because the dataset does not implement `InputDatasets`

I believe the error is that InputDatasets is in BigQueryDatasetOp, when it should be in BigQueryDatasetOp::Dataset (and while we're at it, it should also clear the inputs):

tensorflow_io/core/kernels/bigquery/bigquery_dataset_op.cc
42,45d41
<   Status InputDatasets(std::vector<const DatasetBase *> *inputs) const {
<     return Status();
<   }
< 
214a211,216
>     }
> 
>     Status InputDatasets(
>         std::vector<const DatasetBase *> *inputs) const override {
>       inputs->clear();
>       return Status();

I've tried a variety of libs and dependencies, but they all fail:

google-cloud-bigquery==2.31.0
google-cloud-bigquery-storage==2.10.1
tensorflow==2.8.0
tensorflow-io==0.25.0
tensorflow-recommenders==0.6.0
pyyaml==5.3.1

Reproducible code here, you need a GCP project to use it:

from tensorflow.python.framework import ops
from tensorflow.python.framework import dtypes
from tensorflow_io.bigquery import BigQueryClient
from tensorflow_io.bigquery import BigQueryReadSession
import json

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

def read_bigquery(project, dataset, table):
  tensorflow_io_bigquery_client = BigQueryClient()
  read_session = tensorflow_io_bigquery_client.read_session(
      "projects/" + project,
      project, 
      table, 
      dataset,
      selected_fields=['unique_key', 'taxi_id'],
      output_types=['string', 'string'],
  )
  dataset = read_session.parallel_read_rows()
  return dataset

taxi_data = read_bigquery("ENTER YOUR GCP PROJECT HERE", 
                          "chicago_taxi_trips", 
                          "taxi_trips_300k"
)

taxi_ids = taxi_data.batch(1_000).map(lambda x: x["taxi_id"])
taxi_id_vocab = tf.keras.layers.StringLookup(
            mask_token=None, name="taxi_id_lookup")
            
taxi_id_vocab.adapt(data.map(lambda x: x["taxi_id"]))

kardiff18 avatar Nov 29 '22 14:11 kardiff18