mjarrett
updated for 8B model
29969bf
from fastapi import FastAPI, HTTPException
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import logging
from pydantic import BaseModel
import os
import tarfile
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Debug environment variables
logger.info("Environment variables: %s", {k: "****" if "TOKEN" in k or k == "granite" else v for k, v in os.environ.items()})
app = FastAPI()
model_tarball = "/app/granite-8b-finetuned-ascii.tar.gz"
model_path = "/app/granite-8b-finetuned-ascii"
# Extract tarball if model directory doesn't exist
if not os.path.exists(model_path):
logger.info(f"Extracting model tarball: {model_tarball}")
try:
with tarfile.open(model_tarball, "r:gz") as tar:
tar.extractall(path="/app")
logger.info("Model tarball extracted successfully")
except Exception as e:
logger.error(f"Failed to extract model tarball: {str(e)}")
raise HTTPException(status_code=500, detail=f"Model tarball extraction failed: {str(e)}")
try:
logger.info("Loading tokenizer and model")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
tokenizer.padding_side = 'right'
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
logger.info("Model and tokenizer loaded successfully")
except Exception as e:
logger.error(f"Failed to load model or tokenizer: {str(e)}")
raise HTTPException(status_code=500, detail=f"Model initialization failed: {str(e)}")
class EditRequest(BaseModel):
text: str
@app.get("/")
def greet_json():
return {"status": "Model is ready", "model": model_path}
@app.post("/generate")
async def generate(request: EditRequest):
try:
prompt = f"Edit this AsciiDoc sentence: {request.text}"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_length=200)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
logger.info(f"Generated response for prompt: {prompt}")
return {"response": response}
except Exception as e:
logger.error(f"Generation failed: {str(e)}")
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")