from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.middleware.cors import CORSMiddleware import tensorflow as tf import numpy as np from PIL import Image import io import uvicorn import tempfile import cv2 # Initialize FastAPI app app = FastAPI(title="Plant Disease Detection API", version="1.0.0") # Add CORS middleware to allow requests from your frontend app.add_middleware( CORSMiddleware, allow_origins=["*"], # In production, replace with your frontend URL allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Load your model try: model = tf.keras.models.load_model("trained_modela.keras") except Exception as e: raise RuntimeError(f"Failed to load model: {e}") # Define your class names (update with your actual classes) class_name = ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', 'Blueberry___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy', 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', 'Corn_(maize)___Common_rust_', 'Corn_(maize)___Northern_Leaf_Blight', 'Corn_(maize)___healthy', 'Grape___Black_rot', 'Grape___Esca_(Black_Measles)', 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy', 'Orange___Haunglongbing_(Citrus_greening)', 'Peach___Bacterial_spot', 'Peach___healthy', 'Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', 'Raspberry___healthy', 'Soybean___healthy', 'Squash___Powdery_mildew', 'Strawberry___Leaf_scorch', 'Strawberry___healthy', 'Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight', 'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot', 'Tomato___Spider_mites Two-spotted_spider_mite', 'Tomato___Target_Spot', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', 'Tomato___Tomato_mosaic_virus', 'Tomato___healthy'] @app.get("/") async def root(): print("dfhkjfdshu") return {"message": "Plant Disease Detection API", "version": "1.0.0"} @app.post("/predict") async def predict_disease(file: UploadFile = File(...)): if not file.content_type.startswith('image/'): raise HTTPException(status_code=400, detail="File must be an image") try: # Validate file type # Validate file type # Save uploaded file temporarily with tempfile.NamedTemporaryFile(suffix=".jpeg", delete=False) as tmp: temp_path = tmp.name tmp.write(await file.read()) tmp.flush() # Ensure data is written # Read image using OpenCV # img = cv2.imread(temp_path) # if img is None: # raise HTTPException(status_code=400, detail="Invalid image file") # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) image = tf.keras.preprocessing.image.load_img(temp_path,target_size=(128, 128)) input_arr = tf.keras.preprocessing.image.img_to_array(image) input_arr = np.array([input_arr]) # Convert single image to batch # Predict prediction = model.predict(input_arr) result_index = np.argmax(prediction) confidence = float(prediction[0][result_index]) disease_name = class_name[result_index] return { "success": True, "disease": disease_name, "confidence": confidence } except Exception as e: raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}") @app.get("/health") async def health_check(): return {"status": "healthy"} @app.get("/classes") async def get_classes(): """Get all available disease classes""" return {"classes": class_name} if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)