ak0601's picture
Update app.py
d45bd2e verified
import pickle
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, accuracy_score
from sklearn.model_selection import train_test_split
from fastapi import FastAPI, UploadFile, File, HTTPException
from pydantic import BaseModel
import io
app = FastAPI()
data = None
# Function to train the model
def train_aut(data):
data['Downtime_Flag'] = data['Downtime_Flag'].map({'Yes': 1, 'No': 0})
X = data[['Temperature', 'Run_Time']]
y = data['Downtime_Flag']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
model = LogisticRegression()
model.fit(X_train, y_train)
with open('model.pkl', 'wb') as file:
pickle.dump(model, file)
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
f1 = classification_report(y_test, y_pred, output_dict=True)['1']['f1-score']
return accuracy, f1
# Function to make predictions
def predict_aut(temp, run_time):
try:
with open('model.pkl', 'rb') as file:
model = pickle.load(file)
input_data = [[temp, run_time]]
y_pred = model.predict(input_data)
return 'Yes' if y_pred[0] == 1 else 'No'
except FileNotFoundError:
raise HTTPException(status_code=400, detail="Model not trained. Please upload data and train the model first.")
# Pydantic model for prediction input
class PredictionInput(BaseModel):
Temperature: float
Run_Time: float
@app.post("/upload")
async def upload(file: UploadFile = File(...)):
try:
global data
contents = await file.read()
data = pd.read_csv(io.StringIO(contents.decode("utf-8")))
return {"message": "File uploaded successfully."}
except Exception as e:
raise HTTPException(status_code=400, detail=f"Error reading file: {str(e)}")
@app.post("/train")
def train():
global data
if data is None:
raise HTTPException(status_code=400, detail="No data uploaded. Please upload a dataset first.")
try:
accuracy, f1 = train_aut(data)
# return {"message": "Model trained successfully.", "accuracy": accuracy, "f1_score": f1}
return {"message": "Please Contact the owner to switch this space on."}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error during training: {str(e)}")
@app.post("/predict")
def predict(input_data: PredictionInput):
try:
result = predict_aut(input_data.Temperature, input_data.Run_Time)
# return {"Downtime": result}
return {"message": "Please Contact the owner to switch this space on."}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error during prediction: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)