vark101 commited on
Commit
d120779
·
verified ·
1 Parent(s): cb11454

Upload multimodal_tools.py

Browse files
Files changed (1) hide show
  1. tools/multimodal_tools.py +177 -0
tools/multimodal_tools.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import os
3
+ from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage
4
+ from langchain_google_genai import ChatGoogleGenerativeAI
5
+ from langchain.tools import Tool
6
+ from langchain_core.tools import tool
7
+
8
+ api_key = os.getenv("GEMINI_API_KEY")
9
+
10
+ # Create LLM class
11
+ vision_llm = ChatGoogleGenerativeAI(
12
+ model= "gemini-2.5-flash-preview-05-20",
13
+ temperature=0,
14
+ max_retries=2,
15
+ google_api_key=api_key
16
+ )
17
+
18
+ @tool("extract_text_tool", parse_docstring=True)
19
+ def extract_text(img_path: str) -> str:
20
+ """Extract text from an image file using a multimodal model.
21
+
22
+ Args:
23
+ img_path (str): The path to the image file from which to extract text.
24
+
25
+ Returns:
26
+ str: The extracted text from the image, or an empty string if an error occurs.
27
+ """
28
+ all_text = ""
29
+ try:
30
+ # Read image and encode as base64
31
+ with open(img_path, "rb") as image_file:
32
+ image_bytes = image_file.read()
33
+
34
+ image_base64 = base64.b64encode(image_bytes).decode("utf-8")
35
+
36
+ # Prepare the prompt including the base64 image data
37
+ message = [
38
+ HumanMessage(
39
+ content=[
40
+ {
41
+ "type": "text",
42
+ "text": (
43
+ "Extract all the text from this image. "
44
+ "Return only the extracted text, no explanations."
45
+ ),
46
+ },
47
+ {
48
+ "type": "image_url",
49
+ "image_url": {
50
+ "url": f"data:image/png;base64,{image_base64}"
51
+ },
52
+ },
53
+ ]
54
+ )
55
+ ]
56
+
57
+ # Call the vision-capable model
58
+ response = vision_llm.invoke(message)
59
+
60
+ # Append extracted text
61
+ all_text += response.content + "\n\n"
62
+
63
+ return all_text.strip()
64
+ except Exception as e:
65
+ # A butler should handle errors gracefully
66
+ error_msg = f"Error extracting text: {str(e)}"
67
+ print(error_msg)
68
+ return ""
69
+
70
+ @tool("analyze_image_tool", parse_docstring=True)
71
+ def analyze_image_tool(user_query: str, img_path: str) -> str:
72
+ """Answer the question reasoning on the image.
73
+
74
+ Args:
75
+ user_query (str): The question to be answered based on the image.
76
+ img_path (str): Path to the image file to be analyzed.
77
+
78
+ Returns:
79
+ str: The answer to the query based on image content, or an empty string if an error occurs.
80
+ """
81
+ all_text = ""
82
+ try:
83
+ # Read image and encode as base64
84
+ with open(img_path, "rb") as image_file:
85
+ image_bytes = image_file.read()
86
+
87
+ image_base64 = base64.b64encode(image_bytes).decode("utf-8")
88
+
89
+ # Prepare the prompt including the base64 image data
90
+ message = [
91
+ HumanMessage(
92
+ content=[
93
+ {
94
+ "type": "text",
95
+ "text": (
96
+ f"User query: {user_query}"
97
+ ),
98
+ },
99
+ {
100
+ "type": "image_url",
101
+ "image_url": {
102
+ "url": f"data:image/png;base64,{image_base64}"
103
+ },
104
+ },
105
+ ]
106
+ )
107
+ ]
108
+
109
+ # Call the vision-capable model
110
+ response = vision_llm.invoke(message)
111
+
112
+ # Append extracted text
113
+ all_text += response.content + "\n\n"
114
+
115
+ return all_text.strip()
116
+ except Exception as e:
117
+ # A butler should handle errors gracefully
118
+ error_msg = f"Error analyzing image: {str(e)}"
119
+ print(error_msg)
120
+ return ""
121
+
122
+ @tool("analyze_audio_tool", parse_docstring=True)
123
+ def analyze_audio_tool(user_query: str, audio_path: str) -> str:
124
+ """Answer the question by reasoning on the provided audio file.
125
+
126
+ Args:
127
+ user_query (str): The question to be answered based on the audio content.
128
+ audio_path (str): Path to the audio file (e.g., .mp3, .wav, .flac, .aac, .ogg).
129
+
130
+ Returns:
131
+ str: The answer to the query based on audio content, or an error message/empty string if an error occurs.
132
+ """
133
+ try:
134
+ # Determine MIME type from file extension
135
+ _filename, file_extension = os.path.splitext(audio_path)
136
+ file_extension = file_extension.lower()
137
+
138
+ supported_formats = {
139
+ ".mp3": "audio/mp3", ".wav": "audio/wav", ".flac": "audio/flac",
140
+ ".aac": "audio/aac", ".ogg": "audio/ogg"
141
+ }
142
+
143
+ if file_extension not in supported_formats:
144
+ return (f"Error: Unsupported audio file format '{file_extension}'. "
145
+ f"Supported extensions: {', '.join(supported_formats.keys())}.")
146
+ mime_type = supported_formats[file_extension]
147
+
148
+ # Read audio file and encode as base64
149
+ with open(audio_path, "rb") as audio_file:
150
+ audio_bytes = audio_file.read()
151
+ audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
152
+
153
+ # Prepare the prompt including the base64 audio data
154
+ message = [
155
+ HumanMessage(
156
+ content=[
157
+ {
158
+ "type": "text",
159
+ "text": f"User query: {user_query}",
160
+ },
161
+ {
162
+ "type": "audio",
163
+ "source_type": "base64",
164
+ "mime_type": mime_type,
165
+ "data": audio_base64
166
+ },
167
+ ]
168
+ )
169
+ ]
170
+
171
+ # Call the vision-capable model
172
+ response = vision_llm.invoke(message)
173
+ return response.content.strip()
174
+ except Exception as e:
175
+ error_msg = f"Error analyzing audio: {str(e)}"
176
+ print(error_msg)
177
+ return ""