标签:时间 __init__ sum sel 公式 math range code spl
特征域感知因子分解机(Field-aware Factorization Machines,FFM)主要解决了FM的以下几个痛点:
FFM引入特征域感知(field-aware)这一概念,使每一个特征对每一种特征域的组合都有一个单独的隐向量。这意味着FFM中每个特征包含一组隐向量。
FFM无法使用FM这样的化简方法(自己看一下FM的化简公式就知道了),所以训练和推理的时间复杂度为\(O(kn^2)\)。
特征域一般表示某一种特征,例如,“职业”可以为一个特征域,其one-hot表示可能为000100
,则表示这个特征域包含6个特征。特征域也可以有更加灵活的定义,比如表示“用户”特征或者“物品”特征,而“用户”特征是由“职业”、“年龄”等特征的embedding进行pooling得到的。具体实现非常灵活,可以因地制宜。
FFM的一阶部分可以直接复用LR,只需要额外实现各个特征域的二阶交叉。我们可以表示出每个特征的embedding向量组后,对特征域依次交叉。
class FieldAwareFactorizationMachine(torch.nn.Module):
def __init__(self, field_dims, embed_dim):
super().__init__()
self.num_fields = len(field_dims) # 特征域数量
self.embeddings = torch.nn.ModuleList(
[torch.nn.Embedding(sum(field_dims), embed_dim) for _ in range(self.num_fields)]
) # 每个特征的embedding向量组
self.offsets = np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.long)
for embedding in self.embeddings:
torch.nn.init.xavier_uniform_(embedding.weight.data)
def forward(self, x):
"""
:param x: Long tensor of size ``(batch_size, num_fields)``
"""
x = x + x.new_tensor(self.offsets).unsqueeze(0)
xs = [self.embeddings[i](x) for i in range(self.num_fields)]
ix = list()
for i in range(self.num_fields - 1):
for j in range(i + 1, self.num_fields):
ix.append(xs[j][:, i] * xs[i][:, j]) # 特征域交叉
ix = torch.stack(ix, dim=1)
return ix
最后,构建完整的FFM前向传播链路:
class FieldAwareFactorizationMachineModel(torch.nn.Module):
def __init__(self, field_dims, embed_dim=10):
super().__init__()
self.linear = FeaturesLinear(field_dims)
self.ffm = FieldAwareFactorizationMachine(field_dims, embed_dim)
def forward(self, x):
"""
:param x: Long tensor of size ``(batch_size, num_fields)``
"""
ffm_term = torch.sum(torch.sum(self.ffm(x), dim=1), dim=1, keepdim=True)
x = self.linear(x) + ffm_term
return torch.sigmoid(x.squeeze(1))
设置:
数据集:ml-100k
优化方法:Adam
学习率:0.003
效果:
收敛epoch:6
train logloss: 0.48356
val auc: 0.78662
test auc: 0.78931
【推荐算法】特征域感知因子分解机(Field-aware Factorization Machines,FFM)
标签:时间 __init__ sum sel 公式 math range code spl
原文地址:https://www.cnblogs.com/tmpUser/p/14954452.html