|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import AutoModel, AutoConfig |
|
|
|
|
|
|
|
|
class ScalingLawForecaster(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
base_model_name: str = "HuggingFaceTB/SmolLM2-135M", |
|
|
init_from_pretrained: bool = True, |
|
|
force_fp32: bool = False, |
|
|
): |
|
|
super().__init__() |
|
|
self.config = AutoConfig.from_pretrained(base_model_name) |
|
|
if force_fp32: |
|
|
self.config.torch_dtype = torch.float32 |
|
|
if init_from_pretrained: |
|
|
if force_fp32: |
|
|
self.base = AutoModel.from_pretrained( |
|
|
base_model_name, |
|
|
config=self.config, |
|
|
torch_dtype=torch.float32, |
|
|
) |
|
|
else: |
|
|
self.base = AutoModel.from_pretrained(base_model_name, config=self.config) |
|
|
else: |
|
|
self.base = AutoModel.from_config(self.config) |
|
|
|
|
|
hidden_size = self.config.hidden_size |
|
|
|
|
|
act_cls = nn.ReLU |
|
|
self.num_mlp = nn.Sequential( |
|
|
nn.Linear(1, hidden_size * 2), |
|
|
act_cls(), |
|
|
nn.Linear(hidden_size * 2, hidden_size) |
|
|
) |
|
|
|
|
|
self.head = nn.Linear(hidden_size, 1) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor, |
|
|
is_number_mask: torch.BoolTensor, |
|
|
number_values_filled: torch.FloatTensor, |
|
|
attention_mask: torch.BoolTensor = None |
|
|
) -> torch.FloatTensor: |
|
|
""" |
|
|
Args: |
|
|
input_ids: (batch, seq_len) |
|
|
is_number_mask: (batch, seq_len) bool mask for numeric tokens |
|
|
number_values_filled:(batch, seq_len) float values (0 for non-numeric) |
|
|
attention_mask: (batch, seq_len) optional |
|
|
Returns: |
|
|
logits: (batch, seq_len) scalar predictions per token |
|
|
""" |
|
|
|
|
|
input_ids[input_ids == 49152] = 0 |
|
|
text_emb = self.base.get_input_embeddings()(input_ids) |
|
|
|
|
|
|
|
|
flat_vals = number_values_filled.view(-1, 1) |
|
|
mlp_out = self.num_mlp(flat_vals) |
|
|
mlp_out = mlp_out.view_as(text_emb) |
|
|
|
|
|
mask = is_number_mask.unsqueeze(-1) |
|
|
inputs_embeds = torch.where(mask, mlp_out, text_emb) |
|
|
|
|
|
outputs = self.base( |
|
|
inputs_embeds=inputs_embeds, |
|
|
attention_mask=attention_mask, |
|
|
return_dict=True |
|
|
) |
|
|
hidden = outputs.last_hidden_state |
|
|
|
|
|
|
|
|
logits = self.head(hidden).squeeze(-1) |
|
|
return logits |
|
|
|
|
|
|