Source code for lmflow.utils.conversation_template.hymba

#!/usr/bin/env python
# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved.
from typing import Optional

from .base import ConversationTemplateForTool, StringFormatter, TemplateComponent

# NOTE: 'contexts' are not used in sft
# {{'<extra_id_0>System'}}
# {% for message in messages %}
#   {% if message['role'] == 'system' %}
#       {{'\n' + message['content'].strip()}}
#       {% if tools %}
#           {{'\n'}}
#       {% endif %}
#   {% endif %}
# {% endfor %}
# {% if tools %}
#   {% for tool in tools %}
#       {{ '\n<tool> ' + tool|tojson + ' </tool>' }}
#   {% endfor %}
# {% endif %}
# {{'\n\n'}}
# {% for message in messages %}
#   {% if message['role'] == 'user' %}
#       {{ '<extra_id_1>User\n' + message['content'].strip() + '\n' }}
#   {% elif message['role'] == 'assistant' %}
#       {{ '<extra_id_1>Assistant\n' + message['content'].strip() + '\n' }}
#   {% elif message['role'] == 'tool' %}
#       {{ '<extra_id_1>Tool\n' + message['content'].strip() + '\n' }}
#   {% endif %}
# {% endfor %}
# {%- if add_generation_prompt %}
#   {{'<extra_id_1>Assistant\n'}}
# {%- endif %}


[docs] class HymbaConversationTemplate(ConversationTemplateForTool):
[docs] def _handle_tools(self, tools: Optional[list[str]]) -> str: tools_out = "" if tools is not None: for tool in tools: tools_out += "\n<tool> " + tool + " </tool>" return tools_out
[docs] HYMBA_TEMPLATE = HymbaConversationTemplate( template_name="hymba", user_formatter=StringFormatter( template=[TemplateComponent(type="string", content="<extra_id_1>User\n{{content}}\n")] ), assistant_formatter=StringFormatter( template=[TemplateComponent(type="string", content="<extra_id_1>Assistant\n{{content}}\n")] ), function_formatter=StringFormatter( template=[TemplateComponent(type="string", content="<extra_id_1>Assistant\n{{content}}\n")] ), observation_formatter=StringFormatter( template=[TemplateComponent(type="string", content="<extra_id_1>Tool\n{{content}}\n")] ), system_formatter=StringFormatter( template=[TemplateComponent(type="string", content="<extra_id_0>System{{content}}\n\n")] ), separator=TemplateComponent(type="token_id", content=13), remove_last_sep=True, special_stopper=TemplateComponent(type="token", content="eos_token"), force_system=True, )