You need to agree to share your contact information to access this model

This repository is publicly accessible, but you have to accept the conditions to access its files and content.

Log in or Sign Up to review the conditions and access this model content.

RoBERTa for mathlib state difficulty judgement

lean4の状態を入力として、証明完了までの行数を10段階("1"-"10 or more")で予測するRoBERTaモデルです。

使用方法

pipelineを使用

from transformers import pipeline

pipe = pipeline("text-classification", model="fumiyau/mathlib4_state_diff", token=HF_TOKEN)
pipe("case h.e'_2.a x x₁ x₂ x₃ x' y y₁ y₂ y₃ y' : PGame ih : ∀ (a : Args), ArgsRel a (Args.P1 x y) → P124 a hx : x.Numeric hy : y.Numeric ih' : ∀ (a : Args), ArgsRel a (Args.P24 x₁ x₂ y) → P124 a hn : (x * y).Numeric h : ∀ (i : (x * y).LeftMoves), ⟦(x * y).moveLeft i⟧ < ⟦x * y⟧ ⊢ MulOptionsLTMul (-x) (-y) ↔ ∀ (i : (-x).LeftMoves) (j : (-y).LeftMoves), ⟦-x * -y⟧ > ⟦(-x).mulOption (-y) i j⟧")
# [{'label': '1', 'score': 0.23482494056224823}]
# logitsにsoftmaxを適用した値が得られる

model, tokenizerを明示的に読み込む使いかた

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

text = """x y : ℕ,
h₁ : Prime x,
h₂ : ¬Even x,
h₃ : y > x
⊢ y ≥ 4"""

tokenizer = AutoTokenizer.from_pretrained("fumiyau/mathlib4_state_diff")
model = AutoModelForSequenceClassification.from_pretrained("fumiyau/mathlib4_state_diff")

inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
    logits = model(**inputs).logits
    # logits = tensor([[ 1.5015,  1.1391,  0.4959,  0.0449, -0.1418, -0.2950, -0.8557, -0.8465, -1.1877,  0.2097]])

predicted_class_id = logits.argmax().item()
model.config.id2label[predicted_class_id]
# '1'

評価結果

1.5epoch学習時点
適宜更新予定

Metric Value
train_loss 2.0292
eval_loss 1.8142
eval_acc 0.39312
Downloads last month
-
Safetensors
Model size
0.3B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for fumiyau/mathlib4_state_diff

Finetuned
(3915)
this model

Dataset used to train fumiyau/mathlib4_state_diff