zhb10086 commited on
Commit
501f4d0
·
verified ·
1 Parent(s): 99e26f0

Update preprocessing_molmo.py

Browse files
Files changed (1) hide show
  1. preprocessing_molmo.py +10 -16
preprocessing_molmo.py CHANGED
@@ -97,7 +97,7 @@ class MolmoProcessor(ProcessorMixin):
97
  self._special_tokens = get_special_token_ids(self.tokenizer)
98
  return self._special_tokens
99
 
100
- def get_tokens_input(self, prompt, message_format, always_start_with_space):
101
  if message_format == "none" or message_format is None:
102
  pass
103
  elif message_format == "role":
@@ -107,21 +107,9 @@ class MolmoProcessor(ProcessorMixin):
107
 
108
  if always_start_with_space:
109
  prompt = " " + prompt
110
-
111
- tokens = self.tokenizer.encode(prompt, add_special_tokens=False)
112
-
113
- return tokens
114
-
115
- def get_tokens_input_for_logits(self, prompt, pred, message_format, always_start_with_space):
116
- if message_format == "none" or message_format is None:
117
- pass
118
- elif message_format == "role":
119
- prompt = "User: " + prompt + " Assistant: " + pred
120
- else:
121
- raise NotImplementedError(f"Message format {message_format} not implemented")
122
-
123
- if always_start_with_space:
124
- prompt = " " + prompt
125
 
126
  tokens = self.tokenizer.encode(prompt, add_special_tokens=False)
127
 
@@ -131,8 +119,10 @@ class MolmoProcessor(ProcessorMixin):
131
  self,
132
  text: TextInput = None,
133
  images: ImageInput = None,
 
134
  *,
135
  tokens: Optional[PreTokenizedInput] = None,
 
136
  **kwargs: Unpack[MolmoProcessorKwargs],
137
  ):
138
  output_kwargs = self._merge_kwargs(
@@ -146,8 +136,12 @@ class MolmoProcessor(ProcessorMixin):
146
  text,
147
  output_kwargs["text_kwargs"]["message_format"],
148
  output_kwargs["text_kwargs"]["always_start_with_space"],
 
149
  )
150
 
 
 
 
151
  image_token_id = self.special_token_ids[IMAGE_PROMPT]
152
 
153
  if images is not None:
 
97
  self._special_tokens = get_special_token_ids(self.tokenizer)
98
  return self._special_tokens
99
 
100
+ def get_tokens_input(self, prompt, message_format, always_start_with_space, out_text=None):
101
  if message_format == "none" or message_format is None:
102
  pass
103
  elif message_format == "role":
 
107
 
108
  if always_start_with_space:
109
  prompt = " " + prompt
110
+
111
+ if out_text is not None:
112
+ prompt = " ".join([prompt, out_text])
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  tokens = self.tokenizer.encode(prompt, add_special_tokens=False)
115
 
 
119
  self,
120
  text: TextInput = None,
121
  images: ImageInput = None,
122
+ out_text: TextInput = None,
123
  *,
124
  tokens: Optional[PreTokenizedInput] = None,
125
+ out_tokens: Optional[PreTokenizedInput] = None,
126
  **kwargs: Unpack[MolmoProcessorKwargs],
127
  ):
128
  output_kwargs = self._merge_kwargs(
 
136
  text,
137
  output_kwargs["text_kwargs"]["message_format"],
138
  output_kwargs["text_kwargs"]["always_start_with_space"],
139
+ out_text
140
  )
141
 
142
+ if out_tokens is not None:
143
+ tokens = torch.cat([tokens, out_tokens], dim=0).tolist()
144
+
145
  image_token_id = self.special_token_ids[IMAGE_PROMPT]
146
 
147
  if images is not None: