from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import joblib
import numpy as np
import pandas as pd

app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Allows all origins
    allow_credentials=True,
    allow_methods=["*"],  # Allows all methods
    allow_headers=["*"],  # Allows all headers
)

# Loading the  model and label encoder
model = joblib.load("soil_npk_joblib_model.joblib")
le = joblib.load("label_encoder.joblib")

class InputData(BaseModel):
    crop_name: str
    target_yield: float
    field_size: float
    ph: float
    organic_carbon: float
    nitrogen: float
    phosphorus: float
    potassium: float
    soil_moisture: float

@app.post("/predict")
async def predict(data: InputData):
    try:
        # Validating crop_name
        if data.crop_name not in le.classes_:
            raise ValueError(f"Invalid crop_name: {data.crop_name}")
        input_data = pd.DataFrame({
            'crop_name': [data.crop_name],
            'target_yield': [data.target_yield],
            'field_size': [data.field_size],
            'ph': [data.ph],
            'organic_carbon': [data.organic_carbon],
            'nitrogen': [data.nitrogen],
            'phosphorus': [data.phosphorus],
            'potassium': [data.potassium],
            'soil_moisture': [data.soil_moisture]
        })
        
        # Use the encoder to transform the crop_name
        input_data['crop_name'] = le.transform(input_data['crop_name'])
          # Validating the  input shape
        expected_shape = model.n_features_in_
        if input_data.shape[1] != expected_shape:
            raise ValueError(f"Input shape mismatch. Expected {expected_shape} features, got {input_data.shape[1]}")

        
        prediction = model.predict(input_data)
        return {
            "nitrogen_need": float(prediction[0][0]),
            "phosphorus_need": float(prediction[0][1]),
            "potassium_need": float(prediction[0][2]),
            "organic_matter_need": float(prediction[0][3]),
            "lime_need": float(prediction[0][4])
        }
    except Exception as e:
        logging.error(f"Error in predict function: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/")
async def root():
    return {"message": "NPK Needs Prediction API"}