Source code for lmflow.utils.conversation_template.llama

#!/usr/bin/env python
# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved.
import logging
from collections.abc import Sequence
from typing import Optional

from transformers import PreTrainedTokenizer

from lmflow.utils.constants import CONVERSATION_ROLE_NAMES

from .base import ConversationTemplate, ConversationTemplateForTool, StringFormatter, TemplateComponent

[docs] logger = logging.getLogger(__name__)
[docs] class Llama2ConversationTemplate(ConversationTemplate):
[docs] def _encode( self, tokenizer: PreTrainedTokenizer, messages: list[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, **kwargs, ) -> Sequence[tuple[list[int], list[int]]]: if tools: logger.warning( "Formatted tools are not supported in Llama2, thus tools will be ignored. " "If this is intended, please include tools in the system message manually." ) res_all = [] system_formatted = self.system_formatter.format(content=system) if system else [] system_formatted_text = "".join( [component.content for component in system_formatted if component.type == "string"] ) # HACK for i in range(0, len(messages), 2): user_message = messages[i] assistant_message = messages[i + 1] user_content = system_formatted_text + user_message["content"] if i == 0 else user_message["content"] user_formatted = self.user_formatter.format(content=user_content) assistant_formatted = self.assistant_formatter.format(content=assistant_message["content"]) user_encoded = self._encode_template(user_formatted, tokenizer) assistant_encoded = self._encode_template(assistant_formatted, tokenizer) res_all.append((user_encoded, assistant_encoded)) return res_all
[docs] class Llama2ConversationTemplateForTool(Llama2ConversationTemplate):
[docs] def _encode( self, tokenizer: PreTrainedTokenizer, messages: list[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, **kwargs, ) -> Sequence[tuple[list[int], list[int]]]: if tools: # logger.warning("Formatted tools are not supported in Llama2, thus tools will be ignored. " # "If this is intended, please include tools in the system message manually.") system = system + tools res_all = [] system_formatted = self.system_formatter.format(content=system) if system else [] system_formatted_text = "".join( [component.content for component in system_formatted if component.type == "string"] ) # HACK ls_for_save = [] for i in range(0, len(messages), 1): if messages[i]["role"] == CONVERSATION_ROLE_NAMES["user"]: user_message = messages[i] if i == 0: user_content = system_formatted_text + user_message["content"] else: user_content = user_message["content"] user_formatted = self.user_formatter.format(content=user_content) user_encoded = self._encode_template(user_formatted, tokenizer) ls_for_save.append(user_encoded) elif messages[i]["role"] == CONVERSATION_ROLE_NAMES["function"]: function_message = messages[i] function_formatted = self.assistant_formatter.format(content=function_message["content"]) function_encoded = self._encode_template(function_formatted, tokenizer) ls_for_save.append(function_encoded) elif messages[i]["role"] == CONVERSATION_ROLE_NAMES["observation"]: observation_message = messages[i] observation_formatted = self.user_formatter.format(content=observation_message["content"]) observation_encoded = self._encode_template(observation_formatted, tokenizer) ls_for_save.append(observation_encoded) elif messages[i]["role"] == CONVERSATION_ROLE_NAMES["assistant"]: assistant_message = messages[i] assistant_formatted = self.assistant_formatter.format(content=assistant_message["content"]) assistant_encoded = self._encode_template(assistant_formatted, tokenizer) ls_for_save.append(assistant_encoded) # res_tuple = (ls_for_save[0], ls_for_save[1], ls_for_save[2], ls_for_save[3]) res_all.append(tuple(ls_for_save)) ls_for_save = [] if ls_for_save: res_all.append(tuple(ls_for_save)) return res_all
[docs] LLAMA3_TEMPLATE = ConversationTemplate( template_name="llama3", user_formatter=StringFormatter( template=[ TemplateComponent( type="string", content="<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>" ) ] ), assistant_formatter=StringFormatter( template=[ TemplateComponent( type="string", content="<|start_header_id|>assistant<|end_header_id|>\n\n{{content}}<|eot_id|>" ) ] ), system_formatter=StringFormatter( template=[ TemplateComponent( type="string", content="<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>" ) ] ), special_starter=TemplateComponent(type="token", content="bos_token"), )
[docs] LLAMA3_TEMPLATE_FOR_TOOL = ConversationTemplateForTool( template_name="llama3_for_tool", user_formatter=StringFormatter( template=[ TemplateComponent( type="string", content="<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>" ) ] ), function_formatter=StringFormatter( template=[ TemplateComponent( type="string", content="<|start_header_id|>assistant<|end_header_id|>\n\n{{content}}<|eot_id|>" ) ] ), observation_formatter=StringFormatter( template=[ TemplateComponent( type="string", content="<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>" ) ] ), assistant_formatter=StringFormatter( template=[ TemplateComponent( type="string", content="<|start_header_id|>assistant<|end_header_id|>\n\n{{content}}<|eot_id|>" ) ] ), system_formatter=StringFormatter( template=[ TemplateComponent( type="string", content="<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>" ) ] ), special_starter=TemplateComponent(type="token", content="bos_token"), )
[docs] LLAMA2_TEMPLATE = Llama2ConversationTemplate( template_name="llama2", user_formatter=StringFormatter( template=[ TemplateComponent(type="token", content="bos_token"), TemplateComponent(type="string", content="[INST] {{content}} [/INST]"), ] ), assistant_formatter=StringFormatter( template=[ TemplateComponent(type="string", content="{{content}}"), TemplateComponent(type="token", content="eos_token"), ] ), system_formatter=StringFormatter( template=[TemplateComponent(type="string", content="<<SYS>>\n{{content}}\n<</SYS>>\n\n")] ), )
[docs] LLAMA2_TEMPLATE_FOR_TOOL = Llama2ConversationTemplate( template_name="llama2_for_tool", user_formatter=StringFormatter( template=[ TemplateComponent(type="token", content="bos_token"), TemplateComponent(type="string", content="[INST] {{content}} [/INST]"), ] ), assistant_formatter=StringFormatter( template=[ TemplateComponent(type="string", content="{{content}}"), TemplateComponent(type="token", content="eos_token"), ] ), system_formatter=StringFormatter( template=[TemplateComponent(type="string", content="<<SYS>>\n{{content}}\n<</SYS>>\n\n")] ), )