tripo-custom / app.py
ashh757's picture
Update app.py
56c4234 verified
from fastapi import FastAPI, File, UploadFile, Form, Query
from fastapi.responses import FileResponse
import os
from main import load_model, generate_mesh
## create a new FASTAPI app instance
app=FastAPI()
model = load_model()
@app.get("/")
def home():
return {"message":"Hello World"}
# Define a function to handle the GET request at `/generate`
@app.post("/generate")
async def generate(image: UploadFile = File(...),
no_remove_bg: bool = Form(True),
foreground_ratio: float = Form(0.85),
render: bool = Form(False),
mc_resolution: int = Form(256),
bake_texture_flag: bool = Form(False),
texture_resolution: int = Form(2048),
bucket_name: str= Form('BUCKET_NAME'),
input_folder: str= Form('INPUT_IMAGES'),
output_folder: str= Form('OUTPUT_MESHES'),
input_s3_id: str= Form('input_image.png'),
output_s3_id: str= Form('output_mesh.obj'),
output_video_s3_id= Form('output_video.mp4')
):
# Save the uploaded image to a temporary location
temp_image_path = f"tmp/output/{image.filename}"
with open(temp_image_path, "wb") as f:
f.write(await image.read())
# Call the `generate_mesh` function with customized parameters
output_file_path, output_video_path = generate_mesh(
image_path=temp_image_path,
output_dir='tmp/output/',
no_remove_bg=no_remove_bg,
foreground_ratio=foreground_ratio,
render=render,
mc_resolution=mc_resolution,
bake_texture_flag=bake_texture_flag,
texture_resolution=texture_resolution,
model=model,
bucket_name=bucket_name,
input_folder=input_folder,
output_folder=output_folder,
input_s3_id=input_s3_id,
output_s3_id=output_s3_id,
output_video_s3_id=output_video_s3_id
)
if output_video_path==None:
## return the generate text in Json reposne
return FileResponse(output_file_path, media_type='application/octet-stream', filename="output_mesh.obj")
else:
return (FileResponse(output_file_path, media_type='application/octet-stream', filename="output_mesh.obj"),
FileResponse(output_video_path, media_type='application/octet-stream', filename="output_video.mp4"))