| """ |
| Original work: |
| https://github.com/sangHa0411/CloneDetection/blob/main/models/codebert.py#L169 |
| |
| Copyright (c) 2022 Sangha Park(sangha110495), Young Jin Ahn(snoop2head) |
| |
| All credits to the original authors. |
| """ |
| import torch.nn as nn |
| from transformers import ( |
| RobertaPreTrainedModel, |
| RobertaModel, |
| ) |
| from transformers.modeling_outputs import SequenceClassifierOutput |
|
|
|
|
| class CloneDetectionModel(RobertaPreTrainedModel): |
| _keys_to_ignore_on_load_missing = [r"position_ids"] |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.num_labels = config.num_labels |
| self.config = config |
|
|
| self.roberta = RobertaModel(config, add_pooling_layer=False) |
| self.net = nn.Sequential( |
| nn.Dropout(config.hidden_dropout_prob), |
| nn.Linear(config.hidden_size, config.hidden_size), |
| nn.ReLU(), |
| ) |
| self.classifier = nn.Linear(config.hidden_size * 4, config.num_labels) |
|
|
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| token_type_ids=None, |
| position_ids=None, |
| head_mask=None, |
| inputs_embeds=None, |
| labels=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| ): |
|
|
| return_dict = ( |
| return_dict if return_dict is not None else self.config.use_return_dict |
| ) |
|
|
| outputs = self.roberta( |
| input_ids, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| position_ids=position_ids, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| hidden_states = outputs[0] |
| batch_size, _, hidden_size = hidden_states.shape |
|
|
| |
| cls_flag = input_ids == self.config.tokenizer_cls_token_id |
| sep_flag = input_ids == self.config.tokenizer_sep_token_id |
|
|
| special_token_states = hidden_states[cls_flag + sep_flag].view( |
| batch_size, -1, hidden_size |
| ) |
| special_hidden_states = self.net( |
| special_token_states |
| ) |
|
|
| pooled_output = special_hidden_states.view( |
| batch_size, -1 |
| ) |
| logits = self.classifier(pooled_output) |
|
|
| loss = None |
| if labels is not None: |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
| if not return_dict: |
| output = (logits,) + outputs[2:] |
| return ((loss,) + output) if loss is not None else output |
|
|
| return SequenceClassifierOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|