Spaces:
Runtime error
Runtime error
Commit
·
8c28418
1
Parent(s):
2833bac
Update app.py
Browse files
app.py
CHANGED
@@ -39,16 +39,16 @@ def generate(
|
|
39 |
ctx,
|
40 |
image_features,
|
41 |
token_count=200,
|
42 |
-
temperature=0
|
43 |
top_p=0.3,
|
44 |
presencePenalty = 0.1,
|
45 |
countPenalty = 0.1,
|
46 |
):
|
47 |
args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
ctx = ctx.strip()
|
53 |
all_tokens = []
|
54 |
out_last = 0
|
@@ -56,9 +56,11 @@ def generate(
|
|
56 |
occurrence = {}
|
57 |
for i in range(int(token_count)):
|
58 |
if i == 0:
|
|
|
|
|
59 |
input_ids = pipeline.encode(ctx)
|
60 |
text_embs = model.w['emb.weight'][input_ids]
|
61 |
-
input_embs = torch.cat((image_features, text_embs), dim=0)[-ctx_limit:]
|
62 |
out, state = model.forward(embs=input_embs, state=None)
|
63 |
else:
|
64 |
input_ids = [token]
|
|
|
39 |
ctx,
|
40 |
image_features,
|
41 |
token_count=200,
|
42 |
+
temperature=1.0,
|
43 |
top_p=0.3,
|
44 |
presencePenalty = 0.1,
|
45 |
countPenalty = 0.1,
|
46 |
):
|
47 |
args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
|
48 |
+
alpha_frequency = countPenalty,
|
49 |
+
alpha_presence = presencePenalty,
|
50 |
+
token_ban = [], # ban the generation of some tokens
|
51 |
+
token_stop = [0, 261]) # stop generation whenever you see any token here
|
52 |
ctx = ctx.strip()
|
53 |
all_tokens = []
|
54 |
out_last = 0
|
|
|
56 |
occurrence = {}
|
57 |
for i in range(int(token_count)):
|
58 |
if i == 0:
|
59 |
+
prefix_ids = pipeline.encode("User: ")
|
60 |
+
prefix_embs = model.w['emb.weight'][prefix_ids]
|
61 |
input_ids = pipeline.encode(ctx)
|
62 |
text_embs = model.w['emb.weight'][input_ids]
|
63 |
+
input_embs = torch.cat((prefix_embs, image_features, text_embs), dim=0)[-ctx_limit:]
|
64 |
out, state = model.forward(embs=input_embs, state=None)
|
65 |
else:
|
66 |
input_ids = [token]
|