NLarchive commited on
Commit
0d98fee
·
verified ·
1 Parent(s): 670ede9

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +100 -0
inference.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from google import genai
2
+ from google.genai import types
3
+ import os
4
+
5
+ # Global variable to hold the client
6
+ client = None
7
+
8
+ def initialize():
9
+ """
10
+ Initializes the Google Generative AI client.
11
+ """
12
+ global client
13
+ # It's a good practice to load the API key from an environment variable
14
+ api_key = os.environ.get("GEMINI_API_KEY")
15
+ if not api_key:
16
+ api_key = os.environ.get("GOOGLE_API_KEY")
17
+
18
+ if not api_key:
19
+ raise ValueError("Neither GEMINI_API_KEY nor GOOGLE_API_KEY environment variable is set.")
20
+
21
+ try:
22
+ client = genai.Client(api_key=api_key)
23
+ print("Google Generative AI client initialized.")
24
+ except Exception as e:
25
+ print(f"Error initializing Google Generative AI client: {e}")
26
+ raise
27
+
28
+ def generate_content(prompt: str, model_name: str = None, allow_fallbacks: bool = True, generation_config: dict = None) -> str:
29
+ """
30
+ Generates content using the Google Generative AI model.
31
+
32
+ Args:
33
+ prompt: The prompt to send to the model.
34
+ model_name: The name of the model to use (e.g., "gemini-2.0-flash", "gemini-1.5-flash").
35
+ If None, a default model will be used.
36
+ allow_fallbacks: (Currently not directly used by genai.Client.models.generate_content,
37
+ but kept for compatibility with agent.py structure)
38
+ generation_config: A dictionary for generation parameters like temperature, max_output_tokens.
39
+
40
+ Returns:
41
+ The generated text content.
42
+ """
43
+ global client
44
+ if client is None:
45
+ # Attempt to initialize if not already done, though ideally initialize() should be called explicitly.
46
+ print("Client not initialized. Attempting to initialize now...")
47
+ initialize()
48
+ if client is None: # Check again after attempt
49
+ raise RuntimeError("Google Generative AI client is not initialized. Call initialize() first.")
50
+
51
+ # Default model if not specified
52
+ effective_model_name = model_name if model_name else "gemini-2.0-flash-lite"
53
+
54
+ # Prepare generation configuration for the API
55
+ config_obj = None
56
+ if generation_config:
57
+ config_params = {}
58
+ if 'temperature' in generation_config:
59
+ config_params['temperature'] = generation_config['temperature']
60
+ if 'max_output_tokens' in generation_config:
61
+ config_params['max_output_tokens'] = generation_config['max_output_tokens']
62
+ # Add other relevant parameters from generation_config as needed by the genai API
63
+
64
+ if config_params:
65
+ config_obj = types.GenerateContentConfig(**config_params)
66
+
67
+ try:
68
+ response = client.models.generate_content(
69
+ model=effective_model_name,
70
+ contents=[prompt], # Note: contents expects a list
71
+ config=config_obj
72
+ )
73
+ return response.text
74
+ except Exception as e:
75
+ print(f"Error during content generation: {e}")
76
+ # Depending on how agent.py handles errors, you might want to raise the exception
77
+ # or return a specific error message. For now, re-raising.
78
+ raise
79
+
80
+ if __name__ == '__main__':
81
+ # Example usage (optional, for testing inference.py directly)
82
+ try:
83
+ # Make sure to set your GEMINI_API_KEY environment variable before running
84
+ # For example, in your terminal: $env:GEMINI_API_KEY="YOUR_API_KEY"
85
+ initialize()
86
+ if client:
87
+ sample_prompt = "Explain how AI works in a few words"
88
+ print(f"Sending prompt: '{sample_prompt}'")
89
+ config = {'temperature': 0.7, 'max_output_tokens': 50}
90
+ generated_text = generate_content(sample_prompt, generation_config=config)
91
+ print("\nGenerated text:")
92
+ print(generated_text)
93
+
94
+ sample_prompt_2 = "What is the capital of France?"
95
+ print(f"\nSending prompt: '{sample_prompt_2}'")
96
+ generated_text_2 = generate_content(sample_prompt_2, model_name="gemini-2.0-flash-lite") # Example with a different model
97
+ print("\nGenerated text:")
98
+ print(generated_text_2)
99
+ except Exception as e:
100
+ print(f"An error occurred: {e}")