| | |
| | |
| | |
| |
|
| | |
| | import argparse |
| | import os |
| | from evaluator.CodeBLEU import bleu, weighted_ngram_match, syntax_match, dataflow_match |
| |
|
| |
|
| | def get_codebleu(refs, hyp, lang, params='0.25,0.25,0.25,0.25'): |
| | if not isinstance(refs, list): |
| | refs = [refs] |
| | alpha, beta, gamma, theta = [float(x) for x in params.split(',')] |
| |
|
| | |
| | pre_references = [[x.strip() for x in open(file, 'r', encoding='utf-8').readlines()] for file in refs] |
| | hypothesis = [x.strip() for x in open(hyp, 'r', encoding='utf-8').readlines()] |
| |
|
| | for i in range(len(pre_references)): |
| | assert len(hypothesis) == len(pre_references[i]) |
| |
|
| | references = [] |
| | for i in range(len(hypothesis)): |
| | ref_for_instance = [] |
| | for j in range(len(pre_references)): |
| | ref_for_instance.append(pre_references[j][i]) |
| | references.append(ref_for_instance) |
| | assert len(references) == len(pre_references) * len(hypothesis) |
| |
|
| | |
| | tokenized_hyps = [x.split() for x in hypothesis] |
| | tokenized_refs = [[x.split() for x in reference] for reference in references] |
| |
|
| | ngram_match_score = bleu.corpus_bleu(tokenized_refs, tokenized_hyps) |
| |
|
| | |
| | root_dir = os.path.dirname(__file__) |
| | keywords = [x.strip() for x in open(root_dir + '/keywords/' + lang + '.txt', 'r', encoding='utf-8').readlines()] |
| |
|
| | def make_weights(reference_tokens, key_word_list): |
| | return {token: 1 if token in key_word_list else 0.2 for token in reference_tokens} |
| |
|
| | tokenized_refs_with_weights = [[[reference_tokens, make_weights(reference_tokens, keywords)] \ |
| | for reference_tokens in reference] for reference in tokenized_refs] |
| |
|
| | weighted_ngram_match_score = weighted_ngram_match.corpus_bleu(tokenized_refs_with_weights, tokenized_hyps) |
| |
|
| | |
| | syntax_match_score = syntax_match.corpus_syntax_match(references, hypothesis, lang) |
| |
|
| | |
| | dataflow_match_score = dataflow_match.corpus_dataflow_match(references, hypothesis, lang) |
| |
|
| | print('ngram match: {0}, weighted ngram match: {1}, syntax_match: {2}, dataflow_match: {3}'. \ |
| | format(ngram_match_score, weighted_ngram_match_score, syntax_match_score, dataflow_match_score)) |
| |
|
| | code_bleu_score = alpha * ngram_match_score \ |
| | + beta * weighted_ngram_match_score \ |
| | + gamma * syntax_match_score \ |
| | + theta * dataflow_match_score |
| |
|
| | return code_bleu_score |
| |
|
| |
|
| | if __name__ == '__main__': |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument('--refs', type=str, nargs='+', required=True, |
| | help='reference files') |
| | parser.add_argument('--hyp', type=str, required=True, |
| | help='hypothesis file') |
| | parser.add_argument('--lang', type=str, required=True, |
| | choices=['java', 'js', 'c_sharp', 'php', 'go', 'python', 'ruby'], |
| | help='programming language') |
| | parser.add_argument('--params', type=str, default='0.25,0.25,0.25,0.25', |
| | help='alpha, beta and gamma') |
| |
|
| | args = parser.parse_args() |
| | code_bleu_score = get_codebleu(args.refs, args.hyp, args.lang, args.params) |
| | print('CodeBLEU score: ', code_bleu_score) |
| |
|
| |
|