import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, TensorDataset import numpy as np from model import CustomBERTModel from config import Config import pandas as pd from tqdm import tqdm def load_data(file_path): df = pd.read_csv(file_path, header=None) return torch.tensor(df.values, dtype=torch.float32) def create_mlm_data(data, mlm_probability): labels = data.clone() probability_matrix = torch.full(labels.shape, mlm_probability) masked_indices = torch.bernoulli(probability_matrix).bool() labels[~masked_indices] = -100 # We only compute loss on masked tokens # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices data[indices_replaced] = 0 # Assume 0 is the representation of [MASK] # 10% of the time, we replace masked input tokens with random word indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced random_words = torch.randint(Config.vocab_size, labels.shape, dtype=torch.long) data[indices_random] = random_words[indices_random].float() return data, labels def train(): config = Config() model = CustomBERTModel(config).to(config.device) optimizer = optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay) train_data = load_data(config.train_file) val_data = load_data(config.val_file) train_dataset = TensorDataset(train_data) val_dataset = TensorDataset(val_data) train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=config.batch_size) for epoch in range(config.num_train_epochs): model.train() total_loss = 0 for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.num_train_epochs}"): inputs = batch[0].to(config.device) masked_inputs, labels = create_mlm_data(inputs, config.mlm_probability) optimizer.zero_grad() outputs = model(masked_inputs, labels=labels) loss = outputs.loss loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) optimizer.step() total_loss += loss.item() avg_train_loss = total_loss / len(train_loader) print(f"Epoch {epoch+1}/{config.num_train_epochs}, Average training loss: {avg_train_loss:.4f}") # Validation model.eval() total_val_loss = 0 with torch.no_grad(): for batch in val_loader: inputs = batch[0].to(config.device) masked_inputs, labels = create_mlm_data(inputs, config.mlm_probability) outputs = model(masked_inputs, labels=labels) total_val_loss += outputs.loss.item() avg_val_loss = total_val_loss / len(val_loader) print(f"Validation loss: {avg_val_loss:.4f}") # Save the model torch.save(model.state_dict(), "bert_mlm_model.pth") print("Model saved as bert_mlm_model.pth") if __name__ == "__main__": train()