Spaces:
Running
Running
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
import joblib | |
import pandas as pd | |
import logging | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
model = joblib.load('ModelV2.joblib') | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
async def predict(data: dict): | |
try: | |
# Map input keys to expected column names | |
column_mapping = { | |
"crop_name": "Crop Name", | |
"target_yield": "Target Yield", | |
"field_size": "Field Size", | |
"ph": "pH (water)", | |
"organic_carbon": "Organic Carbon", | |
"nitrogen": "Total Nitrogen", | |
"phosphorus": "Phosphorus (M3)", | |
"potassium": "Potassium (exch.)", | |
"soil_moisture": "Soil moisture" | |
} | |
# Create a new dictionary with mapped keys | |
mapped_data = {column_mapping.get(k, k): v for k, v in data.items()} | |
# Create DataFrame | |
df = pd.DataFrame([mapped_data]) | |
# Check if all required columns are present | |
required_columns = set(column_mapping.values()) | |
missing_columns = required_columns - set(df.columns) | |
if missing_columns: | |
raise ValueError(f"Missing required columns: {missing_columns}") | |
# Make prediction | |
prediction = model.predict(df) | |
return { | |
"nitrogen_need": float(prediction[0][0]), | |
"phosphorus_need": float(prediction[0][1]), | |
"potassium_need": float(prediction[0][2]) | |
} | |
except ValueError as ve: | |
logger.error(f"ValueError in predict: {str(ve)}") | |
raise HTTPException(status_code=400, detail=str(ve)) | |
except Exception as e: | |
logger.error(f"Unexpected error in predict: {str(e)}") | |
raise HTTPException(status_code=500, detail="An unexpected error occurred") | |
async def root(): | |
return {"message": "NPK Needs Prediction API"} |