add model manager
Browse files- inference_manager.py +84 -7
inference_manager.py
CHANGED
@@ -84,12 +84,15 @@ class AuthHelper:
|
|
84 |
raise Exception("wrong token")
|
85 |
|
86 |
class InferenceManager:
|
87 |
-
def __init__(self,
|
88 |
-
self.
|
|
|
|
|
|
|
89 |
self.lora_load_options = self.load_json(lora_options_path) # Load LoRA load options
|
90 |
self.lora_models = self.load_index_file("index.json") # Load index.json
|
91 |
self.preloaded_loras = [] # Array to store preloaded LoRAs with name and weights
|
92 |
-
self.base_model_pipeline = self.load_base_model(
|
93 |
self.preload_loras() # Preload LoRAs based on options
|
94 |
|
95 |
def load_json(self, filepath):
|
@@ -132,13 +135,12 @@ class InferenceManager:
|
|
132 |
self.base_model_pipeline = pipe
|
133 |
print(f"OneDiff compile in {compile_time}s")
|
134 |
|
135 |
-
def load_base_model(self
|
136 |
"""Load base model and return the pipeline."""
|
137 |
start = time.time()
|
138 |
-
|
139 |
-
cfg = json.load(f)
|
140 |
|
141 |
-
model_version = cfg.get("model_version", self.model_version)
|
142 |
ckpt_dir = snapshot_download(repo_id=cfg["model_id"], local_files_only=False)
|
143 |
|
144 |
if model_version == "1.5":
|
@@ -257,6 +259,81 @@ class InferenceManager:
|
|
257 |
torch.cuda.empty_cache()
|
258 |
print("Memory released and cache cleared.")
|
259 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
260 |
|
261 |
# Hugging Face file download function - returns only file path
|
262 |
def download_from_hf(filename, local_dir=None):
|
|
|
84 |
raise Exception("wrong token")
|
85 |
|
86 |
class InferenceManager:
|
87 |
+
def __init__(self, config_path="config.json"):
|
88 |
+
self.cfg = {}
|
89 |
+
with open(config_path, "r", encoding="utf-8") as f:
|
90 |
+
self.cfg = json.load(f)
|
91 |
+
lora_options_path = cfg.get("loras", "")
|
92 |
self.lora_load_options = self.load_json(lora_options_path) # Load LoRA load options
|
93 |
self.lora_models = self.load_index_file("index.json") # Load index.json
|
94 |
self.preloaded_loras = [] # Array to store preloaded LoRAs with name and weights
|
95 |
+
self.base_model_pipeline = self.load_base_model() # Load the base model
|
96 |
self.preload_loras() # Preload LoRAs based on options
|
97 |
|
98 |
def load_json(self, filepath):
|
|
|
135 |
self.base_model_pipeline = pipe
|
136 |
print(f"OneDiff compile in {compile_time}s")
|
137 |
|
138 |
+
def load_base_model(self):
|
139 |
"""Load base model and return the pipeline."""
|
140 |
start = time.time()
|
141 |
+
cfg = self.cfg
|
|
|
142 |
|
143 |
+
self.model_version = cfg.get("model_version", self.model_version)
|
144 |
ckpt_dir = snapshot_download(repo_id=cfg["model_id"], local_files_only=False)
|
145 |
|
146 |
if model_version == "1.5":
|
|
|
259 |
torch.cuda.empty_cache()
|
260 |
print("Memory released and cache cleared.")
|
261 |
|
262 |
+
class ModelManager:
|
263 |
+
def __init__(self, config_urls):
|
264 |
+
"""
|
265 |
+
Initialize the ModelManager by loading all models specified by the URLs.
|
266 |
+
|
267 |
+
:param config_urls: List of URLs pointing to model config files (e.g., ["model1/config.json", "model2/config.json"]).
|
268 |
+
"""
|
269 |
+
self.models = {}
|
270 |
+
self.load_models(config_urls)
|
271 |
+
|
272 |
+
def load_models(self, config_urls):
|
273 |
+
"""
|
274 |
+
Load and initialize InferenceManager instances for each config URL.
|
275 |
+
|
276 |
+
:param config_urls: List of config file URLs.
|
277 |
+
"""
|
278 |
+
for url in config_urls:
|
279 |
+
model_name = self.get_model_name_from_url(url)
|
280 |
+
print(f"Initializing model: {model_name} from {url}")
|
281 |
+
try:
|
282 |
+
# Initialize InferenceManager for each model
|
283 |
+
self.models[model_name] = InferenceManager(config_path=url)
|
284 |
+
except Exception as e:
|
285 |
+
print(f"Failed to initialize model {model_name} from {url}: {e}")
|
286 |
+
|
287 |
+
def get_model_name_from_url(self, url):
|
288 |
+
"""
|
289 |
+
Extract the model name from the config URL (filename without extension).
|
290 |
+
|
291 |
+
:param url: The URL of the configuration file.
|
292 |
+
:return: The model name (file name without extension).
|
293 |
+
"""
|
294 |
+
filename = os.path.basename(url)
|
295 |
+
model_name, _ = os.path.splitext(filename)
|
296 |
+
return model_name
|
297 |
+
|
298 |
+
def get_model_pipeline(self, model_id, lora_list, sampler="DPM2 a", new_pipeline=False):
|
299 |
+
"""
|
300 |
+
Build the pipeline with specific LoRAs for a model.
|
301 |
+
|
302 |
+
:param model_id: The model ID (the model name extracted from the config URL).
|
303 |
+
:param lora_list: List of LoRAs to be applied to the model pipeline.
|
304 |
+
:param sampler: The sampler to be used for the pipeline.
|
305 |
+
:param new_pipeline: Flag to indicate whether to create a new pipeline or reuse the existing one.
|
306 |
+
:return: The built pipeline with LoRAs applied.
|
307 |
+
"""
|
308 |
+
model = self.models.get(model_id)
|
309 |
+
if not model:
|
310 |
+
print(f"Model {model_id} not found.")
|
311 |
+
return None
|
312 |
+
try:
|
313 |
+
print(f"Building pipeline with LoRAs for model {model_id}...")
|
314 |
+
return model.build_pipeline_with_lora(lora_list, sampler, new_pipeline)
|
315 |
+
except Exception as e:
|
316 |
+
print(f"Failed to build pipeline for model {model_id}: {e}")
|
317 |
+
return None
|
318 |
+
|
319 |
+
def release_model(self, model_id):
|
320 |
+
"""
|
321 |
+
Release resources and clear memory for a specific model.
|
322 |
+
|
323 |
+
:param model_id: The model ID (the model name extracted from the config URL).
|
324 |
+
"""
|
325 |
+
model = self.models.get(model_id)
|
326 |
+
if not model:
|
327 |
+
print(f"Model {model_id} not found.")
|
328 |
+
return
|
329 |
+
try:
|
330 |
+
print(f"Releasing model {model_id}...")
|
331 |
+
model.release(model.base_model_pipeline)
|
332 |
+
except Exception as e:
|
333 |
+
print(f"Failed to release model {model_id}: {e}")
|
334 |
+
|
335 |
+
|
336 |
+
|
337 |
|
338 |
# Hugging Face file download function - returns only file path
|
339 |
def download_from_hf(filename, local_dir=None):
|