| |
| """context |
| |
| Automatically generated by Colaboratory. |
| |
| Original file is located at |
| https://colab.research.google.com/drive/1qLh1aASQj5HIENPZpHQltTuShZny_567 |
| """ |
|
|
| |
|
|
| |
| |
| import os |
| import json |
| import wanb |
| from pprint import pprint |
|
|
| import torch |
| from torch.utils.data import Dataset |
| from torch.utils.data import DataLoader |
| from transformers import AdamW |
| from tqdm.notebook import tqdm |
| from transformers import BertForQuestionAnswering,BertTokenizer,BertTokenizerFast |
|
|
| import numpy as np |
| import matplotlib.pyplot as plt |
| import seaborn as sns |
| import pandas as pd |
| |
|
|
| |
| wandb.login() |
|
|
| |
| PROJECT_NAME="context" |
| ENTITY=None |
|
|
| sweep_config={ |
| 'method':'random' |
| } |
|
|
| |
| metric = { |
| 'name': 'Validation accuracy', |
| 'goal': 'maximize' |
| } |
| sweep_config['metric'] = metric |
|
|
| |
| parameters_dict = { |
| 'epochs':{ |
| 'values': [1] |
| }, |
| 'optimizer':{ |
| 'values': ['sgd','adam'] |
| }, |
| 'momentum':{ |
| 'distribution': 'uniform', |
| 'min': 0.5, |
| 'max': 0.99 |
| }, |
| 'batch_size':{ |
| 'distribution': 'q_log_uniform_values', |
| 'q': 8, |
| 'min': 16, |
| 'max': 256 |
| } |
| } |
| sweep_config['parameters'] = parameters_dict |
|
|
| |
| pprint(sweep_config) |
|
|
| |
| sweep_id=wandb.sweep(sweep_config,project=PROJECT_NAME,entity=ENTITY) |
|
|
| |
| from google.colab import drive |
| drive.mount('/content/drive') |
|
|
| if not os.path.exists('/content/drive/MyDrive/BERT-SQuAD'): |
| os.mkdir('/content/drive/MyDrive/BERT-SQuAD') |
|
|
| |
| |
| |
|
|
| """Load the training dataset and take a look at it""" |
| with open('train-v2.0.json','rb') as f: |
| squad=json.load(f) |
|
|
| |
| squad['data'][150]['paragraphs'][0]['context'] |
|
|
| """Load the dev dataset and take a look at it""" |
| def read_data(path): |
|
|
| with open(path,'rb') as f: |
| squad=json.load(f) |
|
|
| contexts=[] |
| questions=[] |
| answers=[] |
| for group in squad['data']: |
| for passage in group['paragraphs']: |
| context=passage['context'] |
| for qna in passage['qas']: |
| question=qna['question'] |
| for answer in qna['answers']: |
| contexts.append(context) |
| questions.append(question) |
| answers.append(answer) |
| return contexts,questions,answers |
|
|
|
|
| |
| """ |
| The answers are dictionaries whith the answer text and an integer which indicates the start index of the answer in the context. |
| """ |
| train_contexts,train_questions,train_answers=read_data('train-v2.0.json') |
| valid_contexts,valid_questions,valid_answers=read_data('dev-v2.0.json') |
| |
|
|
| |
| def end_idx(answers,contexts): |
| for answers,context in zip(answers,contexts): |
| gold_text=answers['text'] |
| start_idx=answers['answer_start'] |
| end_idx=start_idx+len(gold_text) |
|
|
| |
| if context[start_idx:end_idx] == gold_text: |
| answers['answer_end'] = end_idx |
| elif context[start_idx-1:end_idx-1] == gold_text: |
| answers['answer_start'] = start_idx - 1 |
| answers['answer_end'] = end_idx - 1 |
| elif context[start_idx-2:end_idx-2] == gold_text: |
| answers['answer_start'] = start_idx - 2 |
| answers['answer_end'] = end_idx - 2 |
|
|
|
|
| """"Tokenization""" |
| tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') |
| train_encodings = tokenizer(train_contexts, train_questions, truncation=True, padding=True) |
| valid_encodings = tokenizer(valid_contexts, valid_questions, truncation=True, padding=True) |
|
|
| |
|
|
| |
| def add_token_positions(encodings,answers): |
| start_positions=[] |
| end_positions=[] |
| for i in range(len(answers)): |
| start_positions.append(encodings.char_to_token(i,answers[i]['answer_start'])) |
| end_positions.append(encodings.char_to_token(i,answers[i]['answer_end'])) |
|
|
| |
| if start_positions[-1] is None: |
| start_positions[-1] = tokenizer.model_max_length |
| if end_positions[-1] is None: |
| end_positions[-1] = tokenizer.model_max_length |
|
|
| encodings.update({'start_positions': start_positions, 'end_positions': end_positions}) |
|
|
|
|
| """Dataloader for the training dataset""" |
| class DatasetRetriever(Dataset): |
| def __init__(self,encodings): |
| self.encodings=encodings |
|
|
| def __getitem__(self,idx): |
| return {key:torch.tensor(val[idx]) for key,val in self.encodings.items()} |
|
|
| def __len__(self): |
| return len(self.encodings.input_ids) |
|
|
| |
| train_dataset=DatasetRetriever(train_encodings) |
| valid_dataset=DatasetRetriever(valid_encodings) |
| train_loader=DataLoader(train_dataset,batch_size=16,shuffle=True) |
| valid_loader=DataLoader(valid_dataset,batch_size=16) |
| model = BertForQuestionAnswering.from_pretrained("bert-base-uncased") |
| device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') |
|
|
| |
| def pipeline(): |
| epochs=1, |
| optimizer = torch.optim.AdamW(model.parameters(),lr=5e-5) |
| |
| with wandb.init(config=None): |
| config=wandb.config |
| model.to(device) |
|
|
| |
| model.train() |
| for epoch in range(config.epochs): |
| loop = tqdm(train_loader, leave=True) |
| for batch in loop: |
| optimizer.zero_grad() |
| input_ids = batch['input_ids'].to(device) |
| attention_mask = batch['attention_mask'].to(device) |
| start_positions = batch['start_positions'].to(device) |
| end_positions = batch['end_positions'].to(device) |
| outputs = model(input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions) |
| loss = outputs[0] |
| loss.backward() |
| optimizer.step() |
| |
| loop.set_description(f'Epoch {epoch+1}') |
| loop.set_postfix(loss=loss.item()) |
| wandb.log({'Validation Loss':loss}) |
|
|
| |
| model.eval() |
| acc=[] |
| for batch in tqdm(valid_loader): |
| with torch.no_grad(): |
| input_ids=batch['input_ids'].to(device) |
| attention_mask=batch['attention_mask'].to(device) |
| start_true=batch['start_positions'].to(device) |
| end_true=batch['end_positions'].to(device) |
| |
| outputs=model(input_ids,attention_mask=attention_mask) |
| |
| start_pred=torch.argmax(outputs['start_logits'],dim=1) |
| end_pred=torch.argmax(outputs['end_logits'],dim=1) |
| |
| acc.append(((start_pred == start_true).sum()/len(start_pred)).item()) |
| acc.append(((end_pred == end_true).sum()/len(end_pred)).item()) |
| |
| acc = sum(acc)/len(acc) |
| |
| print("\n\nT/P\tanswer_start\tanswer_end\n") |
| for i in range(len(start_true)): |
| print(f"true\t{start_true[i]}\t{end_true[i]}\n" |
| f"pred\t{start_pred[i]}\t{end_pred[i]}\n") |
| wandb.log({'Validation accuracy': acc}) |
|
|
| |
| wandb.agent(sweep_id, pipeline, count = 4) |
|
|
|
|
| """Save the model so we dont have to train it again""" |
| model_path = '/content/drive/MyDrive/BERT-SQuAD' |
| model.save_pretrained(model_path) |
| tokenizer.save_pretrained(model_path) |
|
|
| """Load the model""" |
| model_path = '/content/drive/MyDrive/BERT-SQuAD' |
| model = BertForQuestionAnswering.from_pretrained(model_path) |
| tokenizer = BertTokenizerFast.from_pretrained(model_path) |
| device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') |
| model = model.to(device) |
|
|
|
|
|
|
| |
| def get_prediction(context,answer): |
| inputs=tokenizer.encode_plus(question,context,return_tensors='pt').to(device) |
| outputs=model(**inputs) |
| answer_start=torch.argmax(outputs[0]) |
| answer_end=torch.argmax(outputs[1])+1 |
| answer = tokenizer.convert_tokens_to_string(tokenizer. |
| convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end])) |
| return answer |
|
|
|
|
| """ |
| Question testing |
| |
| Official SQuAD evaluation script--> |
| https://colab.research.google.com/github/fastforwardlabs/ff14_blog/blob/master/_notebooks/2020-06-09-Evaluating_BERT_on_SQuAD.ipynb#scrollTo=MzPlHgWEBQ8D |
| """ |
|
|
| def normalize_text(s): |
| """Removing articles and punctuation, and standardizing whitespace are all typical text processing steps.""" |
| import string, re |
| def remove_articles(text): |
| regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) |
| return re.sub(regex, " ", text) |
| def white_space_fix(text): |
| return " ".join(text.split()) |
| def remove_punc(text): |
| exclude = set(string.punctuation) |
| return "".join(ch for ch in text if ch not in exclude) |
| def lower(text): |
| return text.lower() |
|
|
| return white_space_fix(remove_articles(remove_punc(lower(s)))) |
|
|
| def exact_match(prediction, truth): |
| return bool(normalize_text(prediction) == normalize_text(truth)) |
|
|
| def compute_f1(prediction, truth): |
| pred_tokens = normalize_text(prediction).split() |
| truth_tokens = normalize_text(truth).split() |
| |
| |
| if len(pred_tokens) == 0 or len(truth_tokens) == 0: |
| return int(pred_tokens == truth_tokens) |
| |
| common_tokens = set(pred_tokens) & set(truth_tokens) |
| |
| |
| if len(common_tokens) == 0: |
| return 0 |
| |
| prec = len(common_tokens) / len(pred_tokens) |
| rec = len(common_tokens) / len(truth_tokens) |
| |
| return round(2 * (prec * rec) / (prec + rec), 2) |
|
|
| def question_answer(context, question,answer): |
| prediction = get_prediction(context,question) |
| em_score = exact_match(prediction, answer) |
| f1_score = compute_f1(prediction, answer) |
| |
| print(f'Question: {question}') |
| print(f'Prediction: {prediction}') |
| print(f'True Answer: {answer}') |
| print(f'Exact match: {em_score}') |
| print(f'F1 score: {f1_score}\n') |
|
|
| context = """Space exploration is a very exciting field of research. It is the |
| frontier of Physics and no doubt will change the understanding of science. |
| However, it does come at a cost. A normal space shuttle costs about 1.5 billion dollars to make. |
| The annual budget of NASA, which is a premier space exploring organization is about 17 billion. |
| So the question that some people ask is that whether it is worth it.""" |
|
|
|
|
| questions =["What wil change the understanding of science?", |
| "What is the main idea in the paragraph?"] |
|
|
| answers = ["Space Exploration", |
| "The cost of space exploration is too high"] |
|
|
| """ |
| VISUALISATION IN PROGRESS |
| |
| for question, answer in zip(questions, answers): |
| question_answer(context, question, answer) |
| |
| #Visualize the start scores |
| plt.rcParams["figure.figsize"]=(20,10) |
| ax=sns.barplot(x=token_labels,y=start_scores) |
| ax.set_xticklabels(ax.get_xticklabels(),rotation=90,ha="center") |
| ax.grid(True) |
| plt.title("Start word scores") |
| plt.show() |
| |
| #Visualize the end scores |
| plt.rcParams["figure.figsize"]=(20,10) |
| ax=sns.barplot(x=token_labels,y=end_scores) |
| ax.set_xticklabels(ax.get_xticklabels(),rotation=90,ha="center") |
| ax.grid(True) |
| plt.title("End word scores") |
| plt.show() |
| |
| #Visualize both the scores |
| scores=[] |
| for (i,token_label) in enumerate(token_labels): |
| # Add the token's start score as one row. |
| scores.append({'token_label':token_label, |
| 'score':start_scores[i], |
| 'marker':'start'}) |
| |
| # Add the token's end score as another row. |
| scores.append({'token_label': token_label, |
| 'score': end_scores[i], |
| 'marker': 'end'}) |
| |
| df=pd.DataFrame(scores) |
| group_plot=sns.catplot(x="token_label",y="score",hue="marker",data=df, |
| kind="bar",height=6,aspect=4) |
| |
| group_plot.set_xticklabels(ax.get_xticklabels(),rotation=90,ha="center") |
| group_plot.ax.grid(True) |
| """ |
|
|