| import torch.nn as nn |
| import torch |
| from transformers import AutoModel |
|
|
| class BERT_FFNN(nn.Module): |
| """ |
| BERT_FFNN: BERT + feed-forward network for text classification tasks. |
| """ |
| def __init__( |
| self, |
| bert_model_name= "microsoft/deberta-v3-base", |
| hidden_dims=[192, 96], |
| output_dim=5, |
| dropout=0.2, |
| pooling='attention', |
| freeze_bert=False, |
| freeze_layers=0, |
| use_layer_norm=True |
| ): |
| super().__init__() |
| |
| |
| self.bert = AutoModel.from_pretrained(bert_model_name) |
| self.use_layer_norm = use_layer_norm |
| self.pooling = pooling |
| |
| if pooling == 'attention': |
| self.attention_pool = AttentionPooling(self.bert.config.hidden_size) |
| |
| if freeze_bert: |
| for param in self.bert.parameters(): |
| param.requires_grad = False |
| elif freeze_layers > 0: |
| for layer in self.bert.encoder.layer[:freeze_layers]: |
| for param in layer.parameters(): |
| param.requires_grad = False |
|
|
| |
| fc_input_dim = self.bert.config.hidden_size |
| layers = [] |
| in_dim = fc_input_dim |
| for h_dim in hidden_dims: |
| layers.append(nn.Linear(in_dim, h_dim)) |
| layers.append(nn.ReLU()) |
| if use_layer_norm: |
| layers.append(nn.LayerNorm(h_dim)) |
| layers.append(nn.Dropout(dropout)) |
| in_dim = h_dim |
| layers.append(nn.Linear(in_dim, output_dim)) |
| self.classifier = nn.Sequential(*layers) |
| |
| def forward(self, input_ids, attention_mask): |
| |
| outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) |
| |
| if self.pooling == 'mean': |
| mask = attention_mask.unsqueeze(-1).float() |
| sum_emb = (outputs.last_hidden_state * mask).sum(1) |
| features = sum_emb / mask.sum(1).clamp(min=1e-9) |
| elif self.pooling == 'max': |
| mask = attention_mask.unsqueeze(-1).float() |
| masked_emb = outputs.last_hidden_state.masked_fill(mask == 0, float('-inf')) |
| features, _ = masked_emb.max(dim=1) |
| elif self.pooling == 'attention': |
| features = self.attention_pool(outputs.last_hidden_state, attention_mask) |
| else: |
| |
| features = outputs.pooler_output if getattr(outputs, 'pooler_output', None) is not None else outputs.last_hidden_state[:, 0] |
| |
| logits = self.classifier(features) |
| return logits |
|
|
| class AttentionPooling(nn.Module): |
| def __init__(self, hidden_size): |
| super().__init__() |
| self.attention = nn.Linear(hidden_size, 1) |
|
|
| def forward(self, hidden_states, attention_mask): |
| |
| |
|
|
| scores = self.attention(hidden_states).squeeze(-1) |
| scores = scores.masked_fill(attention_mask == 0, -1e9) |
| weights = torch.softmax(scores, dim=-1) |
|
|
| weighted_sum = torch.sum(hidden_states * weights.unsqueeze(-1), dim=1) |
| return weighted_sum |
|
|
|
|