Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, HTTPException | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import logging | |
from pydantic import BaseModel | |
import os | |
import tarfile | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Debug environment variables | |
logger.info("Environment variables: %s", {k: "****" if "TOKEN" in k or k == "granite" else v for k, v in os.environ.items()}) | |
app = FastAPI() | |
model_tarball = "/app/granite-8b-finetuned-ascii.tar.gz" | |
model_path = "/app/granite-8b-finetuned-ascii" | |
# Extract tarball if model directory doesn't exist | |
if not os.path.exists(model_path): | |
logger.info(f"Extracting model tarball: {model_tarball}") | |
try: | |
with tarfile.open(model_tarball, "r:gz") as tar: | |
tar.extractall(path="/app") | |
logger.info("Model tarball extracted successfully") | |
except Exception as e: | |
logger.error(f"Failed to extract model tarball: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Model tarball extraction failed: {str(e)}") | |
try: | |
logger.info("Loading tokenizer and model") | |
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
tokenizer.padding_side = 'right' | |
model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
logger.info("Model and tokenizer loaded successfully") | |
except Exception as e: | |
logger.error(f"Failed to load model or tokenizer: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Model initialization failed: {str(e)}") | |
class EditRequest(BaseModel): | |
text: str | |
def greet_json(): | |
return {"status": "Model is ready", "model": model_path} | |
async def generate(request: EditRequest): | |
try: | |
prompt = f"Edit this AsciiDoc sentence: {request.text}" | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
outputs = model.generate(**inputs, max_length=200) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
logger.info(f"Generated response for prompt: {prompt}") | |
return {"response": response} | |
except Exception as e: | |
logger.error(f"Generation failed: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") | |