recommenders-addons
recommenders-addons copied to clipboard
dynamic embedding doesn't work with tf.distribute.experimental.ParameterServerStrategy
System information
- OS Platform and Distribution Manjaro:
- TensorFlow version and how it was installed (source or binary): 2.5.1 binary
- TensorFlow-Recommenders-Addons version and how it was installed (source or binary): 0.3.0 binary
- Python version: 3.7.9
- Is GPU used? (yes/no): no
Describe the bug
dynamic embedding doesn't work with tf.distribute.experimental.ParameterServerStrategy.
Related issue #167 because dynamic embedding also doesn't work with @tf.function.
Code to reproduce the issue
Code is collected from https://www.tensorflow.org/tutorials/distribute/parameter_server_training and https://github.com/tensorflow/recommenders-addons/blob/master/docs/tutorials/dynamic_embedding_tutorial.ipynb. It works fine when use_de=False.
import os
import multiprocessing
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_recommenders_addons as tfra
import numpy as np
import portpicker
from tensorflow.keras.layers import Dense
class NCFModel(tf.keras.Model):
def __init__(self, use_de):
super(NCFModel, self).__init__()
self.embedding_size = 32
self.use_de = use_de
self.d0 = Dense(
256,
activation='relu',
kernel_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1),
bias_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1))
self.d1 = Dense(
64,
activation='relu',
kernel_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1),
bias_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1))
self.d2 = Dense(
1,
kernel_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1),
bias_initializer=tf.keras.initializers.RandomNormal(0.0, 0.1))
if use_de:
self.user_embeddings = tfra.dynamic_embedding.get_variable(
name="user_dynamic_embeddings",
dim=self.embedding_size,
initializer=tf.keras.initializers.RandomNormal(-1, 1),
key_dtype=tf.int64)
self.movie_embeddings = tfra.dynamic_embedding.get_variable(
name="moive_dynamic_embeddings",
dim=self.embedding_size,
initializer=tf.keras.initializers.RandomNormal(-1, 1),
key_dtype=tf.int64)
else:
self.user_embeddings = self.add_weight(
name=f"user_embeddings",
shape=(10000, self.embedding_size),
dtype=tf.float32,
initializer=tf.keras.initializers.RandomNormal(-1, 1),
trainable=True,
)
self.movie_embeddings = self.add_weight(
name=f"movie_embeddings",
shape=(10000, self.embedding_size),
dtype=tf.float32,
initializer=tf.keras.initializers.RandomNormal(-1, 1),
trainable=True,
)
def call(self, batch):
movie_id = batch["movie_id"]
user_id = batch["user_id"]
trainable_wrappers = []
if self.use_de:
user_id_weights, user_id_trainable_wrapper = tfra.dynamic_embedding.embedding_lookup_unique(
params=self.user_embeddings,
ids=user_id,
name="user-id-weights",
return_trainable=True
)
movie_id_weights, movie_id_trainable_wrapper = tfra.dynamic_embedding.embedding_lookup_unique(
params=self.movie_embeddings,
ids=movie_id,
name="movie-id-weights",
return_trainable=True
)
trainable_wrappers = [user_id_trainable_wrapper, movie_id_trainable_wrapper]
else:
user_id_weights = tf.gather(self.user_embeddings, user_id)
movie_id_weights = tf.gather(self.movie_embeddings, movie_id)
embeddings = tf.concat([user_id_weights, movie_id_weights], axis=1)
dnn = self.d0(embeddings)
dnn = self.d1(dnn)
dnn = self.d2(dnn)
out = tf.reshape(dnn, shape=[-1])
return out, trainable_wrappers
def create_in_process_cluster(num_workers, num_ps):
"""Creates and starts local servers and returns the cluster_resolver."""
worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
cluster_dict = {}
cluster_dict["worker"] = ["localhost:%s" % port for port in worker_ports]
if num_ps > 0:
cluster_dict["ps"] = ["localhost:%s" % port for port in ps_ports]
cluster_spec = tf.train.ClusterSpec(cluster_dict)
# Workers need some inter_ops threads to work properly.
worker_config = tf.compat.v1.ConfigProto()
if multiprocessing.cpu_count() < num_workers + 1:
worker_config.inter_op_parallelism_threads = num_workers + 1
for i in range(num_workers):
tf.distribute.Server(
cluster_spec,
job_name="worker",
task_index=i,
config=worker_config,
protocol="grpc")
for i in range(num_ps):
tf.distribute.Server(
cluster_spec,
job_name="ps",
task_index=i,
protocol="grpc")
cluster_resolver = tf.distribute.cluster_resolver.SimpleClusterResolver(
cluster_spec, rpc_layer="grpc")
return cluster_resolver
os.environ["GRPC_FAIL_FAST"] = "use_caller"
NUM_WORKERS = 2
NUM_PS = 1
cluster_resolver = create_in_process_cluster(NUM_WORKERS, NUM_PS)
strategy = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver,
variable_partitioner=None)
use_de = True # code works fine if use_de=False
with strategy.scope():
model = NCFModel(use_de)
optimizer = tf.keras.optimizers.Adam()
if use_de:
optimizer = tfra.dynamic_embedding.DynamicEmbeddingOptimizer(optimizer)
@tf.function
def step_fn(iterator):
def replica_fn(batch):
with tf.GradientTape() as tape:
pred, trainable_wrappers = model(batch, training=True)
rating = batch['user_rating']
per_example_loss = (pred - rating)**2
loss = tf.nn.compute_average_loss(per_example_loss)
gradients = tape.gradient(loss, model.trainable_variables + trainable_wrappers)
optimizer.apply_gradients(zip(gradients, model.trainable_variables + trainable_wrappers))
return loss
batch_data = next(iterator)
losses = strategy.run(replica_fn, args=(batch_data,))
sum_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, losses, axis=None)
return sum_loss
def get_dataset_fn(input_context):
global_batch_size = 256
batch_size = input_context.get_per_replica_batch_size(global_batch_size)
ratings = tfds.load("movielens/100k-ratings", split="train")
ratings = ratings.map(lambda x: {
"movie_id": tf.strings.to_number(x["movie_id"], tf.int64),
"user_id": tf.strings.to_number(x["user_id"], tf.int64),
"user_rating": x["user_rating"]
})
wid = input_context.input_pipeline_id
shuffled = ratings.shuffle(100_000, seed=wid, reshuffle_each_iteration=False)
dataset_train = shuffled.take(100_000).batch(batch_size).repeat()
return dataset_train
@tf.function
def per_worker_dataset_fn():
return strategy.distribute_datasets_from_function(get_dataset_fn)
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(strategy)
per_worker_dataset = coordinator.create_per_worker_dataset(per_worker_dataset_fn)
per_worker_iterator = iter(per_worker_dataset)
num_epoches = 20
steps_per_epoch = 100
for i in range(num_epoches):
total_loss = []
for _ in range(steps_per_epoch):
remote = coordinator.schedule(step_fn, args=(per_worker_iterator,))
total_loss.append(remote.fetch())
coordinator.join()
print("epoch", i, "loss", np.mean(total_loss))
Other info / logs
Error:
Traceback (most recent call last):
File "ps_test.py", line 193, in <module>
remote = coordinator.schedule(step_fn, args=(per_worker_iterator,))
File "/home/npbool/Projects/tfra_test/venv/lib/python3.7/site-packages/tensorflow/python/distribute/coordinator/cluster_coordinator.py", line 1150, in schedule
remote_value = self._cluster.schedule(fn, args=args, kwargs=kwargs)
File "/home/npbool/Projects/tfra_test/venv/lib/python3.7/site-packages/tensorflow/python/distribute/coordinator/cluster_coordinator.py", line 977, in schedule
kwargs=kwargs)
File "/home/npbool/Projects/tfra_test/venv/lib/python3.7/site-packages/tensorflow/python/distribute/coordinator/cluster_coordinator.py", line 363, in __init__
**nest.map_structure(_maybe_as_type_spec, replica_kwargs))
File "/home/npbool/Projects/tfra_test/venv/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 1367, in get_concrete_function
concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
File "/home/npbool/Projects/tfra_test/venv/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 1273, in _get_concrete_function_garbage_collected
self._initialize(args, kwargs, add_initializers_to=initializers)
File "/home/npbool/Projects/tfra_test/venv/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 764, in _initialize
*args, **kwds))
File "/home/npbool/Projects/tfra_test/venv/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3050, in _get_concrete_function_internal_garbage_collected
graph_function, _ = self._maybe_define_function(args, kwargs)
File "/home/npbool/Projects/tfra_test/venv/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3444, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/home/npbool/Projects/tfra_test/venv/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3289, in _create_graph_function
capture_by_value=self._capture_by_value),
File "/home/npbool/Projects/tfra_test/venv/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 999, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/home/npbool/Projects/tfra_test/venv/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 672, in wrapped_fn
out = weak_wrapped_fn().__wrapped__(*args, **kwds)
File "/home/npbool/Projects/tfra_test/venv/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 986, in wrapper
raise e.ag_error_metadata.to_exception(e)
AttributeError: in user code:
ps_test.py:156 replica_fn *
optimizer.apply_gradients(zip(gradients, model.trainable_variables + trainable_wrappers))
/home/npbool/Projects/tfra_test/venv/lib/python3.7/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:636 apply_gradients **
self._create_all_weights(var_list)
/home/npbool/Projects/tfra_test/venv/lib/python3.7/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:823 _create_all_weights
self._create_slots(var_list)
/home/npbool/Projects/tfra_test/venv/lib/python3.7/site-packages/tensorflow/python/keras/optimizer_v2/adam.py:124 _create_slots
self.add_slot(var, 'm')
/home/npbool/Projects/tfra_test/venv/lib/python3.7/site-packages/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py:177 add_slot
with strategy.extended.colocate_vars_with(var):
/home/npbool/Projects/tfra_test/venv/lib/python3.7/site-packages/tensorflow/python/distribute/distribute_lib.py:2217 colocate_vars_with
self._validate_colocate_with_variable(colocate_with_variable)
/home/npbool/Projects/tfra_test/venv/lib/python3.7/site-packages/tensorflow/python/distribute/parameter_server_strategy.py:339 _validate_colocate_with_variable
distribute_utils.validate_colocate(colocate_with_variable, self)
/home/npbool/Projects/tfra_test/venv/lib/python3.7/site-packages/tensorflow/python/distribute/distribute_utils.py:246 validate_colocate
_validate_colocate_extended(v, extended)
/home/npbool/Projects/tfra_test/venv/lib/python3.7/site-packages/tensorflow/python/distribute/distribute_utils.py:226 _validate_colocate_extended
if variable_strategy.extended is not extended:
AttributeError: 'NoneType' object has no attribute 'extended'
@rhdong please have a look
@rhdong please have a look
Hi @npbool ,Thank you for feedback, we will try to resolve this in next release. Before that, maybe you can implement your distributed pipeline by using estimator APIs refer to demo, Thank you !
This issue was solved by commit 1a5dfca.
as mentioned this issue on an issue in tf repo, i'm facing same error in tf 2.16.1 in a simple model fit (not recommenders-addons) could you please tell that is this problem from tf or my model? in that issue, i also mentioned a Kaggle notebook with this error.
thank you @MoFHeka @npbool @rhdong 🌹