nsfwalex commited on
Commit
7df1bb1
·
verified ·
1 Parent(s): 1606e2d

add model manager

Browse files
Files changed (1) hide show
  1. inference_manager.py +84 -7
inference_manager.py CHANGED
@@ -84,12 +84,15 @@ class AuthHelper:
84
  raise Exception("wrong token")
85
 
86
  class InferenceManager:
87
- def __init__(self, model_version="xl", config_path="config.json", lora_options_path="loras.json"):
88
- self.model_version = model_version
 
 
 
89
  self.lora_load_options = self.load_json(lora_options_path) # Load LoRA load options
90
  self.lora_models = self.load_index_file("index.json") # Load index.json
91
  self.preloaded_loras = [] # Array to store preloaded LoRAs with name and weights
92
- self.base_model_pipeline = self.load_base_model(config_path) # Load the base model
93
  self.preload_loras() # Preload LoRAs based on options
94
 
95
  def load_json(self, filepath):
@@ -132,13 +135,12 @@ class InferenceManager:
132
  self.base_model_pipeline = pipe
133
  print(f"OneDiff compile in {compile_time}s")
134
 
135
- def load_base_model(self, config_path):
136
  """Load base model and return the pipeline."""
137
  start = time.time()
138
- with open(config_path, "r", encoding="utf-8") as f:
139
- cfg = json.load(f)
140
 
141
- model_version = cfg.get("model_version", self.model_version)
142
  ckpt_dir = snapshot_download(repo_id=cfg["model_id"], local_files_only=False)
143
 
144
  if model_version == "1.5":
@@ -257,6 +259,81 @@ class InferenceManager:
257
  torch.cuda.empty_cache()
258
  print("Memory released and cache cleared.")
259
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
 
261
  # Hugging Face file download function - returns only file path
262
  def download_from_hf(filename, local_dir=None):
 
84
  raise Exception("wrong token")
85
 
86
  class InferenceManager:
87
+ def __init__(self, config_path="config.json"):
88
+ self.cfg = {}
89
+ with open(config_path, "r", encoding="utf-8") as f:
90
+ self.cfg = json.load(f)
91
+ lora_options_path = cfg.get("loras", "")
92
  self.lora_load_options = self.load_json(lora_options_path) # Load LoRA load options
93
  self.lora_models = self.load_index_file("index.json") # Load index.json
94
  self.preloaded_loras = [] # Array to store preloaded LoRAs with name and weights
95
+ self.base_model_pipeline = self.load_base_model() # Load the base model
96
  self.preload_loras() # Preload LoRAs based on options
97
 
98
  def load_json(self, filepath):
 
135
  self.base_model_pipeline = pipe
136
  print(f"OneDiff compile in {compile_time}s")
137
 
138
+ def load_base_model(self):
139
  """Load base model and return the pipeline."""
140
  start = time.time()
141
+ cfg = self.cfg
 
142
 
143
+ self.model_version = cfg.get("model_version", self.model_version)
144
  ckpt_dir = snapshot_download(repo_id=cfg["model_id"], local_files_only=False)
145
 
146
  if model_version == "1.5":
 
259
  torch.cuda.empty_cache()
260
  print("Memory released and cache cleared.")
261
 
262
+ class ModelManager:
263
+ def __init__(self, config_urls):
264
+ """
265
+ Initialize the ModelManager by loading all models specified by the URLs.
266
+
267
+ :param config_urls: List of URLs pointing to model config files (e.g., ["model1/config.json", "model2/config.json"]).
268
+ """
269
+ self.models = {}
270
+ self.load_models(config_urls)
271
+
272
+ def load_models(self, config_urls):
273
+ """
274
+ Load and initialize InferenceManager instances for each config URL.
275
+
276
+ :param config_urls: List of config file URLs.
277
+ """
278
+ for url in config_urls:
279
+ model_name = self.get_model_name_from_url(url)
280
+ print(f"Initializing model: {model_name} from {url}")
281
+ try:
282
+ # Initialize InferenceManager for each model
283
+ self.models[model_name] = InferenceManager(config_path=url)
284
+ except Exception as e:
285
+ print(f"Failed to initialize model {model_name} from {url}: {e}")
286
+
287
+ def get_model_name_from_url(self, url):
288
+ """
289
+ Extract the model name from the config URL (filename without extension).
290
+
291
+ :param url: The URL of the configuration file.
292
+ :return: The model name (file name without extension).
293
+ """
294
+ filename = os.path.basename(url)
295
+ model_name, _ = os.path.splitext(filename)
296
+ return model_name
297
+
298
+ def get_model_pipeline(self, model_id, lora_list, sampler="DPM2 a", new_pipeline=False):
299
+ """
300
+ Build the pipeline with specific LoRAs for a model.
301
+
302
+ :param model_id: The model ID (the model name extracted from the config URL).
303
+ :param lora_list: List of LoRAs to be applied to the model pipeline.
304
+ :param sampler: The sampler to be used for the pipeline.
305
+ :param new_pipeline: Flag to indicate whether to create a new pipeline or reuse the existing one.
306
+ :return: The built pipeline with LoRAs applied.
307
+ """
308
+ model = self.models.get(model_id)
309
+ if not model:
310
+ print(f"Model {model_id} not found.")
311
+ return None
312
+ try:
313
+ print(f"Building pipeline with LoRAs for model {model_id}...")
314
+ return model.build_pipeline_with_lora(lora_list, sampler, new_pipeline)
315
+ except Exception as e:
316
+ print(f"Failed to build pipeline for model {model_id}: {e}")
317
+ return None
318
+
319
+ def release_model(self, model_id):
320
+ """
321
+ Release resources and clear memory for a specific model.
322
+
323
+ :param model_id: The model ID (the model name extracted from the config URL).
324
+ """
325
+ model = self.models.get(model_id)
326
+ if not model:
327
+ print(f"Model {model_id} not found.")
328
+ return
329
+ try:
330
+ print(f"Releasing model {model_id}...")
331
+ model.release(model.base_model_pipeline)
332
+ except Exception as e:
333
+ print(f"Failed to release model {model_id}: {e}")
334
+
335
+
336
+
337
 
338
  # Hugging Face file download function - returns only file path
339
  def download_from_hf(filename, local_dir=None):