对比学习是一种通过对比正反两个例子来学习表征的自监督学习方法。对于自监督对比学习,下一个等式是对比损失:
在很多情况下,对比学习只需要对每一个样本生成一个正样本,同一个batch内的其他样本作为负样本,实现如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
def contrastive_loss(x, x_aug, T): """ :param x: the hidden vectors of original data :param x_aug: the positive vector of the auged data :param T: temperature :return: loss """ batch_size, _ = x.size() x_abs = x.norm(dim=1) x_aug_abs = x_aug.norm(dim=1) sim_matrix = torch.einsum('ik,jk->ij', x, x_aug) / torch.einsum('i,j->ij', x_abs, x_aug_abs) sim_matrix = torch.exp(sim_matrix / T) pos_sim = sim_matrix[range(batch_size), range(batch_size)] loss = pos_sim / (sim_matrix.sum(dim=1) - pos_sim) loss = - torch.log(loss).mean() return loss |
如果要用生成的负样本进行对比,代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
def info_nce_loss(self, features): labels = torch.cat([torch.arange(self.args.batch_size) for i in range(self.args.n_views)], dim=0) labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float() labels = labels.to(self.args.device) features = F.normalize(features, dim=1) similarity_matrix = torch.matmul(features, features.T) # assert similarity_matrix.shape == ( # self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size) # assert similarity_matrix.shape == labels.shape # discard the main diagonal from both: labels and similarities matrix mask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.args.device) labels = labels[~mask].view(labels.shape[0], -1) similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1) # assert similarity_matrix.shape == labels.shape # select and combine multiple positives positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1) # select only the negatives the negatives negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1) logits = torch.cat([positives, negatives], dim=1) labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.args.device) logits = logits / self.args.temperature return logits, labels self.criterion = torch.nn.CrossEntropyLoss() loss = self.criterion(logits, labels) |