# -*- coding: utf-8 -*- from __future__ import annotations import copy from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F class RelationAwareTransformerEncoder(nn.Module): def __init__( self, layer: nn.Module, n_layers: int = 6, n_model: int = 1024, pre_norm: bool = False, ) -> RelationAwareTransformerEncoder: super(RelationAwareTransformerEncoder, self).__init__() self.n_layers = n_layers self.n_model = n_model self.pre_norm = pre_norm self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(n_layers)]) self.norm = nn.LayerNorm(n_model) if self.pre_norm else None def forward(self, x: torch.Tensor, rels: torch.LongTensor, mask: torch.BoolTensor) -> torch.Tensor: x = x.transpose(0, 1) for layer in self.layers: x = layer(x, rels, mask) if self.pre_norm: x = self.norm(x) return x.transpose(0, 1) class RelationAwareTransformerEncoderLayer(nn.Module): def __init__( self, n_rels: int, n_heads: int = 8, n_model: int = 1024, n_inner: int = 2048, activation: str = 'relu', pre_norm: bool = False, attn_dropout: float = 0.1, ffn_dropout: float = 0.1, dropout: float = 0.1, ) -> RelationAwareTransformerEncoderLayer: super(RelationAwareTransformerEncoderLayer, self).__init__() self.attn = RelationAwareMultiHeadAttention(n_rels=n_rels, n_heads=n_heads, n_model=n_model, n_embed=n_model//n_heads, dropout=attn_dropout) self.attn_norm = nn.LayerNorm(n_model) self.ffn = PositionwiseFeedForward(n_model=n_model, n_inner=n_inner, activation=activation, dropout=ffn_dropout) self.ffn_norm = nn.LayerNorm(n_model) self.dropout = nn.Dropout(dropout) self.pre_norm = pre_norm def forward(self, x: torch.Tensor, rels: torch.LongTensor, mask: torch.BoolTensor) -> torch.Tensor: if self.pre_norm: n = self.attn_norm(x) x = x + self.dropout(self.attn(n, n, n, rels, mask)) n = self.ffn_norm(x) x = x + self.dropout(self.ffn(n)) else: x = self.attn_norm(x + self.dropout(self.attn(x, x, x, rels, mask))) x = self.ffn_norm(x + self.dropout(self.ffn(x))) return x class RelationAwareMultiHeadAttention(nn.Module): def __init__( self, n_rels: int, n_heads: int = 8, n_model: int = 1024, n_embed: int = 128, dropout: float = 0.1, attn: bool = False ) -> RelationAwareMultiHeadAttention: super(RelationAwareMultiHeadAttention, self).__init__() self.n_rels = n_rels self.n_heads = n_heads self.n_model = n_model self.n_embed = n_embed self.scale = n_embed**0.5 self.rel_embed = nn.Embedding(num_embeddings=n_rels, embedding_dim=n_embed) self.wq = nn.Parameter(torch.zeros(n_model, n_heads * n_embed)) self.wk = nn.Parameter(torch.zeros(n_model, n_heads * n_embed)) self.wv = nn.Parameter(torch.zeros(n_model, n_heads * n_embed)) self.wo = nn.Parameter(torch.zeros(n_heads * n_embed, n_model)) self.bu = nn.Parameter(torch.zeros(n_heads, n_embed)) self.bv = nn.Parameter(torch.zeros(n_heads, n_embed)) self.dropout = nn.Dropout(dropout) self.attn = attn self.reset_parameters() def reset_parameters(self): nn.init.xavier_uniform_(self.rel_embed.weight, 2 ** -0.5) # borrowed from https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/multihead_attention.py nn.init.xavier_uniform_(self.wq, 2 ** -0.5) nn.init.xavier_uniform_(self.wk, 2 ** -0.5) nn.init.xavier_uniform_(self.wv, 2 ** -0.5) nn.init.xavier_uniform_(self.wo) def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, rels: torch.LongTensor, mask: torch.BoolTensor, attn_mask: Optional[torch.BoolTensor] = None ) -> torch.Tensor: batch_size, _ = mask.shape # [seq_len, batch_size, n_heads, n_embed] q = F.linear(q, self.wq).view(-1, batch_size, self.n_heads, self.n_embed) # [src_len, batch_size * n_heads, n_embed] k = F.linear(k, self.wk).view(-1, batch_size * self.n_heads, self.n_embed) v = F.linear(v, self.wv).view(-1, batch_size * self.n_heads, self.n_embed) # [seq_len, batch_size * n_heads, n_embed] qu = (q + self.bu).view(-1, *k.shape[1:]) # [seq_len, batch_size, n_heads, n_embed] qv = q + self.bv rel_mask = rels.ge(0) & mask.unsqueeze(1) & mask.unsqueeze(2) rel_indices = torch.where(rel_mask) mask = mask.unsqueeze(1).repeat(1, self.n_heads, 1).view(-1, 1, *mask.shape[1:]) if attn_mask is not None: mask = mask & attn_mask # [batch_size * n_heads, seq_len, src_len] attn = torch.bmm(qu.transpose(0, 1), k.movedim((0, 1), (2, 0))) # [seq_len, batch_size * n_heads, n_embed] rel_attn = torch.bmm(qv[rel_indices[1], rel_indices[0]], self.rel_embed(rels[rel_mask]).unsqueeze(-1)).squeeze(-1) # [batch_size, seq_len, src_len, n_heads] rel_attn = attn.new_zeros(batch_size, *attn.shape[-2:], self.n_heads).masked_scatter_(rel_mask.unsqueeze(-1), rel_attn) attn = attn + rel_attn.movedim(-1, 1).reshape_as(attn) attn = torch.softmax(attn / self.scale + torch.where(mask, 0., float('-inf')), -1) attn = self.dropout(attn) # [seq_len, batch_size * n_heads, n_embed] x = torch.bmm(attn, v.transpose(0, 1)).transpose(0, 1) # [seq_len, batch_size, n_model] x = F.linear(x.reshape(-1, batch_size, self.n_heads * self.n_embed), self.wo) return (x, attn.view(batch_size, self.n_heads, *attn.shape[1:])) if self.attn else x