File size: 1,391 Bytes
a0e0ff1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 | import torch
import torch.nn as nn
from transformers import BertConfig, BertForMaskedLM
from config import Config
class CustomBERTModel(nn.Module):
def __init__(self, config):
super(CustomBERTModel, self).__init__()
self.input_proj = nn.Linear(config.input_dim, config.hidden_size)
bert_config = BertConfig(
vocab_size=config.vocab_size,
hidden_size=config.hidden_size,
num_hidden_layers=config.num_hidden_layers,
num_attention_heads=config.num_attention_heads,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
hidden_dropout_prob=config.hidden_dropout_prob,
attention_probs_dropout_prob=config.attention_probs_dropout_prob,
max_position_embeddings=config.max_position_embeddings,
type_vocab_size=config.type_vocab_size,
initializer_range=config.initializer_range,
layer_norm_eps=config.layer_norm_eps
)
self.bert = BertForMaskedLM(bert_config)
def forward(self, x, labels=None):
x = self.input_proj(x)
outputs = self.bert(inputs_embeds=x, labels=labels)
return outputs
def get_encoder_output(self, x):
x = self.input_proj(x)
outputs = self.bert.bert(inputs_embeds=x)
return outputs.last_hidden_state
|