Upload sampling.py with huggingface_hub
Browse files- sampling.py +54 -16
sampling.py
CHANGED
|
@@ -23,6 +23,7 @@ parser.add_argument("--out_path", type=str, required=True)
|
|
| 23 |
parser.add_argument("--num_samples", type=int, required=False, default=100000)
|
| 24 |
parser.add_argument("--max_new_tokens", type=int, required=True, help="number of tokens generated in each sample")
|
| 25 |
parser.add_argument("--strategy",type=str, required=False,default='top_k',help="should be in ['greedy_search', 'sampling', 'top_k', 'beam_search']")
|
|
|
|
| 26 |
parser.add_argument("--temperature",type=float, required=False,default=1.0,help="1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions")
|
| 27 |
parser.add_argument("--top_k",type=int, required=False,default=20,help="retain only the top_k most likely tokens, clamp others to have 0 probability")
|
| 28 |
parser.add_argument("--ckpt_path",type=str, required=True,help="path to a checkpoint/model")
|
|
@@ -30,6 +31,7 @@ parser.add_argument("--tokenizer_path",type=str, required=True,help="path to a t
|
|
| 30 |
parser.add_argument("--start",type=str, required=False,default="<|endoftext|>")
|
| 31 |
parser.add_argument("--repetition_penalty",type=float, required=False,default=1.0)
|
| 32 |
parser.add_argument("--shuffle_token", action='store_true', help="Enable shuffling of tokens before decoding")
|
|
|
|
| 33 |
|
| 34 |
args = parser.parse_args()
|
| 35 |
init_from = args.init_from
|
|
@@ -37,17 +39,20 @@ out_path = args.out_path
|
|
| 37 |
num_samples = args.num_samples
|
| 38 |
max_new_tokens = args.max_new_tokens
|
| 39 |
strategy = args.strategy
|
|
|
|
|
|
|
| 40 |
temperature = args.temperature
|
| 41 |
top_k = args.top_k
|
| 42 |
ckpt_path = args.ckpt_path
|
| 43 |
tokenizer_path = args.tokenizer_path
|
| 44 |
start = args.start
|
| 45 |
repetition_penalty = args.repetition_penalty
|
|
|
|
|
|
|
| 46 |
|
| 47 |
# -----------------------------------------------------------------------------
|
| 48 |
seed = random.randint(1,6666)
|
| 49 |
-
|
| 50 |
-
device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
|
| 51 |
dtype = 'float32'
|
| 52 |
# dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
|
| 53 |
compile = False # use PyTorch 2.0 to compile the model to be faster
|
|
@@ -91,20 +96,53 @@ load_meta = False
|
|
| 91 |
encode = tokenizer.encode
|
| 92 |
decode = tokenizer.decode
|
| 93 |
|
| 94 |
-
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
-
with open(out_path, 'a') as f:
|
| 99 |
-
with torch.no_grad():
|
| 100 |
-
with ctx:
|
| 101 |
-
for k in tqdm(range(num_samples), desc="Generating samples"):
|
| 102 |
-
token_sequence = model.generate(x, max_new_tokens, strategy=strategy, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty)[0].tolist()
|
| 103 |
-
|
| 104 |
-
# Shuffle tokens if --shuffle_token is specified
|
| 105 |
-
if args.shuffle_token:
|
| 106 |
-
random.shuffle(token_sequence)
|
| 107 |
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
|
|
|
|
|
| 23 |
parser.add_argument("--num_samples", type=int, required=False, default=100000)
|
| 24 |
parser.add_argument("--max_new_tokens", type=int, required=True, help="number of tokens generated in each sample")
|
| 25 |
parser.add_argument("--strategy",type=str, required=False,default='top_k',help="should be in ['greedy_search', 'sampling', 'top_k', 'beam_search']")
|
| 26 |
+
parser.add_argument("--beam_size",type=int, required=False,default=3,help="beam size for beam search")
|
| 27 |
parser.add_argument("--temperature",type=float, required=False,default=1.0,help="1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions")
|
| 28 |
parser.add_argument("--top_k",type=int, required=False,default=20,help="retain only the top_k most likely tokens, clamp others to have 0 probability")
|
| 29 |
parser.add_argument("--ckpt_path",type=str, required=True,help="path to a checkpoint/model")
|
|
|
|
| 31 |
parser.add_argument("--start",type=str, required=False,default="<|endoftext|>")
|
| 32 |
parser.add_argument("--repetition_penalty",type=float, required=False,default=1.0)
|
| 33 |
parser.add_argument("--shuffle_token", action='store_true', help="Enable shuffling of tokens before decoding")
|
| 34 |
+
parser.add_argument("--fasta", action='store_true', default=True, help="Enable writing output in FASTA format")
|
| 35 |
|
| 36 |
args = parser.parse_args()
|
| 37 |
init_from = args.init_from
|
|
|
|
| 39 |
num_samples = args.num_samples
|
| 40 |
max_new_tokens = args.max_new_tokens
|
| 41 |
strategy = args.strategy
|
| 42 |
+
assert strategy in ['greedy_search', 'sampling', 'top_k', 'beam_search']
|
| 43 |
+
beam_size = args.beam_size
|
| 44 |
temperature = args.temperature
|
| 45 |
top_k = args.top_k
|
| 46 |
ckpt_path = args.ckpt_path
|
| 47 |
tokenizer_path = args.tokenizer_path
|
| 48 |
start = args.start
|
| 49 |
repetition_penalty = args.repetition_penalty
|
| 50 |
+
fasta = args.fasta
|
| 51 |
+
|
| 52 |
|
| 53 |
# -----------------------------------------------------------------------------
|
| 54 |
seed = random.randint(1,6666)
|
| 55 |
+
device = 'cuda'
|
|
|
|
| 56 |
dtype = 'float32'
|
| 57 |
# dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
|
| 58 |
compile = False # use PyTorch 2.0 to compile the model to be faster
|
|
|
|
| 96 |
encode = tokenizer.encode
|
| 97 |
decode = tokenizer.decode
|
| 98 |
|
| 99 |
+
fasta_out_path = os.path.splitext(out_path)[0] + ".fasta" if fasta else None
|
| 100 |
+
|
| 101 |
+
if strategy in["sampling", "top_k"]:
|
| 102 |
+
start_ids = encode("".join(start))
|
| 103 |
+
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
with open(out_path, 'a') as f:
|
| 107 |
+
with open(fasta_out_path, 'a') if fasta else nullcontext() as fasta_f:
|
| 108 |
+
with torch.no_grad():
|
| 109 |
+
with ctx:
|
| 110 |
+
for k in tqdm(range(num_samples), desc="Generating samples"):
|
| 111 |
+
token_sequence = model.generate(x, max_new_tokens, strategy=strategy, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty)[0].tolist()
|
| 112 |
+
|
| 113 |
+
# Shuffle tokens if --shuffle_token is specified
|
| 114 |
+
if args.shuffle_token:
|
| 115 |
+
random.shuffle(token_sequence)
|
| 116 |
+
|
| 117 |
+
y = decode(token_sequence).replace(' ', '')
|
| 118 |
+
# y = decode(token_sequence).replace('\n', '').replace(' ', '') + '\n'
|
| 119 |
+
f.write(y)
|
| 120 |
+
f.flush()
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
if fasta:
|
| 124 |
+
fasta_entry = f">sample_{k}\n{y.replace(' ', '')}\n"
|
| 125 |
+
fasta_f.write(fasta_entry.strip() + '\n')
|
| 126 |
+
fasta_f.flush()
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
elif strategy in ["beam_search", "greedy_search"]:
|
| 130 |
+
with open(out_path, 'a') as f:
|
| 131 |
+
with open(fasta_out_path, 'a') if fasta else nullcontext() as fasta_f:
|
| 132 |
+
with torch.no_grad():
|
| 133 |
+
with ctx:
|
| 134 |
+
start = '<|endoftext|>'
|
| 135 |
+
start_ids = encode(start)
|
| 136 |
+
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
|
| 137 |
+
|
| 138 |
+
token_sequence = model.generate(x, max_new_tokens, strategy=strategy, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, beam_size=beam_size)[0].tolist()
|
| 139 |
|
| 140 |
+
y = decode(token_sequence).replace(' ', '')
|
| 141 |
+
f.write(y)
|
| 142 |
+
f.flush()
|
| 143 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
+
if fasta:
|
| 146 |
+
fasta_entry = f">sample_{k}\n{y.replace(' ', '')}\n"
|
| 147 |
+
fasta_f.write(fasta_entry.strip() + '\n')
|
| 148 |
+
fasta_f.flush()
|