多任務(wù)學(xué)習(xí)模型ESMM原理與實現(xiàn)(附代碼)

來源:DataFunTalk 本文約2500字,建議閱讀5分鐘
文章基于 Multi-Task Learning (MTL) 的思路,提出一種名為ESMM的CVR預(yù)估模型。
鏈接:https://arxiv.org/abs/1804.07931

2. Data Sparsity (DS)
CVR = 轉(zhuǎn)化數(shù)/點擊數(shù)。是預(yù)測“假設(shè)item被點擊,那么它被轉(zhuǎn)化”的概率。CVR預(yù)估任務(wù),與CTR沒有絕對的關(guān)系。一個item的ctr高,cvr不一定同樣會高,如標(biāo)題黨文章的瀏覽時長往往較低。這也是不能直接使用全部樣本訓(xùn)練CVR模型的原因,因為無法確定那些曝光未點擊的樣本,假設(shè)他們被點擊了,是否會被轉(zhuǎn)化。如果直接使用0作為它們的label,會很大程度上誤導(dǎo)CVR模型的學(xué)習(xí)。 CTCVR = 轉(zhuǎn)換數(shù)/曝光數(shù)。是預(yù)測“item被點擊,然后被轉(zhuǎn)化”的概率。


共享Embedding。CVR-task和CTR-task使用相同的特征和特征embedding,即兩者從Concatenate之后才學(xué)習(xí)各自獨享的參數(shù); 隱式學(xué)習(xí)pCVR。這里pCVR 僅是網(wǎng)絡(luò)中的一個variable,沒有顯示的監(jiān)督信號。

EasyRec 鏈接:
https://github.com/Alibaba/EasyRec
EasyRec-ESMM鏈接:
https://github.com/alibaba/EasyRec/blob/master/easy_rec/python/model/esmm.py
def build_predict_graph(self):"""Forward function.Returns:self._prediction_dict: Prediction result of two tasks."""# 此處從Concatenate后的tensor(all_fea)開始,省略其生成邏輯cvr_tower_name = self._cvr_tower_cfg.tower_namednn_model = dnn.DNN(self._cvr_tower_cfg.dnn,self._l2_reg,name=cvr_tower_name,is_training=self._is_training)cvr_tower_output = dnn_model(all_fea)cvr_tower_output = tf.layers.dense(inputs=cvr_tower_output,units=1,kernel_regularizer=self._l2_reg,name='%s/dnn_output' % cvr_tower_name)ctr_tower_name = self._ctr_tower_cfg.tower_namednn_model = dnn.DNN(self._ctr_tower_cfg.dnn,self._l2_reg,name=ctr_tower_name,is_training=self._is_training)ctr_tower_output = dnn_model(all_fea)ctr_tower_output = tf.layers.dense(inputs=ctr_tower_output,units=1,kernel_regularizer=self._l2_reg,name='%s/dnn_output' % ctr_tower_name)tower_outputs = {cvr_tower_name: cvr_tower_output,ctr_tower_name: ctr_tower_output}self._add_to_prediction_dict(tower_outputs)return self._prediction_dict
1. loss計算
注意:計算CVR的指標(biāo)時需要mask掉曝光數(shù)據(jù)。
def build_loss_graph(self):"""Build loss graph.Returns:self._loss_dict: Weighted loss of ctr and cvr."""cvr_tower_name = self._cvr_tower_cfg.tower_namectr_tower_name = self._ctr_tower_cfg.tower_namecvr_label_name = self._label_name_dict[cvr_tower_name]ctr_label_name = self._label_name_dict[ctr_tower_name]ctcvr_label = tf.cast(self._labels[cvr_label_name] * self._labels[ctr_label_name],tf.float32)cvr_loss = tf.keras.backend.binary_crossentropy(ctcvr_label, self._prediction_dict['probs_ctcvr'])cvr_loss = tf.reduce_sum(cvr_losses, name="ctcvr_loss")# The weight defaults to 1.self._loss_dict['weighted_cross_entropy_loss_%s' %cvr_tower_name] = self._cvr_tower_cfg.weight * cvr_lossctr_loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.cast(self._labels[ctr_label_name], tf.float32),logits=self._prediction_dict['logits_%s' % ctr_tower_name]), name="ctr_loss")self._loss_dict['weighted_cross_entropy_loss_%s' %ctr_tower_name] = self._ctr_tower_cfg.weight * ctr_lossreturn self._loss_dict
def build_metric_graph(self, eval_config):"""Build metric graph.Args:eval_config: Evaluation configuration.Returns:metric_dict: Calculate AUC of ctr, cvr and ctrvr."""metric_dict = {}cvr_tower_name = self._cvr_tower_cfg.tower_namectr_tower_name = self._ctr_tower_cfg.tower_namecvr_label_name = self._label_name_dict[cvr_tower_name]ctr_label_name = self._label_name_dict[ctr_tower_name]for metric in self._cvr_tower_cfg.metrics_set:# CTCVR metricctcvr_label_name = cvr_label_name + '_ctcvr'cvr_dtype = self._labels[cvr_label_name].dtypeself._labels[ctcvr_label_name] = self._labels[cvr_label_name] * tf.cast(self._labels[ctr_label_name], cvr_dtype)metric_dict.update(self._build_metric_impl(metric,loss_type=self._cvr_tower_cfg.loss_type,label_name=ctcvr_label_name,num_class=self._cvr_tower_cfg.num_class,suffix='_ctcvr'))# CVR metriccvr_label_masked_name = cvr_label_name + '_masked'ctr_mask = self._labels[ctr_label_name] > 0self._labels[cvr_label_masked_name] = tf.boolean_mask(self._labels[cvr_label_name], ctr_mask)pred_prefix = 'probs' if self._cvr_tower_cfg.loss_type == LossType.CLASSIFICATION else 'y'pred_name = '%s_%s' % (pred_prefix, cvr_tower_name)self._prediction_dict[pred_name + '_masked'] = tf.boolean_mask(self._prediction_dict[pred_name], ctr_mask)metric_dict.update(self._build_metric_impl(metric,loss_type=self._cvr_tower_cfg.loss_type,label_name=cvr_label_masked_name,num_class=self._cvr_tower_cfg.num_class,suffix='_%s_masked' % cvr_tower_name))for metric in self._ctr_tower_cfg.metrics_set:# CTR metricmetric_dict.update(self._build_metric_impl(metric,loss_type=self._ctr_tower_cfg.loss_type,label_name=ctr_label_name,num_class=self._ctr_tower_cfg.num_class,suffix='_%s' % ctr_tower_name))return metric_dict
https://tianchi.aliyun.com/dataset/dataDetail?dataId=408&userId=1
評論
圖片
表情
