Maximofn commited on
Commit
f222842
·
verified ·
1 Parent(s): 74c7c43

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -0
app.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import torch
5
+
6
+ from langchain_core.messages import HumanMessage, AIMessage
7
+ from langgraph.checkpoint.memory import MemorySaver
8
+ from langgraph.graph import START, MessagesState, StateGraph
9
+
10
+ import os
11
+ from dotenv import load_dotenv
12
+ load_dotenv()
13
+
14
+ # Initialize the model and tokenizer
15
+ print("Cargando modelo y tokenizer...")
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ model_name = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
18
+
19
+ try:
20
+ # Load the model in BF16 format for better performance and lower memory usage
21
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
22
+
23
+ if device == "cuda":
24
+ print("Using GPU for the model...")
25
+ model = AutoModelForCausalLM.from_pretrained(
26
+ model_name,
27
+ torch_dtype=torch.bfloat16,
28
+ device_map="auto",
29
+ low_cpu_mem_usage=True
30
+ )
31
+ else:
32
+ print("Using CPU for the model...")
33
+ model = AutoModelForCausalLM.from_pretrained(
34
+ model_name,
35
+ device_map={"": device},
36
+ torch_dtype=torch.float32
37
+ )
38
+
39
+ print(f"Model loaded successfully on: {device}")
40
+ except Exception as e:
41
+ print(f"Error loading the model: {str(e)}")
42
+ raise
43
+
44
+ # Define the function that calls the model
45
+ def call_model(state: MessagesState):
46
+ """
47
+ Call the model with the given messages
48
+
49
+ Args:
50
+ state: MessagesState
51
+
52
+ Returns:
53
+ dict: A dictionary containing the generated text and the thread ID
54
+ """
55
+ # Convert LangChain messages to chat format
56
+ messages = [
57
+ {"role": "system", "content": "You are a friendly Chatbot. Always reply in the language in which the user is writing to you."}
58
+ ]
59
+
60
+ for msg in state["messages"]:
61
+ if isinstance(msg, HumanMessage):
62
+ messages.append({"role": "user", "content": msg.content})
63
+ elif isinstance(msg, AIMessage):
64
+ messages.append({"role": "assistant", "content": msg.content})
65
+
66
+ # Prepare the input using the chat template
67
+ input_text = tokenizer.apply_chat_template(messages, tokenize=False)
68
+ inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
69
+
70
+ # Generate response
71
+ outputs = model.generate(
72
+ inputs,
73
+ max_new_tokens=512, # Increase the number of tokens for longer responses
74
+ temperature=0.7,
75
+ top_p=0.9,
76
+ do_sample=True,
77
+ pad_token_id=tokenizer.eos_token_id
78
+ )
79
+
80
+ # Decode and clean the response
81
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
82
+ # Extract only the assistant's response (after the last user message)
83
+ response = response.split("Assistant:")[-1].strip()
84
+
85
+ # Convert the response to LangChain format
86
+ ai_message = AIMessage(content=response)
87
+ return {"messages": state["messages"] + [ai_message]}
88
+
89
+ # Define the graph
90
+ workflow = StateGraph(state_schema=MessagesState)
91
+
92
+ # Define the node in the graph
93
+ workflow.add_edge(START, "model")
94
+ workflow.add_node("model", call_model)
95
+
96
+ # Add memory
97
+ memory = MemorySaver()
98
+ graph_app = workflow.compile(checkpointer=memory)
99
+
100
+ # Define the data model for the request
101
+ class QueryRequest(BaseModel):
102
+ query: str
103
+ thread_id: str = "default"
104
+
105
+ # Create the FastAPI application
106
+ app = FastAPI(title="LangChain FastAPI", description="API to generate text using LangChain and LangGraph")
107
+
108
+ # Welcome endpoint
109
+ @app.get("/")
110
+ async def api_home():
111
+ """Welcome endpoint"""
112
+ return {"detail": "Welcome to FastAPI, Langchain, Docker tutorial"}
113
+
114
+ # Generate endpoint
115
+ @app.post("/generate")
116
+ async def generate(request: QueryRequest):
117
+ """
118
+ Endpoint to generate text using the language model
119
+
120
+ Args:
121
+ request: QueryRequest
122
+ query: str
123
+ thread_id: str = "default"
124
+
125
+ Returns:
126
+ dict: A dictionary containing the generated text and the thread ID
127
+ """
128
+ try:
129
+ # Configure the thread ID
130
+ config = {"configurable": {"thread_id": request.thread_id}}
131
+
132
+ # Create the input message
133
+ input_messages = [HumanMessage(content=request.query)]
134
+
135
+ # Invoke the graph
136
+ output = graph_app.invoke({"messages": input_messages}, config)
137
+
138
+ # Get the model response
139
+ response = output["messages"][-1].content
140
+
141
+ return {
142
+ "generated_text": response,
143
+ "thread_id": request.thread_id
144
+ }
145
+ except Exception as e:
146
+ raise HTTPException(status_code=500, detail=f"Error al generar texto: {str(e)}")
147
+
148
+ if __name__ == "__main__":
149
+ import uvicorn
150
+ uvicorn.run(app, host="0.0.0.0", port=7860)