Varu96 commited on
Commit
1384720
·
1 Parent(s): 47a684d

Added the Front_End Jupyter Notebook

Browse files
Files changed (1) hide show
  1. Front_End (1).ipynb +193 -0
Front_End (1).ipynb ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": []
7
+ },
8
+ "kernelspec": {
9
+ "name": "python3",
10
+ "display_name": "Python 3"
11
+ },
12
+ "language_info": {
13
+ "name": "python"
14
+ }
15
+ },
16
+ "cells": [
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": null,
20
+ "metadata": {
21
+ "id": "XICISQU4VQ7j"
22
+ },
23
+ "outputs": [],
24
+ "source": []
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "source": [],
29
+ "metadata": {
30
+ "id": "VsFOxVleVRpA"
31
+ },
32
+ "execution_count": null,
33
+ "outputs": []
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "source": [
38
+ "import json\n",
39
+ "import os\n",
40
+ "from transformers import AutoProcessor, AutoModelForVision2Seq\n",
41
+ "import torch\n",
42
+ "from PIL import Image\n",
43
+ "import gradio as gr\n",
44
+ "import subprocess\n",
45
+ "from llava.model.builder import load_pretrained_model\n",
46
+ "from llava.mm_utils import get_model_name_from_path\n",
47
+ "from llava.eval.run_llava import eval_model\n",
48
+ "\n",
49
+ "# Load the LLaVA model and processor\n",
50
+ "llava_model_path = \"/workspace/LLaVA/LLaVA/llava-fine_tune_model\"\n",
51
+ "\n",
52
+ "# Load the LLaVA-Med model and processor\n",
53
+ "llava_med_model_path = \"/workspace/LLaVA-Med/Model/fine_tuned-med-llava\"\n",
54
+ "\n",
55
+ "# Args class to store arguments for LLaVA models\n",
56
+ "class Args:\n",
57
+ " def __init__(self, model_path, model_base, model_name, query, image_path, conv_mode, image_file, sep, temperature, top_p, num_beams, max_new_tokens):\n",
58
+ " self.model_path = model_path\n",
59
+ " self.model_base = model_base\n",
60
+ " self.model_name = model_name\n",
61
+ " self.query = query\n",
62
+ " self.image_path = image_path\n",
63
+ " self.conv_mode = conv_mode\n",
64
+ " self.image_file = image_file\n",
65
+ " self.sep = sep\n",
66
+ " self.temperature = temperature\n",
67
+ " self.top_p = top_p\n",
68
+ " self.num_beams = num_beams\n",
69
+ " self.max_new_tokens = max_new_tokens\n",
70
+ "\n",
71
+ "# Function to predict using Idefics2\n",
72
+ "def predict_idefics2(image, question, temperature, max_tokens):\n",
73
+ " image = image.convert(\"RGB\")\n",
74
+ " images = [image]\n",
75
+ "\n",
76
+ " messages = [\n",
77
+ " {\n",
78
+ " \"role\": \"user\",\n",
79
+ " \"content\": [\n",
80
+ " {\"type\": \"image\"},\n",
81
+ " {\"type\": \"text\", \"text\": question}\n",
82
+ " ]\n",
83
+ " }\n",
84
+ " ]\n",
85
+ " input_text = idefics2_processor.apply_chat_template(messages, add_generation_prompt=False).strip()\n",
86
+ "\n",
87
+ " inputs = idefics2_processor(text=[input_text], images=images, return_tensors=\"pt\", padding=True).to(\"cuda:0\")\n",
88
+ "\n",
89
+ " with torch.no_grad():\n",
90
+ " outputs = idefics2_model.generate(**inputs, max_length=max_tokens, max_new_tokens=max_tokens, temperature=temperature)\n",
91
+ "\n",
92
+ " predictions = idefics2_processor.decode(outputs[0], skip_special_tokens=True)\n",
93
+ "\n",
94
+ " return predictions\n",
95
+ "\n",
96
+ "# Function to predict using LLaVA\n",
97
+ "def predict_llava(image, question, temperature, max_tokens):\n",
98
+ " # Save the image temporarily\n",
99
+ " image.save(\"temp_image.jpg\")\n",
100
+ "\n",
101
+ " # Setup evaluation arguments\n",
102
+ " args = Args(\n",
103
+ " model_path=llava_model_path,\n",
104
+ " model_base=None,\n",
105
+ " model_name=get_model_name_from_path(llava_model_path),\n",
106
+ " query=question,\n",
107
+ " image_path=\"temp_image.jpg\",\n",
108
+ " conv_mode=None,\n",
109
+ " image_file=\"temp_image.jpg\",\n",
110
+ " sep=\",\",\n",
111
+ " temperature=temperature,\n",
112
+ " top_p=None,\n",
113
+ " num_beams=1,\n",
114
+ " max_new_tokens=max_tokens\n",
115
+ " )\n",
116
+ "\n",
117
+ " # Generate the answer using the selected model\n",
118
+ " output = eval_model(args)\n",
119
+ "\n",
120
+ " return output\n",
121
+ "\n",
122
+ "# Function to predict using LLaVA-Med\n",
123
+ "def predict_llava_med(image, question, temperature, max_tokens):\n",
124
+ " # Save the image temporarily\n",
125
+ " image_path = \"temp_image_med.jpg\"\n",
126
+ " image.save(image_path)\n",
127
+ "\n",
128
+ " # Command to run the LLaVA-Med model\n",
129
+ " command = [\n",
130
+ " \"python\", \"-m\", \"llava.eval.run_llava\",\n",
131
+ " \"--model-name\", llava_med_model_path,\n",
132
+ " \"--image-file\", image_path,\n",
133
+ " \"--query\", question,\n",
134
+ " \"--temperature\", str(temperature),\n",
135
+ " \"--max-new-tokens\", str(max_tokens)\n",
136
+ " ]\n",
137
+ "\n",
138
+ " # Execute the command and capture the output\n",
139
+ " result = subprocess.run(command, capture_output=True, text=True)\n",
140
+ "\n",
141
+ " return result.stdout.strip() # Return the output as text\n",
142
+ "\n",
143
+ "# Main prediction function\n",
144
+ "def predict(model_name, image, text, temperature, max_tokens):\n",
145
+ " if model_name == \"LLaVA\":\n",
146
+ " return predict_llava(image, text, temperature, max_tokens)\n",
147
+ " elif model_name == \"LLaVA-Med\":\n",
148
+ " return predict_llava_med(image, text, temperature, max_tokens)\n",
149
+ "\n",
150
+ "# Define the Gradio interface\n",
151
+ "interface = gr.Interface(\n",
152
+ " fn=predict,\n",
153
+ " inputs=[\n",
154
+ " gr.Radio(choices=[\"LLaVA\", \"LLaVA-Med\"], label=\"Select Model\"),\n",
155
+ " gr.Image(type=\"pil\", label=\"Input Image\"),\n",
156
+ " gr.Textbox(label=\"Input Text\"),\n",
157
+ " gr.Slider(minimum=0.1, maximum=1.0, default=0.7, label=\"Temperature\"),\n",
158
+ " gr.Slider(minimum=1, maximum=512, default=256, label=\"Max Tokens\"),\n",
159
+ " ],\n",
160
+ " outputs=gr.Textbox(label=\"Output Text\"),\n",
161
+ " title=\"Multimodal LLM Interface\",\n",
162
+ " description=\"Switch between models and adjust parameters.\",\n",
163
+ ")\n",
164
+ "\n",
165
+ "# Launch the Gradio interface\n",
166
+ "interface.launch()\n"
167
+ ],
168
+ "metadata": {
169
+ "id": "pCJxQjryVRrh"
170
+ },
171
+ "execution_count": null,
172
+ "outputs": []
173
+ },
174
+ {
175
+ "cell_type": "code",
176
+ "source": [],
177
+ "metadata": {
178
+ "id": "YBSsgQNwVRto"
179
+ },
180
+ "execution_count": null,
181
+ "outputs": []
182
+ },
183
+ {
184
+ "cell_type": "code",
185
+ "source": [],
186
+ "metadata": {
187
+ "id": "UjB_xxubVRu7"
188
+ },
189
+ "execution_count": null,
190
+ "outputs": []
191
+ }
192
+ ]
193
+ }