imlixinyang commited on
Commit
65b6110
·
1 Parent(s): c7ea8b7
Files changed (3) hide show
  1. README.md +1 -1
  2. app_gradio.py +571 -0
  3. index.html +189 -207
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: blue
5
  colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 5.49.1
8
- app_file: app.py
9
  pinned: false
10
  license: cc-by-nc-sa-4.0
11
  models:
 
5
  colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 5.49.1
8
+ app_file: app_gradio.py
9
  pinned: false
10
  license: cc-by-nc-sa-4.0
11
  models:
app_gradio.py ADDED
@@ -0,0 +1,571 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ import spaces
3
+ GPU = spaces.GPU
4
+ print("spaces GPU is available")
5
+ except ImportError:
6
+ def GPU(func):
7
+ return func
8
+
9
+ import os
10
+ import subprocess
11
+
12
+ # def install_cuda_toolkit():
13
+ # # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
14
+ # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.4.0/local_installers/cuda_12.4.0_550.54.14_linux.run"
15
+ # CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
16
+ # subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
17
+ # subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
18
+ # subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])
19
+
20
+ # os.environ["CUDA_HOME"] = "/usr/local/cuda"
21
+ # os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
22
+ # os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
23
+ # os.environ["CUDA_HOME"],
24
+ # "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
25
+ # )
26
+ # # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
27
+ # os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
28
+
29
+ # print("Successfully installed CUDA toolkit at: ", os.environ["CUDA_HOME"])
30
+
31
+ # subprocess.call('rm /usr/bin/gcc', shell=True)
32
+ # subprocess.call('rm /usr/bin/g++', shell=True)
33
+ # subprocess.call('rm /usr/local/cuda/bin/gcc', shell=True)
34
+ # subprocess.call('rm /usr/local/cuda/bin/g++', shell=True)
35
+
36
+ # subprocess.call('ln -s /usr/bin/gcc-11 /usr/bin/gcc', shell=True)
37
+ # subprocess.call('ln -s /usr/bin/g++-11 /usr/bin/g++', shell=True)
38
+
39
+ # subprocess.call('ln -s /usr/bin/gcc-11 /usr/local/cuda/bin/gcc', shell=True)
40
+ # subprocess.call('ln -s /usr/bin/g++-11 /usr/local/cuda/bin/g++', shell=True)
41
+
42
+ # subprocess.call('gcc --version', shell=True)
43
+ # subprocess.call('g++ --version', shell=True)
44
+
45
+ # install_cuda_toolkit()
46
+
47
+ # subprocess.run('pip install git+https://github.com/nerfstudio-project/gsplat.git@32f2a54d21c7ecb135320bb02b136b7407ae5712 --no-build-isolation --use-pep517', env={'CUDA_HOME': "/usr/local/cuda", "TORCH_CUDA_ARCH_LIST": "8.0;8.6"}, shell=True)
48
+
49
+ import gradio as gr
50
+ import base64
51
+ import io
52
+ from PIL import Image
53
+ import torch
54
+ import numpy as np
55
+ import os
56
+ import argparse
57
+ import imageio
58
+ import json
59
+ import time
60
+ import tempfile
61
+ import shutil
62
+
63
+ from huggingface_hub import hf_hub_download
64
+
65
+ import einops
66
+ import torch
67
+ import torch.nn as nn
68
+ import torch.nn.functional as F
69
+ import numpy as np
70
+
71
+ import imageio
72
+
73
+ from models import *
74
+ from utils import *
75
+
76
+ from transformers import T5TokenizerFast, UMT5EncoderModel
77
+
78
+ from diffusers import FlowMatchEulerDiscreteScheduler
79
+
80
+ class MyFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
81
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
82
+ if schedule_timesteps is None:
83
+ schedule_timesteps = self.timesteps
84
+
85
+ return torch.argmin(
86
+ (timestep - schedule_timesteps.to(timestep.device)).abs(), dim=0).item()
87
+
88
+ class GenerationSystem(nn.Module):
89
+ def __init__(self, ckpt_path=None, device="cuda:0", offload_t5=False, offload_vae=False):
90
+ super().__init__()
91
+ self.device = device
92
+ self.offload_t5 = offload_t5
93
+ self.offload_vae = offload_vae
94
+
95
+ self.latent_dim = 48
96
+ self.temporal_downsample_factor = 4
97
+ self.spatial_downsample_factor = 16
98
+
99
+ self.feat_dim = 1024
100
+
101
+ self.latent_patch_size = 2
102
+
103
+ self.denoising_steps = [0, 250, 500, 750]
104
+
105
+ model_id = "Wan-AI/Wan2.2-TI2V-5B-Diffusers"
106
+
107
+ self.vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float).eval()
108
+
109
+ from models.autoencoder_kl_wan import WanCausalConv3d
110
+ with torch.no_grad():
111
+ for name, module in self.vae.named_modules():
112
+ if isinstance(module, WanCausalConv3d):
113
+ time_pad = module._padding[4]
114
+ module.padding = (0, module._padding[2], module._padding[0])
115
+ module._padding = (0, 0, 0, 0, 0, 0)
116
+ module.weight = torch.nn.Parameter(module.weight[:, :, time_pad:].clone())
117
+
118
+ self.vae.requires_grad_(False)
119
+
120
+ self.register_buffer('latents_mean', torch.tensor(self.vae.config.latents_mean).float().view(1, self.vae.config.z_dim, 1, 1, 1).to(self.device))
121
+ self.register_buffer('latents_std', torch.tensor(self.vae.config.latents_std).float().view(1, self.vae.config.z_dim, 1, 1, 1).to(self.device))
122
+
123
+ self.latent_scale_fn = lambda x: (x - self.latents_mean) / self.latents_std
124
+ self.latent_unscale_fn = lambda x: x * self.latents_std + self.latents_mean
125
+
126
+ self.tokenizer = T5TokenizerFast.from_pretrained(model_id, subfolder="tokenizer")
127
+
128
+ self.text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.float32).eval().requires_grad_(False).to(self.device if not self.offload_t5 else "cpu")
129
+
130
+ self.transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.float32).train().requires_grad_(False)
131
+
132
+ self.transformer.patch_embedding.weight = nn.Parameter(F.pad(self.transformer.patch_embedding.weight, (0, 0, 0, 0, 0, 0, 0, 6 + self.latent_dim)))
133
+ # self.transformer.rope.freqs_f[:] = self.transformer.rope.freqs_f[:1]
134
+
135
+ weight = self.transformer.proj_out.weight.reshape(self.latent_patch_size ** 2, self.latent_dim, self.transformer.proj_out.weight.shape[1])
136
+ bias = self.transformer.proj_out.bias.reshape(self.latent_patch_size ** 2, self.latent_dim)
137
+
138
+ extra_weight = torch.randn(self.latent_patch_size ** 2, self.feat_dim, self.transformer.proj_out.weight.shape[1]) * 0.02
139
+ extra_bias = torch.zeros(self.latent_patch_size ** 2, self.feat_dim)
140
+
141
+ self.transformer.proj_out.weight = nn.Parameter(torch.cat([weight, extra_weight], dim=1).flatten(0, 1).detach().clone())
142
+ self.transformer.proj_out.bias = nn.Parameter(torch.cat([bias, extra_bias], dim=1).flatten(0, 1).detach().clone())
143
+
144
+ self.recon_decoder = WANDecoderPixelAligned3DGSReconstructionModel(self.vae, self.feat_dim, use_render_checkpointing=True, use_network_checkpointing=False).train().requires_grad_(False).to(self.device)
145
+
146
+ self.scheduler = MyFlowMatchEulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler", shift=3)
147
+
148
+ self.register_buffer('timesteps', self.scheduler.timesteps.clone().to(self.device))
149
+
150
+ self.transformer.disable_gradient_checkpointing()
151
+ self.transformer.gradient_checkpointing = False
152
+
153
+ self.add_feedback_for_transformer()
154
+
155
+ if ckpt_path is not None:
156
+ state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False)
157
+ self.transformer.load_state_dict(state_dict["transformer"])
158
+ self.recon_decoder.load_state_dict(state_dict["recon_decoder"])
159
+ print(f"Loaded {ckpt_path}.")
160
+
161
+ from quant import FluxFp8GeMMProcessor
162
+
163
+ FluxFp8GeMMProcessor(self.transformer)
164
+
165
+ del self.vae.post_quant_conv, self.vae.decoder
166
+ self.vae.to(self.device if not self.offload_vae else "cpu")
167
+
168
+ self.transformer.to(self.device)
169
+
170
+ def add_feedback_for_transformer(self):
171
+ self.use_feedback = True
172
+ self.transformer.patch_embedding.weight = nn.Parameter(F.pad(self.transformer.patch_embedding.weight, (0, 0, 0, 0, 0, 0, 0, self.feat_dim + self.latent_dim)))
173
+
174
+ def encode_text(self, texts):
175
+ max_sequence_length = 512
176
+
177
+ text_inputs = self.tokenizer(
178
+ texts,
179
+ padding="max_length",
180
+ max_length=max_sequence_length,
181
+ truncation=True,
182
+ add_special_tokens=True,
183
+ return_attention_mask=True,
184
+ return_tensors="pt",
185
+ )
186
+ if getattr(self, "offload_t5", False):
187
+ text_input_ids = text_inputs.input_ids.to("cpu")
188
+ mask = text_inputs.attention_mask.to("cpu")
189
+ else:
190
+ text_input_ids = text_inputs.input_ids.to(self.device)
191
+ mask = text_inputs.attention_mask.to(self.device)
192
+ seq_lens = mask.gt(0).sum(dim=1).long()
193
+
194
+ if getattr(self, "offload_t5", False):
195
+ with torch.no_grad():
196
+ text_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state.to(self.device)
197
+ else:
198
+ text_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state
199
+ text_embeds = [u[:v] for u, v in zip(text_embeds, seq_lens)]
200
+ text_embeds = torch.stack(
201
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in text_embeds], dim=0
202
+ )
203
+ return text_embeds.float()
204
+
205
+ def forward_generator(self, noisy_latents, raymaps, condition_latents, t, text_embeds, cameras, render_cameras, image_height, image_width, need_3d_mode=True):
206
+
207
+ out = self.transformer(
208
+ hidden_states=torch.cat([noisy_latents, raymaps, condition_latents], dim=1),
209
+ timestep=t,
210
+ encoder_hidden_states=text_embeds,
211
+ return_dict=False,
212
+ )[0]
213
+
214
+ v_pred, feats = out.split([self.latent_dim, self.feat_dim], dim=1)
215
+
216
+ sigma = torch.stack([self.scheduler.sigmas[self.scheduler.index_for_timestep(_t)] for _t in t.unbind(0)], dim=0).to(self.device)
217
+ latents_pred_2d = noisy_latents - sigma * v_pred
218
+
219
+ if need_3d_mode:
220
+ scene_params = self.recon_decoder(
221
+ einops.rearrange(feats, 'B C T H W -> (B T) C H W').unsqueeze(2),
222
+ einops.rearrange(self.latent_unscale_fn(latents_pred_2d.detach()), 'B C T H W -> (B T) C H W').unsqueeze(2),
223
+ cameras
224
+ ).flatten(1, -2)
225
+
226
+ images_pred, _ = self.recon_decoder.render(scene_params.unbind(0), render_cameras, image_height, image_width, bg_mode="white")
227
+
228
+ latents_pred_3d = einops.rearrange(self.latent_scale_fn(self.vae.encode(
229
+ einops.rearrange(images_pred, 'B T C H W -> (B T) C H W', T=images_pred.shape[1]).unsqueeze(2).to(self.device if not self.offload_vae else "cpu").float()
230
+ ).latent_dist.sample().to(self.device)).squeeze(2), '(B T) C H W -> B C T H W', T=images_pred.shape[1]).to(noisy_latents.dtype)
231
+
232
+ return {
233
+ '2d': latents_pred_2d,
234
+ '3d': latents_pred_3d if need_3d_mode else None,
235
+ 'rgb_3d': images_pred if need_3d_mode else None,
236
+ 'scene': scene_params if need_3d_mode else None,
237
+ 'feat': feats
238
+ }
239
+
240
+ @torch.no_grad()
241
+ @torch.amp.autocast(dtype=torch.bfloat16, device_type="cuda")
242
+ def generate(self, cameras, n_frame, image=None, text="", image_index=0, image_height=480, image_width=704, video_output_path=None):
243
+ with torch.no_grad():
244
+ batch_size = 1
245
+
246
+ cameras = cameras.to(self.device).unsqueeze(0)
247
+
248
+ if cameras.shape[1] != n_frame:
249
+ render_cameras = cameras.clone()
250
+ cameras = sample_from_dense_cameras(cameras.squeeze(0), torch.linspace(0, 1, n_frame, device=self.device)).unsqueeze(0)
251
+ else:
252
+ render_cameras = cameras
253
+
254
+ cameras, ref_w2c, T_norm = normalize_cameras(cameras, return_meta=True, n_frame=None)
255
+
256
+ render_cameras = normalize_cameras(render_cameras, ref_w2c=ref_w2c, T_norm=T_norm, n_frame=None)
257
+
258
+ text = "[Static] " + text
259
+
260
+ text_embeds = self.encode_text([text])
261
+ # neg_text_embeds = self.encode_text([""]).repeat(batch_size, 1, 1)
262
+
263
+ masks = torch.zeros(batch_size, n_frame, device=self.device)
264
+
265
+ condition_latents = torch.zeros(batch_size, self.latent_dim, n_frame, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor, device=self.device)
266
+
267
+ if image is not None:
268
+ image = image.to(self.device)
269
+
270
+ latent = self.latent_scale_fn(self.vae.encode(
271
+ image.unsqueeze(0).unsqueeze(2).to(self.device if not self.offload_vae else "cpu").float()
272
+ ).latent_dist.sample().to(self.device)).squeeze(2)
273
+
274
+ masks[:, image_index] = 1
275
+ condition_latents[:, :, image_index] = latent
276
+
277
+ raymaps = create_raymaps(cameras, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor)
278
+ raymaps = einops.rearrange(raymaps, 'B T H W C -> B C T H W', T=n_frame)
279
+
280
+ noise = torch.randn(batch_size, self.latent_dim, n_frame, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor, device=self.device)
281
+
282
+ noisy_latents = noise
283
+
284
+ torch.cuda.empty_cache()
285
+
286
+ if self.use_feedback:
287
+ prev_latents_pred = torch.zeros(batch_size, self.latent_dim, n_frame, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor, device=self.device)
288
+
289
+ prev_feats = torch.zeros(batch_size, self.feat_dim, n_frame, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor, device=self.device)
290
+
291
+ for i in range(len(self.denoising_steps)):
292
+ t_ids = torch.full((noisy_latents.shape[0],), self.denoising_steps[i], device=self.device)
293
+
294
+ t = self.timesteps[t_ids]
295
+
296
+ if self.use_feedback:
297
+ _condition_latents = torch.cat([condition_latents, prev_feats, prev_latents_pred], dim=1)
298
+ else:
299
+ _condition_latents = condition_latents
300
+
301
+ if i < len(self.denoising_steps) - 1:
302
+ out = self.forward_generator(noisy_latents, raymaps, _condition_latents, t, text_embeds, cameras, cameras, image_height, image_width, need_3d_mode=True)
303
+
304
+ latents_pred = out["3d"]
305
+
306
+ if self.use_feedback:
307
+ prev_latents_pred = latents_pred
308
+ prev_feats = out['feat']
309
+
310
+ noisy_latents = self.scheduler.scale_noise(latents_pred, self.timesteps[torch.full((noisy_latents.shape[0],), self.denoising_steps[i + 1], device=self.device)], torch.randn_like(noise))
311
+
312
+ else:
313
+ out = self.transformer(
314
+ hidden_states=torch.cat([noisy_latents, raymaps, _condition_latents], dim=1),
315
+ timestep=t,
316
+ encoder_hidden_states=text_embeds,
317
+ return_dict=False,
318
+ )[0]
319
+
320
+ v_pred, feats = out.split([self.latent_dim, self.feat_dim], dim=1)
321
+
322
+ sigma = torch.stack([self.scheduler.sigmas[self.scheduler.index_for_timestep(_t)] for _t in t.unbind(0)], dim=0).to(self.device)
323
+ latents_pred = noisy_latents - sigma * v_pred
324
+
325
+ scene_params = self.recon_decoder(
326
+ einops.rearrange(feats, 'B C T H W -> (B T) C H W').unsqueeze(2),
327
+ einops.rearrange(self.latent_unscale_fn(latents_pred.detach()), 'B C T H W -> (B T) C H W').unsqueeze(2),
328
+ cameras
329
+ ).flatten(1, -2)
330
+
331
+ if video_output_path is not None:
332
+ interpolated_images_pred, _ = self.recon_decoder.render(scene_params.unbind(0), render_cameras, image_height, image_width, bg_mode="white")
333
+
334
+ interpolated_images_pred = einops.rearrange(interpolated_images_pred[0].clamp(-1, 1).add(1).div(2), 'T C H W -> T H W C')
335
+
336
+ interpolated_images_pred = [torch.cat([img], dim=1).detach().cpu().mul(255).numpy().astype(np.uint8) for i, img in enumerate(interpolated_images_pred.unbind(0))]
337
+
338
+ imageio.mimwrite(video_output_path, interpolated_images_pred, fps=15, quality=8, macro_block_size=1)
339
+
340
+ scene_params = scene_params[0]
341
+
342
+ scene_params = scene_params.detach().cpu()
343
+
344
+ return scene_params, ref_w2c, T_norm
345
+
346
+ def process_generation_request(data, generation_system, cache_dir):
347
+ """
348
+ Process the generation request with the same logic as Flask version
349
+ """
350
+ try:
351
+ image_prompt = data.get('image_prompt', None)
352
+ text_prompt = data.get('text_prompt', "")
353
+ cameras = data.get('cameras')
354
+ resolution = data.get('resolution')
355
+ image_index = data.get('image_index', 0)
356
+
357
+ n_frame, image_height, image_width = resolution
358
+
359
+ if not image_prompt and text_prompt == "":
360
+ return {'error': 'No Prompts provided'}
361
+
362
+ if image_prompt:
363
+ # image_prompt可以是路径和base64
364
+ if os.path.exists(image_prompt):
365
+ image_prompt = Image.open(image_prompt)
366
+ else:
367
+ # image_prompt 可能是 "data:image/png;base64,...."
368
+ if ',' in image_prompt:
369
+ image_prompt = image_prompt.split(',', 1)[1]
370
+
371
+ try:
372
+ image_bytes = base64.b64decode(image_prompt)
373
+ image_prompt = Image.open(io.BytesIO(image_bytes))
374
+ except Exception as img_e:
375
+ return {'error': f'Image decode error: {str(img_e)}'}
376
+
377
+ image = image_prompt.convert('RGB')
378
+
379
+ w, h = image.size
380
+
381
+ # center crop
382
+ if image_height / h > image_width / w:
383
+ scale = image_height / h
384
+ else:
385
+ scale = image_width / w
386
+
387
+ new_h = int(image_height / scale)
388
+ new_w = int(image_width / scale)
389
+
390
+ image = image.crop(((w - new_w) // 2, (h - new_h) // 2,
391
+ new_w + (w - new_w) // 2, new_h + (h - new_h) // 2)).resize((image_width, image_height))
392
+
393
+ for camera in cameras:
394
+ camera['fx'] = camera['fx'] * scale
395
+ camera['fy'] = camera['fy'] * scale
396
+ camera['cx'] = (camera['cx'] - (w - new_w) // 2) * scale
397
+ camera['cy'] = (camera['cy'] - (h - new_h) // 2) * scale
398
+
399
+ image = torch.from_numpy(np.array(image)).float().permute(2, 0, 1) / 255.0 * 2 - 1
400
+ else:
401
+ image = None
402
+
403
+ cameras = torch.stack([
404
+ torch.from_numpy(np.array([camera['quaternion'][0], camera['quaternion'][1], camera['quaternion'][2], camera['quaternion'][3], camera['position'][0], camera['position'][1], camera['position'][2], camera['fx'] / image_width, camera['fy'] / image_height, camera['cx'] / image_width, camera['cy'] / image_height], dtype=np.float32))
405
+ for camera in cameras
406
+ ], dim=0)
407
+
408
+ file_id = str(int(time.time() * 1000))
409
+
410
+ start_time = time.time()
411
+ scene_params, ref_w2c, T_norm = generation_system.generate(cameras, n_frame, image, text_prompt, image_index, image_height, image_width, video_output_path=os.path.join(cache_dir, f'{file_id}.mp4'))
412
+ end_time = time.time()
413
+ print(f'生成时间: {end_time - start_time} 秒')
414
+
415
+ with open(os.path.join(cache_dir, f'{file_id}.json'), 'w') as f:
416
+ json.dump(data, f)
417
+
418
+ splat_path = os.path.join(cache_dir, f'{file_id}.ply')
419
+
420
+ export_ply_for_gaussians(splat_path, scene_params, opacity_threshold=0.001, T_norm=T_norm)
421
+
422
+ if not os.path.exists(splat_path):
423
+ return {'error': f'{splat_path} not found'}
424
+
425
+ file_size = os.path.getsize(splat_path)
426
+
427
+ response_data = {
428
+ 'success': True,
429
+ 'file_id': file_id,
430
+ 'file_path': splat_path,
431
+ 'file_size': file_size,
432
+ 'download_url': f'/download/{file_id}',
433
+ 'generation_time': end_time - start_time,
434
+ }
435
+
436
+ return response_data
437
+
438
+ except Exception as e:
439
+ return {'error': f'Processing error: {str(e)}'}
440
+
441
+ def gradio_generate(json_input, generation_system, cache_dir):
442
+ """
443
+ Gradio interface function that processes JSON input and returns JSON output
444
+ """
445
+ try:
446
+ # Parse JSON input
447
+ if isinstance(json_input, str):
448
+ data = json.loads(json_input)
449
+ else:
450
+ data = json_input
451
+
452
+ # Process the request
453
+ result = process_generation_request(data, generation_system, cache_dir)
454
+
455
+ # Return JSON response
456
+ return json.dumps(result, indent=2)
457
+
458
+ except Exception as e:
459
+ error_response = {'error': f'JSON processing error: {str(e)}'}
460
+ return json.dumps(error_response, indent=2)
461
+
462
+ def download_file(file_id, cache_dir):
463
+ """
464
+ Download generated PLY file
465
+ """
466
+ file_path = os.path.join(cache_dir, f'{file_id}.ply')
467
+
468
+ if not os.path.exists(file_path):
469
+ return None
470
+
471
+ return file_path
472
+
473
+ if __name__ == "__main__":
474
+ parser = argparse.ArgumentParser()
475
+ parser.add_argument('--port', type=int, default=7860)
476
+ parser.add_argument("--ckpt", default=None)
477
+ parser.add_argument("--gpu", type=int, default=0)
478
+ parser.add_argument("--cache_dir", type=str, default="./tmpfiles")
479
+ parser.add_argument("--offload_t5", type=bool, default=False)
480
+ parser.add_argument("--max_concurrent", type=int, default=1, help="Maximum concurrent generation tasks")
481
+ args, _ = parser.parse_known_args()
482
+
483
+ # Ensure model.ckpt exists, download if not present
484
+ if args.ckpt is None:
485
+ from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
486
+ ckpt_path = os.path.join(HUGGINGFACE_HUB_CACHE, "models--imlixinyang--FlashWorld", "snapshots", "6a8e88c6f88678ac098e4c82675f0aee555d6e5d", "model.ckpt")
487
+ if not os.path.exists(ckpt_path):
488
+ hf_hub_download(repo_id="imlixinyang/FlashWorld", filename="model.ckpt", local_dir_use_symlinks=False)
489
+ else:
490
+ ckpt_path = args.ckpt
491
+
492
+ # Create cache directory
493
+ os.makedirs(args.cache_dir, exist_ok=True)
494
+
495
+ # Initialize GenerationSystem
496
+ device = f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu"
497
+ generation_system = GenerationSystem(ckpt_path=ckpt_path, device=device)
498
+
499
+ # Create Gradio interface
500
+ with gr.Blocks(title="FlashWorld Backend") as demo:
501
+ gr.Markdown("# FlashWorld Generation Backend")
502
+ gr.Markdown("This backend processes JSON requests for 3D scene generation.")
503
+
504
+ with gr.Row():
505
+ with gr.Column():
506
+ json_input = gr.Textbox(
507
+ label="JSON Input",
508
+ placeholder="Enter JSON request here...",
509
+ lines=10,
510
+ value='{"image_prompt": null, "text_prompt": "A beautiful landscape", "cameras": [...], "resolution": [16, 480, 704], "image_index": 0}'
511
+ )
512
+
513
+ generate_btn = gr.Button("Generate", variant="primary")
514
+
515
+ with gr.Column():
516
+ json_output = gr.Textbox(
517
+ label="JSON Output",
518
+ lines=10,
519
+ interactive=False
520
+ )
521
+
522
+ # File download section
523
+ gr.Markdown("## File Download")
524
+ with gr.Row():
525
+ file_id_input = gr.Textbox(
526
+ label="File ID",
527
+ placeholder="Enter file ID to download..."
528
+ )
529
+ download_btn = gr.Button("Download PLY File")
530
+ download_output = gr.File(label="Downloaded File")
531
+
532
+ # Event handlers
533
+ generate_btn.click(
534
+ fn=lambda json_input: gradio_generate(json_input, generation_system, args.cache_dir),
535
+ inputs=[json_input],
536
+ outputs=[json_output]
537
+ )
538
+
539
+ download_btn.click(
540
+ fn=lambda file_id: download_file(file_id, args.cache_dir),
541
+ inputs=[file_id_input],
542
+ outputs=[download_output]
543
+ )
544
+
545
+ # Example JSON format
546
+ gr.Markdown("""
547
+ ## Example JSON Input Format:
548
+ ```json
549
+ {
550
+ "image_prompt": null,
551
+ "text_prompt": "A beautiful landscape with mountains and trees",
552
+ "cameras": [
553
+ {
554
+ "quaternion": [0, 0, 0, 1],
555
+ "position": [0, 0, 5],
556
+ "fx": 500,
557
+ "fy": 500,
558
+ "cx": 240,
559
+ "cy": 240
560
+ }
561
+ ],
562
+ "resolution": [16, 480, 704],
563
+ "image_index": 0
564
+ }
565
+ ```
566
+ """)
567
+
568
+ # Launch the interface
569
+ demo.launch(
570
+ allowed_paths=[args.cache_dir]
571
+ )
index.html CHANGED
@@ -685,151 +685,7 @@
685
  if (progressText) progressText.textContent = text;
686
  }
687
 
688
- // ==============
689
- // Queue handling
690
- // ==============
691
- let queuePollTimer = null;
692
- let currentTaskId = null;
693
- let initialQueuePosition = null;
694
- let latestGenerationTime = null;
695
- let lastDownloadPct = 0;
696
- let lastDownloadUpdateTs = 0;
697
-
698
- function showQueueWaiting(position, runningCount, queuedCount) {
699
- // Use only the progress bar to show queue progress (from initial position to 0)
700
- showDownloadProgress();
701
- if (initialQueuePosition === null) {
702
- // Initialize from first seen position; ensure >= 1 so 0 -> 100%
703
- const initPos = (typeof position === 'number') ? position : 0;
704
- initialQueuePosition = Math.max(initPos, 1);
705
- }
706
- const percent = initialQueuePosition && initialQueuePosition > 0
707
- ? Math.max(0, Math.min(100, ((initialQueuePosition - (position || 0)) / initialQueuePosition) * 100))
708
- : 0;
709
- updateProgressBar(percent);
710
- const totalWaiting = (position || 0) + (queuedCount || 0);
711
- if (position !== null && position !== undefined) {
712
- const pctText = `${Math.round(percent)}%`;
713
- if (totalWaiting > 0) {
714
- setProgressLabel(`Queued ${position}/${totalWaiting} (${pctText})`);
715
- } else {
716
- setProgressLabel(`Queued ${position} (${pctText})`);
717
- }
718
- } else {
719
- setProgressLabel('Queued');
720
- }
721
- }
722
-
723
- async function pollTaskUntilReady(taskId) {
724
- currentTaskId = taskId;
725
- initialQueuePosition = null;
726
- if (queuePollTimer) {
727
- clearInterval(queuePollTimer);
728
- queuePollTimer = null;
729
- }
730
- const queueStartTs = Date.now();
731
-
732
- const pollOnce = async () => {
733
- try {
734
- const resp = await fetch(`${guiOptions.BackendAddress}/task/${taskId}`);
735
- if (!resp.ok) return;
736
- const info = await resp.json();
737
- if (!info || !info.success) return;
738
-
739
- const pos = info.queue && typeof info.queue.position === 'number' ? info.queue.position : 0;
740
- const running = info.queue ? info.queue.running_count : 0;
741
- const queued = info.queue ? info.queue.queued_count : 0;
742
- if (info.status === 'queued' || info.status === 'running') {
743
- // Only progress bar; set stage label
744
- if (info.status === 'queued') {
745
- showQueueWaiting(pos, running, queued);
746
- } else {
747
- // Transitioned to running: finalize queue progress visually
748
- updateProgressBar(100);
749
- showDownloadProgress();
750
- setProgressLabel('Generating...');
751
- }
752
- }
753
-
754
- if (info.status === 'completed' && info.download_url) {
755
- clearInterval(queuePollTimer);
756
- queuePollTimer = null;
757
- latestGenerationTime = typeof info.generation_time === 'number' ? info.generation_time : null;
758
- // Proceed to download the generated file like the normal path
759
- updateStatus('Downloading generated scene...', cameraParams.length);
760
- const response = await fetch(guiOptions.BackendAddress + info.download_url);
761
- if (!response.ok) throw new Error(`HTTP error! status: ${response.status}`);
762
- const contentLength = response.headers.get('content-length');
763
- const total = parseInt(contentLength || '0', 10);
764
- // Show generation info immediately once we know it and total size from headers
765
- showGenerationInfo(latestGenerationTime || 0, total);
766
- let loaded = 0;
767
- const reader = response.body.getReader();
768
- const chunks = [];
769
- updateProgressBar(0);
770
- setProgressLabel('Downloading 0%');
771
- lastDownloadPct = 0;
772
- lastDownloadUpdateTs = 0;
773
- while (true) {
774
- const { done, value } = await reader.read();
775
- if (done) break;
776
- chunks.push(value);
777
- loaded += value.length;
778
- if (total) {
779
- const pct = Math.min(100, (loaded / total) * 100);
780
- const now = Date.now();
781
- const rounded = Math.round(pct);
782
- // Throttle and enforce monotonic increase
783
- if (rounded > Math.round(lastDownloadPct) || (now - lastDownloadUpdateTs) > 200) {
784
- lastDownloadPct = Math.max(lastDownloadPct, pct);
785
- updateProgressBar(lastDownloadPct);
786
- setProgressLabel(`Downloading ${Math.round(lastDownloadPct)}%`);
787
- lastDownloadUpdateTs = now;
788
- }
789
- }
790
- }
791
-
792
- if (instructionSplat) {
793
- scene.remove(instructionSplat);
794
- console.log('Instruction splat removed');
795
- instructionSplat = null;
796
- }
797
-
798
- const blob = new Blob(chunks);
799
- const url = URL.createObjectURL(blob);
800
- // Continue to load the splat
801
- updateStatus('Loading generated scene...', cameraParams.length);
802
-
803
- const GeneratedSplat = new SplatMesh({ url });
804
- scene.add(GeneratedSplat);
805
- currentGeneratedSplat = GeneratedSplat;
806
- updateStatus('Scene generated successfully!', cameraParams.length);
807
- // Show generation time and total file size (MB)
808
- showGenerationInfo(latestGenerationTime || 0, total || blob.size);
809
- // Notify backend to delete the server file after client has downloaded it
810
- try {
811
- if (info.file_id) {
812
- const resp = await fetch(`${guiOptions.BackendAddress}/delete/${info.file_id}`, { method: 'POST' });
813
- if (!resp.ok) console.warn('Delete notify failed');
814
- }
815
- } catch (e) {
816
- console.warn('Delete notify error', e);
817
- }
818
- hideDownloadProgress();
819
- showLoading(false);
820
- } else if (info.status === 'failed') {
821
- clearInterval(queuePollTimer);
822
- queuePollTimer = null;
823
- throw new Error(info.error || 'Generation failed');
824
- }
825
- } catch (e) {
826
- console.debug('Polling error:', e);
827
- }
828
- };
829
-
830
- await pollOnce();
831
- queuePollTimer = setInterval(pollOnce, 2000);
832
- }
833
 
834
  // Hide download progress
835
  function hideDownloadProgress() {
@@ -885,7 +741,7 @@
885
 
886
  // GUI Options - declare early
887
  const guiOptions = {
888
- // 后端地址,默认为本页面ip
889
  BackendAddress: `${window.location.protocol}//${window.location.hostname}:7860`,
890
  FOV: 60,
891
  LoadFromJson: () => {
@@ -1057,82 +913,208 @@
1057
  console.log('Interpolated cameras:', interpolatedCameras.length);
1058
  updateStatus('Sending request to backend...', cameraParams.length);
1059
 
1060
- // 根据后端类型选择不同的请求方式
1061
- let requestUrl, requestBody;
1062
-
1063
- if (true) {
1064
- // Flask后端:直接POST到/generate
1065
- requestUrl = guiOptions.BackendAddress + '/generate';
1066
- requestBody = JSON.stringify({
1067
- image_prompt: inputImageBase64 ? inputImageBase64 : "",
1068
- text_prompt: guiOptions.inputTextPrompt,
1069
- image_index: 0,
1070
- resolution: [
1071
- parseInt(guiOptions.Resolution.split('x')[0]),
1072
- parseInt(guiOptions.Resolution.split('x')[1]),
1073
- parseInt(guiOptions.Resolution.split('x')[2])
1074
- ],
1075
- cameras: interpolatedCameras.map(cam => ({
1076
- position: [cam.position.x, cam.position.y, cam.position.z],
1077
- quaternion: [cam.quaternion.w, cam.quaternion.x, cam.quaternion.y, cam.quaternion.z],
1078
- fx: 0.5 / Math.tan(0.5 * cam.fov * Math.PI / 180) * parseInt(guiOptions.Resolution.split('x')[1]),
1079
- fy: 0.5 / Math.tan(0.5 * cam.fov * Math.PI / 180) * parseInt(guiOptions.Resolution.split('x')[1]),
1080
- cx: inputImageBase64 && inputImageResolution
1081
- ? 0.5 * inputImageResolution.width
1082
- : 0.5 * parseInt(guiOptions.Resolution.split('x')[2]),
1083
- cy: inputImageBase64 && inputImageResolution
1084
- ? 0.5 * inputImageResolution.height
1085
- : 0.5 * parseInt(guiOptions.Resolution.split('x')[1]),
1086
- }))
1087
- });
1088
- } else {
1089
-
1090
- }
1091
 
1092
- // 请求后端生成(异步:返回task_id并开始排队轮询)
1093
- fetch(requestUrl, {
1094
  method: 'POST',
1095
  headers: { 'Content-Type': 'application/json' },
1096
  mode: 'cors',
1097
- body: requestBody
1098
- })
1099
- .then(response => {
1100
- const contentType = response.headers.get('content-type');
1101
- if (contentType && contentType.includes('application/json')) {
1102
- return response.json();
1103
- } else {
1104
- return response.blob().then(blob => {
1105
- const url = URL.createObjectURL(blob);
1106
- return { url };
1107
- });
1108
- }
1109
  })
 
1110
  .then(data => {
1111
- console.log(data);
1112
- {
1113
- // 异步队列协议:后端返回 task_id + queue 信息(202)
1114
- if (data && data.success && data.task_id) {
1115
- updateStatus('Queued request submitted. Waiting in queue...', cameraParams.length);
1116
- showQueueWaiting(data.queue?.position || 0, data.queue?.running_count || 0, data.queue?.queued_count || 0);
1117
- // 轮询直到任务完成并下载
1118
- return pollTaskUntilReady(data.task_id).then(() => ({ url: null }));
1119
- }
1120
- // 兼容旧的直接文件响应格式
1121
- if (data && data.url) {
1122
- updateStatus('Loading generated scene...', cameraParams.length);
1123
- return Promise.resolve(data);
1124
- }
1125
- throw new Error('Invalid Flask response (expected task_id)');
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1126
  }
1127
  })
1128
  .then(data => {
1129
  if (data.url) {
1130
  updateStatus('Loading 3D scene...', cameraParams.length);
 
1131
  // Remove the instruction splat when generation is complete
1132
  if (instructionSplat) {
1133
  scene.remove(instructionSplat);
1134
  console.log('Instruction splat removed');
1135
  }
 
1136
  const GeneratedSplat = new SplatMesh({ url: data.url });
1137
  scene.add(GeneratedSplat);
1138
  currentGeneratedSplat = GeneratedSplat; // 保存新生成的场景引用
@@ -1517,7 +1499,7 @@
1517
 
1518
  // Step 1: Configure Generation Settings
1519
  const step1Folder = gui.addFolder('1. Configure Settings');
1520
- step1Folder.add(guiOptions, "BackendAddress").name("Backend Address");
1521
 
1522
  // FOV和Resolution控制器,初始时启用
1523
  const fovController = step1Folder.add(guiOptions, "FOV", 0, 120, 1).name("FOV").onChange((value) => {
 
685
  if (progressText) progressText.textContent = text;
686
  }
687
 
688
+ // Gradio handles concurrency automatically, no need for queue polling
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
689
 
690
  // Hide download progress
691
  function hideDownloadProgress() {
 
741
 
742
  // GUI Options - declare early
743
  const guiOptions = {
744
+ // Gradio后端地址,默认为本页面ip:7860
745
  BackendAddress: `${window.location.protocol}//${window.location.hostname}:7860`,
746
  FOV: 60,
747
  LoadFromJson: () => {
 
913
  console.log('Interpolated cameras:', interpolatedCameras.length);
914
  updateStatus('Sending request to backend...', cameraParams.length);
915
 
916
+ // Gradio后端:使用Gradio API
917
+ const requestData = {
918
+ image_prompt: inputImageBase64 ? inputImageBase64 : "",
919
+ text_prompt: guiOptions.inputTextPrompt,
920
+ image_index: 0,
921
+ resolution: [
922
+ parseInt(guiOptions.Resolution.split('x')[0]),
923
+ parseInt(guiOptions.Resolution.split('x')[1]),
924
+ parseInt(guiOptions.Resolution.split('x')[2])
925
+ ],
926
+ cameras: interpolatedCameras.map(cam => ({
927
+ position: [cam.position.x, cam.position.y, cam.position.z],
928
+ quaternion: [cam.quaternion.w, cam.quaternion.x, cam.quaternion.y, cam.quaternion.z],
929
+ fx: 0.5 / Math.tan(0.5 * cam.fov * Math.PI / 180) * parseInt(guiOptions.Resolution.split('x')[1]),
930
+ fy: 0.5 / Math.tan(0.5 * cam.fov * Math.PI / 180) * parseInt(guiOptions.Resolution.split('x')[1]),
931
+ cx: inputImageBase64 && inputImageResolution
932
+ ? 0.5 * inputImageResolution.width
933
+ : 0.5 * parseInt(guiOptions.Resolution.split('x')[2]),
934
+ cy: inputImageBase64 && inputImageResolution
935
+ ? 0.5 * inputImageResolution.height
936
+ : 0.5 * parseInt(guiOptions.Resolution.split('x')[1]),
937
+ }))
938
+ };
 
 
 
 
 
 
 
 
939
 
940
+ // 请求Gradio后端生成
941
+ fetch(guiOptions.BackendAddress + '/gradio_api/call/gradio_generate', {
942
  method: 'POST',
943
  headers: { 'Content-Type': 'application/json' },
944
  mode: 'cors',
945
+ body: JSON.stringify({
946
+ data: [JSON.stringify(requestData)]
947
+ })
 
 
 
 
 
 
 
 
 
948
  })
949
+ .then(response => response.json())
950
  .then(data => {
951
+ console.log('Gradio response:', data);
952
+
953
+ // Gradio总是返回event_id,需要先获取生成结果
954
+ if (data.event_id) {
955
+ console.log('Got EVENT_ID from generation call:', data.event_id);
956
+
957
+ // 使用EVENT_ID获取生成结果(SSE格式)
958
+ return fetch(guiOptions.BackendAddress + `/gradio_api/call/gradio_generate/${data.event_id}`)
959
+ .then(response => {
960
+ if (!response.ok) {
961
+ throw new Error(`HTTP error! status: ${response.status}`);
962
+ }
963
+ return response.text();
964
+ })
965
+ .then(sseText => {
966
+ console.log('SSE response:', sseText);
967
+
968
+ // 解析SSE格式的响应
969
+ const lines = sseText.split('\n');
970
+ let eventType = null;
971
+ let dataContent = null;
972
+
973
+ for (const line of lines) {
974
+ if (line.startsWith('event: ')) {
975
+ eventType = line.substring(7);
976
+ } else if (line.startsWith('data: ')) {
977
+ dataContent = line.substring(6);
978
+ }
979
+ }
980
+
981
+ console.log('Event type:', eventType, 'Data:', dataContent);
982
+
983
+ if (eventType === 'complete' && dataContent) {
984
+ // 解析JSON数据
985
+ const resultData = JSON.parse(dataContent);
986
+ console.log('Generation result:', resultData);
987
+
988
+ // 解析生成结果
989
+ if (resultData && resultData.length > 0) {
990
+ const responseData = JSON.parse(resultData[0]);
991
+ console.log('Gradio generation successful:', responseData);
992
+
993
+ if (responseData.success && responseData.download_url) {
994
+ console.log('Generation time:', responseData.generation_time, 'seconds');
995
+ console.log('File size:', responseData.file_size, 'bytes');
996
+
997
+ // 显示生成信息
998
+ showGenerationInfo(responseData.generation_time, responseData.file_size);
999
+ showDownloadProgress();
1000
+ updateStatus('Downloading generated scene...', cameraParams.length);
1001
+
1002
+ // 现在下载文件,也需要两步:先获取下载的EVENT_ID,再下载文件
1003
+ return fetch(guiOptions.BackendAddress + '/gradio_api/call/download_file', {
1004
+ method: 'POST',
1005
+ headers: { 'Content-Type': 'application/json' },
1006
+ body: JSON.stringify({
1007
+ data: [responseData.file_id]
1008
+ })
1009
+ })
1010
+ .then(response => response.json())
1011
+ .then(downloadEventData => {
1012
+ console.log('Download EVENT_ID:', downloadEventData.event_id);
1013
+
1014
+ // 使用下载的EVENT_ID获取文件信息(SSE格式)
1015
+ return fetch(guiOptions.BackendAddress + `/gradio_api/call/download_file/${downloadEventData.event_id}`)
1016
+ .then(response => {
1017
+ if (!response.ok) {
1018
+ throw new Error(`HTTP error! status: ${response.status}`);
1019
+ }
1020
+ return response.text();
1021
+ })
1022
+ .then(sseText => {
1023
+ console.log('Download SSE response:', sseText);
1024
+
1025
+ // 解析SSE格式的响应
1026
+ const lines = sseText.split('\n');
1027
+ let eventType = null;
1028
+ let dataContent = null;
1029
+
1030
+ for (const line of lines) {
1031
+ if (line.startsWith('event: ')) {
1032
+ eventType = line.substring(7);
1033
+ } else if (line.startsWith('data: ')) {
1034
+ dataContent = line.substring(6);
1035
+ }
1036
+ }
1037
+
1038
+ console.log('Download event type:', eventType, 'Data:', dataContent);
1039
+
1040
+ if (eventType === 'complete' && dataContent) {
1041
+ // 解析文件信息
1042
+ const fileData = JSON.parse(dataContent);
1043
+ console.log('File data:', fileData);
1044
+
1045
+ if (fileData && fileData.length > 0 && fileData[0].url) {
1046
+ const fileUrl = fileData[0].url;
1047
+ console.log('File URL:', fileUrl);
1048
+
1049
+ // 从返回的URL下载实际文件
1050
+ return fetch(fileUrl)
1051
+ .then(response => {
1052
+ if (!response.ok) {
1053
+ throw new Error(`HTTP error! status: ${response.status}`);
1054
+ }
1055
+
1056
+ const contentLength = response.headers.get('content-length');
1057
+ const total = parseInt(contentLength, 10);
1058
+ let loaded = 0;
1059
+
1060
+ const reader = response.body.getReader();
1061
+ const chunks = [];
1062
+
1063
+ function pump() {
1064
+ return reader.read().then(({ done, value }) => {
1065
+ if (done) {
1066
+ return new Blob(chunks);
1067
+ }
1068
+
1069
+ chunks.push(value);
1070
+ loaded += value.length;
1071
+
1072
+ if (total) {
1073
+ const percentage = (loaded / total) * 100;
1074
+ updateProgressBar(percentage);
1075
+ }
1076
+
1077
+ return pump();
1078
+ });
1079
+ }
1080
+
1081
+ return pump().then(blob => {
1082
+ const url = URL.createObjectURL(blob);
1083
+ return { url };
1084
+ });
1085
+ });
1086
+ } else {
1087
+ throw new Error('Invalid file data format from Gradio');
1088
+ }
1089
+ } else {
1090
+ throw new Error('Gradio download SSE response not complete or missing data');
1091
+ }
1092
+ });
1093
+ });
1094
+ } else {
1095
+ throw new Error('Gradio generation failed: ' + (responseData.error || 'Unknown error'));
1096
+ }
1097
+ } else {
1098
+ throw new Error('Invalid Gradio generation result format');
1099
+ }
1100
+ } else {
1101
+ throw new Error('Gradio SSE response not complete or missing data');
1102
+ }
1103
+ });
1104
+ } else {
1105
+ throw new Error('Invalid Gradio response format - no event_id');
1106
  }
1107
  })
1108
  .then(data => {
1109
  if (data.url) {
1110
  updateStatus('Loading 3D scene...', cameraParams.length);
1111
+
1112
  // Remove the instruction splat when generation is complete
1113
  if (instructionSplat) {
1114
  scene.remove(instructionSplat);
1115
  console.log('Instruction splat removed');
1116
  }
1117
+
1118
  const GeneratedSplat = new SplatMesh({ url: data.url });
1119
  scene.add(GeneratedSplat);
1120
  currentGeneratedSplat = GeneratedSplat; // 保存新生成的场景引用
 
1499
 
1500
  // Step 1: Configure Generation Settings
1501
  const step1Folder = gui.addFolder('1. Configure Settings');
1502
+ step1Folder.add(guiOptions, "BackendAddress").name("Gradio Backend Address");
1503
 
1504
  // FOV和Resolution控制器,初始时启用
1505
  const fovController = step1Folder.add(guiOptions, "FOV", 0, 120, 1).name("FOV").onChange((value) => {