Makhinur commited on
Commit
90fd37f
·
verified ·
1 Parent(s): a071191

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +350 -81
main.py CHANGED
@@ -1,101 +1,370 @@
1
  from fastapi import FastAPI, File, UploadFile, Form, HTTPException
2
- from fastapi.responses import HTMLResponse
3
- from fastapi.staticfiles import StaticFiles
4
- from fastapi.templating import Jinja2Templates
5
- from fastapi.responses import FileResponse
6
- import requests
7
- import base64
8
- from typing import Iterator
 
 
9
  import os
10
- from text_generation import Client
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  from deep_translator import GoogleTranslator
 
 
12
 
13
  app = FastAPI()
14
 
15
- model_id = 'codellama/CodeLlama-34b-Instruct-hf'
16
-
17
- API_URL = "https://api-inference.huggingface.co/models/" + model_id
18
- HF_TOKEN = os.environ.get("HF_TOKEN", None)
19
-
20
- client = Client(
21
- API_URL,
22
- headers={"Authorization": f"Bearer {HF_TOKEN}"},
23
- )
24
- EOS_STRING = "</s>"
25
- EOT_STRING = "<EOT>"
26
-
27
-
28
- def get_prompt(message: str, chat_history: list[tuple[str, str]],
29
- system_prompt: str) -> str:
30
- texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
31
- do_strip = False
32
- for user_input, response in chat_history:
33
- user_input = user_input.strip() if do_strip else user_input
34
- do_strip = True
35
- texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
36
- message = message.strip() if do_strip else message
37
- texts.append(f'{message} [/INST]')
38
- return ''.join(texts)
39
-
40
-
41
- def run(message: str,
42
- chat_history: list[tuple[str, str]],
43
- system_prompt: str,
44
- max_new_tokens: int = 1024,
45
- temperature: float = 0.1,
46
- top_p: float = 0.9,
47
- top_k: int = 50) -> Iterator[str]:
48
- prompt = get_prompt(message, chat_history, system_prompt)
49
-
50
- generate_kwargs = dict(
51
- max_new_tokens=max_new_tokens,
52
- do_sample=True,
53
- top_p=top_p,
54
- top_k=top_k,
55
- temperature=temperature,
56
- )
57
- stream = client.generate_stream(prompt, **generate_kwargs)
58
- output = ""
59
- for response in stream:
60
- if any([end_token in response.token.text for end_token in [EOS_STRING, EOT_STRING]]):
61
- yield output
62
- output = ""
63
- else:
64
- output += response.token.text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
 
 
 
 
 
 
 
66
 
67
- def generate_image_caption(image_data):
68
- image_base64 = base64.b64encode(image_data).decode('utf-8')
69
- payload = {"data": ["data:image/jpeg;base64," + image_base64]}
70
- response = requests.post("https://makhinur-image-to-text-salesforce-blip-image-cap-c0a9076.hf.space/run/predict", json=payload)
71
- if response.status_code == 200:
72
- caption = response.json()["data"][0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  return caption
74
- else:
75
- return "Error: Unable to generate caption"
76
 
 
 
 
 
 
77
 
78
- import random
 
 
 
 
 
 
 
79
 
80
- from fastapi import Query
81
- from deep_translator import GoogleTranslator
82
- from deep_translator.exceptions import InvalidSourceOrTargetLanguage
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- from fastapi import Query
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
 
 
87
  @app.post("/generate-story/")
88
- async def generate_story(image_file: UploadFile = File(...), language: str = Form(...)):
89
- image_data = await image_file.read()
90
- system_prompt = f"write an attractive story in 300 words about {random.choice(['an adventurous journey', 'a mysterious encounter', 'a heroic quest', 'a magical adventure', 'a thrilling escape', 'an unexpected discovery', 'a dangerous mission', 'a romantic escapade', 'an epic battle', 'a journey into the unknown'])}"
 
 
 
 
 
 
 
 
 
91
 
92
- caption = generate_image_caption(image_data)
93
- if caption.startswith("Error"):
 
 
94
  raise HTTPException(status_code=500, detail=caption)
95
- ai_response = next(run(caption, [], system_prompt))
96
 
97
- if language != "english":
98
- translator = GoogleTranslator(source='english', target=language)
99
- ai_response = translator.translate(ai_response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
- return {"story": ai_response}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI, File, UploadFile, Form, HTTPException
2
+ # Keep these if you use them elsewhere in your app (HTML, static files)
3
+ # from fastapi.responses import HTMLResponse
4
+ # from fastapi.staticfiles import StaticFiles
5
+ # from fastapi.templating import Jinja2Templates
6
+ # from fastapi.responses import FileResponse
7
+
8
+ # Removed 'requests' as we are using gradio_client
9
+ # import requests
10
+ import base64 # Keep if needed elsewhere (not strictly needed for this version)
11
  import os
12
+ import random
13
+ # Removed unused IO import
14
+ # from typing import IO
15
+
16
+ # Import necessary classes from transformers (Keeping only AutoTokenizer)
17
+ from transformers import AutoTokenizer
18
+
19
+ # Import necessary modules for llama-cpp-python and downloading from Hub
20
+ from llama_cpp import Llama # The core Llama class
21
+ from huggingface_hub import hf_hub_download # For downloading GGUF files
22
+
23
+
24
+ # Import the Gradio Client and handle_file
25
+ from gradio_client import Client, handle_file
26
+
27
+ # Import necessary modules for temporary file handling
28
+ import tempfile
29
+ # shutil is not strictly necessary for this version, os.remove is sufficient
30
+ # import shutil
31
+
32
+
33
  from deep_translator import GoogleTranslator
34
+ from deep_translator.exceptions import InvalidSourceOrTargetLanguage
35
+
36
 
37
  app = FastAPI()
38
 
39
+ # --- Llama.cpp Language Model Setup (Local CPU Inference) ---
40
+ # Repository on Hugging Face Hub containing the Qwen1.5 0.5B GGUF file
41
+ # Using the OFFICIAL Qwen 0.5B repository shown in the user's image:
42
+ LLM_MODEL_REPO = "Qwen/Qwen1.5-0.5B-Chat-GGUF" # Updated to official 0.5B repo
43
+
44
+ # Specify the filename for a Q4_K_M quantized version (good balance of speed/quality on CPU)
45
+ # Based on DIRECT VERIFICATION from the user's IMAGE of the 0.5B repo:
46
+ LLM_MODEL_FILE = "qwen1_5-0_5b-chat-q4_k_m.gguf" # Exact filename from the 0.5B repo image
47
+
48
+ # Original model name for the tokenizer (needed by transformers)
49
+ # This points to the base model repository for the tokenizer files.
50
+ ORIGINAL_MODEL_NAME = "Qwen/Qwen1.5-0.5B-Chat" # Updated to the 0.5B Chat model
51
+
52
+ tokenizer = None # Using transformers tokenizer for chat templating
53
+ llm_model = None # This will hold the llama_cpp.Llama instance
54
+
55
+
56
+ # --- Hugging Face Gradio Space Client Setup (For External Image Captioning) ---
57
+ # Global Gradio Client for Captioning
58
+ caption_client = None
59
+ # The URL of the external Gradio Space for image captioning
60
+ CAPTION_SPACE_URL = "Makhinur/Image-to-Text-Salesforce-blip-image-captioning-base"
61
+
62
+
63
+ # Function to load the language model (GGUF via llama.cpp) and its tokenizer (from transformers)
64
+ def load_language_model():
65
+ global tokenizer, llm_model
66
+ print(f"Loading language model: {LLM_MODEL_FILE} from {LLM_MODEL_REPO}...")
67
+ try:
68
+ # --- Load Tokenizer (using transformers) ---
69
+ # Load the tokenizer from the original model repo
70
+ print(f"Loading tokenizer from original model repo: {ORIGINAL_MODEL_NAME}...")
71
+ tokenizer = AutoTokenizer.from_pretrained(ORIGINAL_MODEL_NAME)
72
+
73
+ # Set pad_token if not already defined, often necessary for correct batching/generation behavior
74
+ # Qwen tokenizers should have pad_token, but this check is robust
75
+ if tokenizer.pad_token is None:
76
+ if tokenizer.eos_token is not None:
77
+ tokenizer.pad_token = tokenizer.eos_token
78
+ elif tokenizer.unk_token is not None:
79
+ tokenizer.pad_token = tokenizer.unk_token
80
+ else:
81
+ # Fallback if neither exists (very rare)
82
+ print("Warning: Neither EOS nor UNK token found for tokenizer. Setting pad_token to None.")
83
+ tokenizer.pad_token = None
84
+
85
+
86
+ # --- Download GGUF model file (using huggingface_hub) ---
87
+ print(f"Downloading GGUF model file: {LLM_MODEL_FILE} from {LLM_MODEL_REPO}...")
88
+ model_path = hf_hub_download(
89
+ repo_id=LLM_MODEL_REPO,
90
+ filename=LLM_MODEL_FILE,
91
+ # cache_dir="/tmp/hf_cache" # Optional: specify a custom cache directory
92
+ )
93
+ print(f"GGUF model downloaded to: {model_path}")
94
+
95
+ # --- Load the GGUF model (using llama-cpp-python) ---
96
+ print(f"Loading GGUF model into llama_cpp...")
97
+ # Instantiate the Llama model from the downloaded GGUF file
98
+ # n_gpu_layers=0: Crucial for forcing CPU-only inference
99
+ # n_ctx: Context window size (tokens model can consider), match model's spec if possible (Qwen1.5 0.5B has a smaller context than 1.8B, maybe 4096 or 8192 is standard)
100
+ # n_threads: Number of CPU threads to use. Set to your vCPU count (2) for better performance.
101
+ llm_model = Llama(
102
+ model_path=model_path,
103
+ n_gpu_layers=0, # Explicitly use CPU
104
+ n_ctx=4096, # Context window size (4096 is a common safe value)
105
+ n_threads=2 # Use 2 CPU threads
106
+ )
107
+ print("Llama.cpp model loaded successfully.")
108
+
109
+ except Exception as e:
110
+ print(f"Error loading language model {LLM_MODEL_REPO}/{LLM_MODEL_FILE}: {e}")
111
+ tokenizer = None
112
+ llm_model = None # Ensure the model is None if loading fails
113
+
114
+
115
+ # Function to initialize the Gradio Client for the captioning Space
116
+ def initialize_caption_client():
117
+ global caption_client
118
+ print(f"Initializing Gradio client for {CAPTION_SPACE_URL}...")
119
+ try:
120
+ # If the target Gradio Space requires authentication (e.g., private)
121
+ # store HF_TOKEN as a Space Secret and uncomment these lines.
122
+ # HF_TOKEN = os.environ.get("HF_TOKEN")
123
+ # if HF_TOKEN:
124
+ # print("Using HF_TOKEN for Gradio client.")
125
+ # caption_client = Client(CAPTION_SPACE_URL, hf_token=HF_TOKEN)
126
+ # else:
127
+ # print("HF_TOKEN not found. Initializing public Gradio client.")
128
+ # caption_client = Client(CAPTION_SPACE_URL)
129
+
130
+ # Assuming the caption space is public
131
+ caption_client = Client(CAPTION_SPACE_URL)
132
+ print("Gradio client initialized successfully.")
133
+ except Exception as e:
134
+ print(f"Error initializing Gradio client for {CAPTION_SPACE_URL}: {e}")
135
+ # Set client to None so the endpoint can check and return an error
136
+ caption_client = None
137
+
138
 
139
+ # Load models and initialize clients when the app starts
140
+ @app.on_event("startup")
141
+ async def startup_event():
142
+ # Load the language model (Qwen1.5 0.5B GGUF via llama.cpp)
143
+ load_language_model()
144
+ # Initialize the client for the external captioning Space
145
+ initialize_caption_client()
146
 
147
+
148
+ # --- Image Captioning Function (Using gradio_client and temporary file) ---
149
+ def generate_image_caption(image_file: UploadFile):
150
+ """
151
+ Generates a caption for the uploaded image using the external Gradio Space API.
152
+ Reads the uploaded file's content, saves it to a temporary file,
153
+ and uses the temporary file's path with handle_file for the API call.
154
+ """
155
+ if caption_client is None:
156
+ # If the client failed to initialize at startup
157
+ error_msg = "Gradio caption client not initialized. Cannot generate caption."
158
+ print(error_msg)
159
+ return f"Error: {error_msg}"
160
+
161
+ temp_file_path = None # Variable to store the path of the temporary file
162
+
163
+ try:
164
+ print(f"Attempting to generate caption for file: {image_file.filename}")
165
+
166
+ # Read the content of the uploaded file
167
+ # Seek to the beginning just in case the file-like object's pointer was moved
168
+ image_file.file.seek(0)
169
+ image_bytes = image_file.file.read()
170
+
171
+ # Create a temporary file on the local filesystem
172
+ # delete=False ensures the file persists after closing the handle
173
+ # suffix helps hint at the file type for the Gradio API
174
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(image_file.filename)[1] or '.jpg')
175
+ temp_file.write(image_bytes)
176
+ temp_file.close() # Close the file handle so gradio_client can access the file
177
+ temp_file_path = temp_file.name # Get the full path to the temporary file
178
+
179
+ print(f"Saved uploaded file temporarily to: {temp_file_path}")
180
+
181
+ # Use handle_file() with the path string to the temporary file.
182
+ # This correctly prepares the file for the Gradio API input.
183
+ prepared_input = handle_file(temp_file_path)
184
+
185
+ # Call the predict method on the initialized gradio_client
186
+ # api_name="/predict" matches the endpoint specified in the Gradio API docs
187
+ caption = caption_client.predict(img=prepared_input, api_name="/predict")
188
+
189
+ print(f"Caption generated successfully.")
190
+ # Return the caption string received from the API
191
  return caption
 
 
192
 
193
+ except Exception as e:
194
+ # Catch any exceptions that occur during reading, writing, or the API call
195
+ print(f"Error during caption generation API call: {e}") # Log the error details server-side
196
+ # Return a structured error string including the exception type and message
197
+ return f"Error: Unable to generate caption from API. Details: {type(e).__name__}: {e}"
198
 
199
+ finally:
200
+ # Clean up the temporary file regardless of whether the process succeeded or failed
201
+ if temp_file_path and os.path.exists(temp_file_path):
202
+ print(f"Cleaning up temporary file: {temp_file_path}")
203
+ try:
204
+ os.remove(temp_file_path) # Delete the file using its path
205
+ except OSError as e:
206
+ print(f"Error removing temporary file {temp_file_path}: {e}") # Log cleanup errors
207
 
 
 
 
208
 
209
+ # --- Language Model Story Generation Function (Qwen1.5 0.5B via llama.cpp) ---
210
+ # Renamed function to reflect the model being used
211
+ def generate_story_qwen_0_5b(prompt_text: str, max_new_tokens: int = 300, temperature: float = 0.7, top_p: float = 0.9, top_k: int = 50) -> str:
212
+ """
213
+ Generates text using the loaded Qwen1.5 0.5B model via llama.cpp.
214
+ Uses the tokenizer to apply the chat template and calls llama.cpp's chat completion.
215
+ """
216
+ # Check if the language model was loaded successfully at startup
217
+ # Check for both tokenizer and llm_model (llama.cpp instance)
218
+ if tokenizer is None or llm_model is None:
219
+ # Raise a RuntimeError which is caught by the calling endpoint
220
+ raise RuntimeError("Language model (llama.cpp) or tokenizer not loaded.")
221
+
222
+ # Construct the messages list following the chat format for Qwen1.5 Chat
223
+ # Qwen models use a standard ChatML-like format.
224
+ messages = [
225
+ # System message is optional but can help guide the model's persona/style
226
+ # {"role": "system", "content": "You are a helpful and creative assistant."}
227
+ {"role": "user", "content": prompt_text}
228
+ ]
229
+
230
+ try:
231
+ print("Calling llama.cpp create_chat_completion for Qwen 0.5B...")
232
+ # Call the create_chat_completion method from llama_cpp.Llama instance
233
+ # This method handles the chat templating internally for models like Qwen.
234
+ # max_tokens is the max number of tokens to generate
235
+ # temperature, top_p control sampling. top_k might not be a direct parameter.
236
+ response = llm_model.create_chat_completion(
237
+ messages=messages,
238
+ max_tokens=max_new_tokens,
239
+ temperature=temperature,
240
+ top_p=top_p,
241
+ # top_k is sometimes supported, but check llama-cpp-python docs if needed
242
+ # top_k=top_k,
243
+ stream=False # We want the full response at once
244
+ )
245
+ print("Llama.cpp completion received for Qwen 0.5B.")
246
 
247
+ # Parse the response to get the generated text content
248
+ # The response structure is typically like OpenAI's chat API response
249
+ if response and response.get('choices') and len(response['choices']) > 0:
250
+ story = response['choices'][0].get('message', {}).get('content', '')
251
+ else:
252
+ # Handle cases where the response is empty or has an unexpected structure
253
+ print("Warning: Llama.cpp Qwen 0.5B response structure unexpected or content missing.")
254
+ story = "" # Return an empty string if content is not found
255
+
256
+ except Exception as e:
257
+ # Catch any exception that occurs during the llama.cpp inference process
258
+ print(f"Llama.cpp Qwen 0.5B inference failed: {e}") # Log the error server-side
259
+ # Re-raise as a RuntimeError to indicate failure to the endpoint
260
+ raise RuntimeError(f"Llama.cpp inference failed: {type(e).__name__}: {e}")
261
+
262
+
263
+ # Return the generated story text, removing leading/trailing whitespace
264
+ return story.strip()
265
 
266
+
267
+ # --- FastAPI Endpoint for Story Generation ---
268
  @app.post("/generate-story/")
269
+ async def generate_story_endpoint(image_file: UploadFile = File(...), language: str = Form(...)):
270
+ # Choose a random theme for the story prompt
271
+ story_theme = random.choice([
272
+ 'an adventurous journey', 'a mysterious encounter', 'a heroic quest',
273
+ 'a magical adventure', 'a thrilling escape', 'an unexpected discovery',
274
+ 'a dangerous mission', 'a romantic escapade', 'an epic battle',
275
+ 'a journey into the unknown'
276
+ ])
277
+
278
+ # Step 1: Get image caption using the external Gradio API via gradio_client
279
+ # Pass the UploadFile object directly to the captioning function
280
+ caption = generate_image_caption(image_file)
281
 
282
+ # Check if caption generation returned an error string
283
+ if caption.startswith("Error:"):
284
+ print(f"Caption generation failed: {caption}") # Log the error detail server-side
285
+ # Raise an HTTPException with a 500 status code and the error message
286
  raise HTTPException(status_code=500, detail=caption)
 
287
 
288
+ # Step 2: Construct the prompt text for the language model
289
+ # This prompt instructs the model on what to write and incorporates the caption.
290
+ prompt_text = f"Write an attractive story of around 300 words about {story_theme}. Incorporate the following details from an image description into the story: {caption}\n\nStory:"
291
+
292
+ # Step 3: Generate the story using the local language model (Qwen 0.5B via llama.cpp)
293
+ try:
294
+ # Call the Qwen 0.5B story generation function
295
+ story = generate_story_qwen_0_5b( # <--- Use the updated function name
296
+ prompt_text,
297
+ max_new_tokens=300, # Request ~300 new tokens
298
+ temperature=0.7, # Sampling parameters
299
+ top_p=0.9,
300
+ top_k=50 # Note: top_k may not be directly used by llama_cpp.create_chat_completion
301
+ )
302
+ story = story.strip() # Basic cleanup of generated story text
303
+
304
+ except RuntimeError as e:
305
+ # Catch specific RuntimeError raised by generate_story_qwen_0_5b if LLM loading or inference fails
306
+ print(f"Language model generation error: {e}") # Log the error server-side
307
+ # Return a 503 Service Unavailable error if the LLM is not available or failed
308
+ raise HTTPException(status_code=503, detail=f"Story generation failed (LLM): {e}")
309
+ except Exception as e:
310
+ # Catch any other unexpected errors during story generation
311
+ print(f"An unexpected error occurred during story generation: {e}") # Log server-side
312
+ raise HTTPException(status_code=500, detail=f"An unexpected error occurred during story generation: {type(e).__name__}: {e}")
313
+
314
+
315
+ # Step 4: Translate the generated story if the target language is not English
316
+ # Check if language is provided and not English (case-insensitive)
317
+ if language and language.lower() != "english":
318
+ try:
319
+ # Initialize GoogleTranslator with English source and requested target language
320
+ translator = GoogleTranslator(source='english', target=language.lower())
321
+ # Perform the translation
322
+ translated_story = translator.translate(story)
323
+
324
+ # Check if translation returned None or an empty string (indicates failure)
325
+ if translated_story is None or translated_story == "":
326
+ print(f"Translation returned None or empty string for language: {language}") # Log failure
327
+ # If translation fails, return the original English story with a warning
328
+ return {"story": story + "\n\n(Note: Automatic translation to your requested language failed.)"}
329
+
330
+ # If translation was successful, use the translated text
331
+ story = translated_story
332
+
333
+ except InvalidSourceOrTargetLanguage:
334
+ print(f"Invalid target language requested: {language}") # Log invalid language
335
+ raise HTTPException(status_code=400, detail=f"Invalid target language: {language}")
336
+ except Exception as e:
337
+ # Catch any other errors during translation (e.g., network issues, API problems)
338
+ print(f"Translation failed for language {language}: {e}") # Log server-side
339
+ raise HTTPException(status_code=500, detail=f"Translation failed: {type(e).__name__}: {e}")
340
+
341
+ # Step 5: Return the final generated (and potentially translated) story as a JSON response
342
+ return {"story": story}
343
 
344
+ # --- Optional: Serve a simple HTML form for testing ---
345
+ # To use this, uncomment the imports related to HTMLResponse, StaticFiles, Jinja2Templates, Request
346
+ # at the top of the file, and create a 'templates' directory with an 'index.html' file.
347
+ # from fastapi import Request
348
+ # from fastapi.templating import Jinja2Templates
349
+ # from fastapi.staticfiles import StaticFiles
350
+ # templates = Jinja2Templates(directory="templates")
351
+ # app.mount("/static", StaticFiles(directory="static"), name="static")
352
+ # @app.get("/", response_class=HTMLResponse)
353
+ # async def read_root(request: Request):
354
+ # # Simple HTML form to upload an image and specify language
355
+ # html_content = """
356
+ # <!DOCTYPE html>
357
+ # <html>
358
+ # <head><title>Story Generator</title></head>
359
+ # <body>
360
+ # <h1>Generate a Story from an Image</h1>
361
+ # <form action="/generate-story/" method="post" enctype="multipart/form-data">
362
+ # <input type="file" name="image_file" accept="image/*" required><br><br>
363
+ # Target Language (e.g., english, french, spanish): <input type="text" name="language" value="english"><br><br>
364
+ # <button type="submit">Generate Story</button>
365
+ # </form>
366
+ # </body>
367
+ # </html>
368
+ # """
369
+ # # If using templates: return templates.TemplateResponse("index.html", {"request": request})
370
+ # return HTMLResponse(content=html_content) # Using direct HTML for simplicity if templates not set up