Create inference.py
Browse files- inference.py +100 -0
inference.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from google import genai
|
2 |
+
from google.genai import types
|
3 |
+
import os
|
4 |
+
|
5 |
+
# Global variable to hold the client
|
6 |
+
client = None
|
7 |
+
|
8 |
+
def initialize():
|
9 |
+
"""
|
10 |
+
Initializes the Google Generative AI client.
|
11 |
+
"""
|
12 |
+
global client
|
13 |
+
# It's a good practice to load the API key from an environment variable
|
14 |
+
api_key = os.environ.get("GEMINI_API_KEY")
|
15 |
+
if not api_key:
|
16 |
+
api_key = os.environ.get("GOOGLE_API_KEY")
|
17 |
+
|
18 |
+
if not api_key:
|
19 |
+
raise ValueError("Neither GEMINI_API_KEY nor GOOGLE_API_KEY environment variable is set.")
|
20 |
+
|
21 |
+
try:
|
22 |
+
client = genai.Client(api_key=api_key)
|
23 |
+
print("Google Generative AI client initialized.")
|
24 |
+
except Exception as e:
|
25 |
+
print(f"Error initializing Google Generative AI client: {e}")
|
26 |
+
raise
|
27 |
+
|
28 |
+
def generate_content(prompt: str, model_name: str = None, allow_fallbacks: bool = True, generation_config: dict = None) -> str:
|
29 |
+
"""
|
30 |
+
Generates content using the Google Generative AI model.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
prompt: The prompt to send to the model.
|
34 |
+
model_name: The name of the model to use (e.g., "gemini-2.0-flash", "gemini-1.5-flash").
|
35 |
+
If None, a default model will be used.
|
36 |
+
allow_fallbacks: (Currently not directly used by genai.Client.models.generate_content,
|
37 |
+
but kept for compatibility with agent.py structure)
|
38 |
+
generation_config: A dictionary for generation parameters like temperature, max_output_tokens.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
The generated text content.
|
42 |
+
"""
|
43 |
+
global client
|
44 |
+
if client is None:
|
45 |
+
# Attempt to initialize if not already done, though ideally initialize() should be called explicitly.
|
46 |
+
print("Client not initialized. Attempting to initialize now...")
|
47 |
+
initialize()
|
48 |
+
if client is None: # Check again after attempt
|
49 |
+
raise RuntimeError("Google Generative AI client is not initialized. Call initialize() first.")
|
50 |
+
|
51 |
+
# Default model if not specified
|
52 |
+
effective_model_name = model_name if model_name else "gemini-2.0-flash-lite"
|
53 |
+
|
54 |
+
# Prepare generation configuration for the API
|
55 |
+
config_obj = None
|
56 |
+
if generation_config:
|
57 |
+
config_params = {}
|
58 |
+
if 'temperature' in generation_config:
|
59 |
+
config_params['temperature'] = generation_config['temperature']
|
60 |
+
if 'max_output_tokens' in generation_config:
|
61 |
+
config_params['max_output_tokens'] = generation_config['max_output_tokens']
|
62 |
+
# Add other relevant parameters from generation_config as needed by the genai API
|
63 |
+
|
64 |
+
if config_params:
|
65 |
+
config_obj = types.GenerateContentConfig(**config_params)
|
66 |
+
|
67 |
+
try:
|
68 |
+
response = client.models.generate_content(
|
69 |
+
model=effective_model_name,
|
70 |
+
contents=[prompt], # Note: contents expects a list
|
71 |
+
config=config_obj
|
72 |
+
)
|
73 |
+
return response.text
|
74 |
+
except Exception as e:
|
75 |
+
print(f"Error during content generation: {e}")
|
76 |
+
# Depending on how agent.py handles errors, you might want to raise the exception
|
77 |
+
# or return a specific error message. For now, re-raising.
|
78 |
+
raise
|
79 |
+
|
80 |
+
if __name__ == '__main__':
|
81 |
+
# Example usage (optional, for testing inference.py directly)
|
82 |
+
try:
|
83 |
+
# Make sure to set your GEMINI_API_KEY environment variable before running
|
84 |
+
# For example, in your terminal: $env:GEMINI_API_KEY="YOUR_API_KEY"
|
85 |
+
initialize()
|
86 |
+
if client:
|
87 |
+
sample_prompt = "Explain how AI works in a few words"
|
88 |
+
print(f"Sending prompt: '{sample_prompt}'")
|
89 |
+
config = {'temperature': 0.7, 'max_output_tokens': 50}
|
90 |
+
generated_text = generate_content(sample_prompt, generation_config=config)
|
91 |
+
print("\nGenerated text:")
|
92 |
+
print(generated_text)
|
93 |
+
|
94 |
+
sample_prompt_2 = "What is the capital of France?"
|
95 |
+
print(f"\nSending prompt: '{sample_prompt_2}'")
|
96 |
+
generated_text_2 = generate_content(sample_prompt_2, model_name="gemini-2.0-flash-lite") # Example with a different model
|
97 |
+
print("\nGenerated text:")
|
98 |
+
print(generated_text_2)
|
99 |
+
except Exception as e:
|
100 |
+
print(f"An error occurred: {e}")
|