recommenders
recommenders copied to clipboard
load_weights can't work
session 1.
def build_model(): model = RetrievalModel(item_model, user_model) learning_rate = 0.01 model.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate)) return model
model = build_model()
model.fit(behavior_dataset, epochs=30)
model.save_weights(save_path,overwrite=True )
compute_loss_args = {
"user_id" : tf.constant(["45"]),
"work_id" : tf.constant(["45"]),
"tags" : tf.constant([""]),
"work_uid" : tf.constant(["45"]),
"money_goods" : tf.constant([100]),
"category_id" : tf.constant(["2"]),
"bid_type" : tf.constant(["normal"]),
"is_rec" : tf.constant(["1"]),
"weights" : tf.constant([1]),
}
model(compute_loss_args)
s = model.load_weights(save_path ).expect_partial()
k = 100 user_id = '1035369' index = tfrs.layers.factorized_top_k.BruteForce(model.user_model,k) index.index_from_dataset( work_dataset.shuffle(100_100).map(lambda x: (x["work_id"], model.item_model(x))) #注意,这里是全局可推荐列表 ) print(f"rec user_id :{user_id}") print(user_id in unique_user_id)
is right。
session 2:
def build_model(): model = RetrievalModel(item_model, user_model) learning_rate = 0.01 model.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate)) return model
- _model = build_model() #delete
- model.fit(behavior_dataset, epochs=30) #delete
- model.save_weights(save_path,overwrite=True )#delete_
compute_loss_args = {
"user_id" : tf.constant(["45"]),
"work_id" : tf.constant(["45"]),
"tags" : tf.constant([""]),
"work_uid" : tf.constant(["45"]),
"money_goods" : tf.constant([100]),
"category_id" : tf.constant(["2"]),
"bid_type" : tf.constant(["normal"]),
"is_rec" : tf.constant(["1"]),
"weights" : tf.constant([1]),
}
model(compute_loss_args)
s = model.load_weights(save_path ).expect_partial()
k = 100 user_id = '1035369' index = tfrs.layers.factorized_top_k.BruteForce(model.user_model,k) index.index_from_dataset( work_dataset.shuffle(100_100).map(lambda x: (x["work_id"], model.item_model(x))) #注意,这里是全局可推荐列表 ) print(f"rec user_id :{user_id}") print(user_id in unique_user_id)
is wrong 不准确。
tensorboard 2.15.2 keras 2.15.0