File size: 1,391 Bytes
a0e0ff1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
import torch.nn as nn
from transformers import BertConfig, BertForMaskedLM
from config import Config

class CustomBERTModel(nn.Module):
    def __init__(self, config):
        super(CustomBERTModel, self).__init__()
        self.input_proj = nn.Linear(config.input_dim, config.hidden_size)
        
        bert_config = BertConfig(
            vocab_size=config.vocab_size,
            hidden_size=config.hidden_size,
            num_hidden_layers=config.num_hidden_layers,
            num_attention_heads=config.num_attention_heads,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
            hidden_dropout_prob=config.hidden_dropout_prob,
            attention_probs_dropout_prob=config.attention_probs_dropout_prob,
            max_position_embeddings=config.max_position_embeddings,
            type_vocab_size=config.type_vocab_size,
            initializer_range=config.initializer_range,
            layer_norm_eps=config.layer_norm_eps
        )
        
        self.bert = BertForMaskedLM(bert_config)

    def forward(self, x, labels=None):
        x = self.input_proj(x)
        outputs = self.bert(inputs_embeds=x, labels=labels)
        return outputs

    def get_encoder_output(self, x):
        x = self.input_proj(x)
        outputs = self.bert.bert(inputs_embeds=x)
        return outputs.last_hidden_state