from fastapi import FastAPI, File, UploadFile from fastapi.responses import JSONResponse from sklearn.neural_network import MLPClassifier import torchvision.datasets as datasets import numpy as np from PIL import Image from io import BytesIO import seaborn as sns app = FastAPI() # Dark mode seaborn sns.set_style("darkgrid") # Load MNIST data mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None) mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=None) X_train = mnist_trainset.data.numpy() X_test = mnist_testset.data.numpy() y_train = mnist_trainset.targets.numpy() y_test = mnist_testset.targets.numpy() # Reshape and normalize data X_train = X_train.reshape(60000, 784) / 255.0 X_test = X_test.reshape(10000, 784) / 255.0 # Train the model mlp = MLPClassifier(hidden_layer_sizes=(32, 32)) mlp.fit(X_train, y_train) @app.post("/predict") async def predict(file: UploadFile = File(...)): try: contents = await file.read() image = Image.open(BytesIO(contents)).convert("L").resize((28, 28)) img_array = np.array(image) img_array = img_array.flatten() / 255.0 prediction = mlp.predict(img_array.reshape(1, -1))[0] return JSONResponse(content={"prediction": int(prediction)}) except Exception as e: return JSONResponse(content={"error": str(e)})