|
|
from fastapi import FastAPI, File, UploadFile |
|
|
from fastapi.responses import JSONResponse |
|
|
from sklearn.neural_network import MLPClassifier |
|
|
import torchvision.datasets as datasets |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from io import BytesIO |
|
|
import seaborn as sns |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
sns.set_style("darkgrid") |
|
|
|
|
|
|
|
|
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None) |
|
|
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=None) |
|
|
|
|
|
X_train = mnist_trainset.data.numpy() |
|
|
X_test = mnist_testset.data.numpy() |
|
|
y_train = mnist_trainset.targets.numpy() |
|
|
y_test = mnist_testset.targets.numpy() |
|
|
|
|
|
|
|
|
X_train = X_train.reshape(60000, 784) / 255.0 |
|
|
X_test = X_test.reshape(10000, 784) / 255.0 |
|
|
|
|
|
|
|
|
mlp = MLPClassifier(hidden_layer_sizes=(32, 32)) |
|
|
mlp.fit(X_train, y_train) |
|
|
|
|
|
@app.post("/predict") |
|
|
async def predict(file: UploadFile = File(...)): |
|
|
try: |
|
|
contents = await file.read() |
|
|
image = Image.open(BytesIO(contents)).convert("L").resize((28, 28)) |
|
|
img_array = np.array(image) |
|
|
img_array = img_array.flatten() / 255.0 |
|
|
prediction = mlp.predict(img_array.reshape(1, -1))[0] |
|
|
return JSONResponse(content={"prediction": int(prediction)}) |
|
|
except Exception as e: |
|
|
return JSONResponse(content={"error": str(e)}) |
|
|
|