diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 2aa34f878..971220b4e 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -403,6 +403,12 @@ def visualize_token_role(tokens: list[int], masks: list[int], tokenizer: PreTrai "{% if not has_system %}" "{{ '<|im_start|>system\nYou are Olmo, a helpful AI assistant built by Ai2. Your date cutoff is December 2024, and your model weights are available at https://huggingface.co/allenai.<|im_end|>\n' }}" "{% endif %}" + "{% set last_user_index = -1 %}" + "{% for message in messages %}" + "{% if message['role'] == 'user' %}" + "{% set last_user_index = loop.index0 %}" + "{% endif %}" + "{% endfor %}" "{% for message in messages %}" "{% if message['role'] == 'system' %}" "{{ '<|im_start|>system\n' + message['content'] }}" @@ -418,10 +424,18 @@ def visualize_token_role(tokens: list[int], masks: list[int], tokenizer: PreTrai "{{ '<|im_start|>user\n' + message['content'] + '<|im_end|>\n' }}" "{% endif %}" "{% elif message['role'] == 'assistant' %}" + "{% set assistant_content = message.get('content', '') %}" + "{% set reasoning_content = '' %}" + "{% if '' in assistant_content %}" + "{% set think_split = assistant_content.split('') %}" + "{% set reasoning_content = think_split[0].rstrip('\\n').split('')[-1].lstrip('\\n') %}" + "{% set assistant_content = think_split[-1].lstrip('\\n') %}" + "{% endif %}" "{{ '<|im_start|>assistant\n' }}" - "{% if message.get('content', none) is not none %}" - "{{ message['content'] }}" + "{% if loop.index0 > last_user_index and reasoning_content.strip() %}" + "{{ '\\n' + reasoning_content.strip('\\n') + '\\n\\n\\n' }}" "{% endif %}" + "{{ assistant_content }}" "{% if message.get('function_calls', none) is not none %}" "{{ '' + message['function_calls'] + '' }}" "{% endif %}"