File size: 1,388 Bytes
d20336c
 
 
 
 
 
 
 
 
 
1c6575a
 
 
 
 
 
 
 
d20336c
1c6575a
 
 
 
 
 
 
 
 
 
 
 
d20336c
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from sklearn.neural_network import MLPClassifier
import torchvision.datasets as datasets
import numpy as np
from PIL import Image
from io import BytesIO
import seaborn as sns

app = FastAPI()

# Dark mode seaborn 
sns.set_style("darkgrid")

# Load MNIST data
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=None)

X_train = mnist_trainset.data.numpy()
X_test = mnist_testset.data.numpy()
y_train = mnist_trainset.targets.numpy()
y_test = mnist_testset.targets.numpy()

# Reshape and normalize data
X_train = X_train.reshape(60000, 784) / 255.0
X_test = X_test.reshape(10000, 784) / 255.0

# Train the model
mlp = MLPClassifier(hidden_layer_sizes=(32, 32))
mlp.fit(X_train, y_train)

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    try:
        contents = await file.read()
        image = Image.open(BytesIO(contents)).convert("L").resize((28, 28))
        img_array = np.array(image)
        img_array = img_array.flatten() / 255.0
        prediction = mlp.predict(img_array.reshape(1, -1))[0]
        return JSONResponse(content={"prediction": int(prediction)})
    except Exception as e:
        return JSONResponse(content={"error": str(e)})