from fastapi import FastAPI, Request from fastapi.responses import HTMLResponse, JSONResponse import uvicorn import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from threading import Thread app = FastAPI() # Chargement du modèle uniquement si CUDA est disponible if torch.cuda.is_available(): model_id = "mistralai/Mistral-7B-Instruct-v0.3" model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto") tokenizer = AutoTokenizer.from_pretrained(model_id) else: model = None tokenizer = None MAX_INPUT_TOKEN_LENGTH = 4096 def generate_response(message: str, history: list) -> str: conversation = history + [{"role": "user", "content": message}] input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt") if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] input_ids = input_ids.to(model.device) streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = { "input_ids": input_ids, "streamer": streamer, "max_new_tokens": 1024, "do_sample": True, "top_p": 0.9, "top_k": 50, "temperature": 0.6, "num_beams": 1, "repetition_penalty": 1.2, } t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() response_text = "" for text in streamer: response_text += text return response_text @app.post("/chat") async def chat_endpoint(request: Request): data = await request.json() message = data.get("message", "") # Utilisation d'un historique vide pour simplifier response_text = generate_response(message, history=[]) return JSONResponse({"response": response_text}) @app.get("/", response_class=HTMLResponse) async def root(): with open("index.html", "r", encoding="utf-8") as f: html_content = f.read() return HTMLResponse(content=html_content, status_code=200) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)