Spaces:
Sleeping
Sleeping
File size: 3,998 Bytes
db33294 b5017bd db33294 b5017bd db33294 2d81861 db33294 b5017bd db33294 b5017bd db33294 f280e6e db33294 b5017bd db33294 b5017bd db33294 f280e6e db33294 472e9bb fdf5004 db33294 b5017bd c22f142 db33294 f2d04f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import tensorflow as tf
import numpy as np
from PIL import Image
import io
import uvicorn
import tempfile
import cv2
# Initialize FastAPI app
app = FastAPI(title="Plant Disease Detection API", version="1.0.0")
# Add CORS middleware to allow requests from your frontend
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, replace with your frontend URL
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load your model
try:
model = tf.keras.models.load_model("trained_modela.keras")
except Exception as e:
raise RuntimeError(f"Failed to load model: {e}")
# Define your class names (update with your actual classes)
class_name = ['Apple___Apple_scab',
'Apple___Black_rot',
'Apple___Cedar_apple_rust',
'Apple___healthy',
'Blueberry___healthy',
'Cherry_(including_sour)___Powdery_mildew',
'Cherry_(including_sour)___healthy',
'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot',
'Corn_(maize)___Common_rust_',
'Corn_(maize)___Northern_Leaf_Blight',
'Corn_(maize)___healthy',
'Grape___Black_rot',
'Grape___Esca_(Black_Measles)',
'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)',
'Grape___healthy',
'Orange___Haunglongbing_(Citrus_greening)',
'Peach___Bacterial_spot',
'Peach___healthy',
'Pepper,_bell___Bacterial_spot',
'Pepper,_bell___healthy',
'Potato___Early_blight',
'Potato___Late_blight',
'Potato___healthy',
'Raspberry___healthy',
'Soybean___healthy',
'Squash___Powdery_mildew',
'Strawberry___Leaf_scorch',
'Strawberry___healthy',
'Tomato___Bacterial_spot',
'Tomato___Early_blight',
'Tomato___Late_blight',
'Tomato___Leaf_Mold',
'Tomato___Septoria_leaf_spot',
'Tomato___Spider_mites Two-spotted_spider_mite',
'Tomato___Target_Spot',
'Tomato___Tomato_Yellow_Leaf_Curl_Virus',
'Tomato___Tomato_mosaic_virus',
'Tomato___healthy']
@app.get("/")
async def root():
print("dfhkjfdshu")
return {"message": "Plant Disease Detection API", "version": "1.0.0"}
@app.post("/predict")
async def predict_disease(file: UploadFile = File(...)):
if not file.content_type.startswith('image/'):
raise HTTPException(status_code=400, detail="File must be an image")
try:
# Validate file type
# Validate file type
# Save uploaded file temporarily
with tempfile.NamedTemporaryFile(suffix=".jpeg", delete=False) as tmp:
temp_path = tmp.name
tmp.write(await file.read())
tmp.flush() # Ensure data is written
# Read image using OpenCV
# img = cv2.imread(temp_path)
# if img is None:
# raise HTTPException(status_code=400, detail="Invalid image file")
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
image = tf.keras.preprocessing.image.load_img(temp_path,target_size=(128, 128))
input_arr = tf.keras.preprocessing.image.img_to_array(image)
input_arr = np.array([input_arr]) # Convert single image to batch
# Predict
prediction = model.predict(input_arr)
result_index = np.argmax(prediction)
confidence = float(prediction[0][result_index])
disease_name = class_name[result_index]
return {
"success": True,
"disease": disease_name,
"confidence": confidence
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
@app.get("/health")
async def health_check():
return {"status": "healthy"}
@app.get("/classes")
async def get_classes():
"""Get all available disease classes"""
return {"classes": class_name}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)
|