Mitchins commited on
Commit
6332194
·
verified ·
1 Parent(s): d8645be

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. demo.py +86 -0
demo.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ T5 Prompt Enhancer V0.3 Demo Script
4
+ Quick test of all four instruction types
5
+ """
6
+
7
+ import torch
8
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
9
+
10
+ def load_model():
11
+ """Load the T5 V0.3 model"""
12
+ print("🤖 Loading T5 Prompt Enhancer V0.3...")
13
+
14
+ tokenizer = T5Tokenizer.from_pretrained(".")
15
+ model = T5ForConditionalGeneration.from_pretrained(".")
16
+
17
+ if torch.cuda.is_available():
18
+ model = model.cuda()
19
+ print("✅ Model loaded on GPU")
20
+ else:
21
+ print("✅ Model loaded on CPU")
22
+
23
+ return model, tokenizer
24
+
25
+ def enhance_prompt(model, tokenizer, text, style="clean"):
26
+ """Generate enhanced prompt with style control"""
27
+
28
+ style_prompts = {
29
+ "clean": f"Enhance this prompt (no lora): {text}",
30
+ "technical": f"Enhance this prompt (with lora): {text}",
31
+ "simplify": f"Simplify this prompt: {text}",
32
+ "standard": f"Enhance this prompt: {text}"
33
+ }
34
+
35
+ prompt = style_prompts[style]
36
+ inputs = tokenizer(prompt, return_tensors="pt", max_length=256, truncation=True)
37
+
38
+ if torch.cuda.is_available():
39
+ inputs = {k: v.cuda() for k, v in inputs.items()}
40
+
41
+ with torch.no_grad():
42
+ outputs = model.generate(
43
+ **inputs,
44
+ max_length=80,
45
+ num_beams=2,
46
+ repetition_penalty=2.0,
47
+ no_repeat_ngram_size=3,
48
+ pad_token_id=tokenizer.pad_token_id
49
+ )
50
+
51
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
52
+
53
+ def main():
54
+ """Demo all four instruction types"""
55
+
56
+ # Load model
57
+ model, tokenizer = load_model()
58
+
59
+ # Test prompts
60
+ test_prompts = [
61
+ "woman in red dress",
62
+ "cat on chair",
63
+ "cyberpunk cityscape",
64
+ "masterpiece, best quality, ultra-detailed render of a fantasy dragon with golden scales"
65
+ ]
66
+
67
+ styles = ["standard", "clean", "technical", "simplify"]
68
+
69
+ print("\n🎨 T5 Prompt Enhancer V0.3 Demo")
70
+ print("="*60)
71
+
72
+ for prompt in test_prompts:
73
+ print(f"\n📝 Input: '{prompt}'")
74
+ print("-" * 40)
75
+
76
+ for style in styles:
77
+ try:
78
+ result = enhance_prompt(model, tokenizer, prompt, style)
79
+ print(f"{style:>10}: {result}")
80
+ except Exception as e:
81
+ print(f"{style:>10}: ERROR - {e}")
82
+
83
+ print()
84
+
85
+ if __name__ == "__main__":
86
+ main()