Source code for lmflow.utils.conversation_template.llama
#!/usr/bin/env python
# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved.
import logging
from collections.abc import Sequence
from typing import Optional
from transformers import PreTrainedTokenizer
from lmflow.utils.constants import CONVERSATION_ROLE_NAMES
from .base import ConversationTemplate, ConversationTemplateForTool, StringFormatter, TemplateComponent
[docs]
class Llama2ConversationTemplate(ConversationTemplate):
[docs]
def _encode(
self,
tokenizer: PreTrainedTokenizer,
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
**kwargs,
) -> Sequence[tuple[list[int], list[int]]]:
if tools:
logger.warning(
"Formatted tools are not supported in Llama2, thus tools will be ignored. "
"If this is intended, please include tools in the system message manually."
)
res_all = []
system_formatted = self.system_formatter.format(content=system) if system else []
system_formatted_text = "".join(
[component.content for component in system_formatted if component.type == "string"]
) # HACK
for i in range(0, len(messages), 2):
user_message = messages[i]
assistant_message = messages[i + 1]
user_content = system_formatted_text + user_message["content"] if i == 0 else user_message["content"]
user_formatted = self.user_formatter.format(content=user_content)
assistant_formatted = self.assistant_formatter.format(content=assistant_message["content"])
user_encoded = self._encode_template(user_formatted, tokenizer)
assistant_encoded = self._encode_template(assistant_formatted, tokenizer)
res_all.append((user_encoded, assistant_encoded))
return res_all
[docs]
class Llama2ConversationTemplateForTool(Llama2ConversationTemplate):
[docs]
def _encode(
self,
tokenizer: PreTrainedTokenizer,
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
**kwargs,
) -> Sequence[tuple[list[int], list[int]]]:
if tools:
# logger.warning("Formatted tools are not supported in Llama2, thus tools will be ignored. "
# "If this is intended, please include tools in the system message manually.")
system = system + tools
res_all = []
system_formatted = self.system_formatter.format(content=system) if system else []
system_formatted_text = "".join(
[component.content for component in system_formatted if component.type == "string"]
) # HACK
ls_for_save = []
for i in range(0, len(messages), 1):
if messages[i]["role"] == CONVERSATION_ROLE_NAMES["user"]:
user_message = messages[i]
if i == 0:
user_content = system_formatted_text + user_message["content"]
else:
user_content = user_message["content"]
user_formatted = self.user_formatter.format(content=user_content)
user_encoded = self._encode_template(user_formatted, tokenizer)
ls_for_save.append(user_encoded)
elif messages[i]["role"] == CONVERSATION_ROLE_NAMES["function"]:
function_message = messages[i]
function_formatted = self.assistant_formatter.format(content=function_message["content"])
function_encoded = self._encode_template(function_formatted, tokenizer)
ls_for_save.append(function_encoded)
elif messages[i]["role"] == CONVERSATION_ROLE_NAMES["observation"]:
observation_message = messages[i]
observation_formatted = self.user_formatter.format(content=observation_message["content"])
observation_encoded = self._encode_template(observation_formatted, tokenizer)
ls_for_save.append(observation_encoded)
elif messages[i]["role"] == CONVERSATION_ROLE_NAMES["assistant"]:
assistant_message = messages[i]
assistant_formatted = self.assistant_formatter.format(content=assistant_message["content"])
assistant_encoded = self._encode_template(assistant_formatted, tokenizer)
ls_for_save.append(assistant_encoded)
# res_tuple = (ls_for_save[0], ls_for_save[1], ls_for_save[2], ls_for_save[3])
res_all.append(tuple(ls_for_save))
ls_for_save = []
if ls_for_save:
res_all.append(tuple(ls_for_save))
return res_all
[docs]
LLAMA3_TEMPLATE = ConversationTemplate(
template_name="llama3",
user_formatter=StringFormatter(
template=[
TemplateComponent(
type="string", content="<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
)
]
),
assistant_formatter=StringFormatter(
template=[
TemplateComponent(
type="string", content="<|start_header_id|>assistant<|end_header_id|>\n\n{{content}}<|eot_id|>"
)
]
),
system_formatter=StringFormatter(
template=[
TemplateComponent(
type="string", content="<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"
)
]
),
special_starter=TemplateComponent(type="token", content="bos_token"),
)
[docs]
LLAMA3_TEMPLATE_FOR_TOOL = ConversationTemplateForTool(
template_name="llama3_for_tool",
user_formatter=StringFormatter(
template=[
TemplateComponent(
type="string", content="<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
)
]
),
function_formatter=StringFormatter(
template=[
TemplateComponent(
type="string", content="<|start_header_id|>assistant<|end_header_id|>\n\n{{content}}<|eot_id|>"
)
]
),
observation_formatter=StringFormatter(
template=[
TemplateComponent(
type="string", content="<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
)
]
),
assistant_formatter=StringFormatter(
template=[
TemplateComponent(
type="string", content="<|start_header_id|>assistant<|end_header_id|>\n\n{{content}}<|eot_id|>"
)
]
),
system_formatter=StringFormatter(
template=[
TemplateComponent(
type="string", content="<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"
)
]
),
special_starter=TemplateComponent(type="token", content="bos_token"),
)
[docs]
LLAMA2_TEMPLATE = Llama2ConversationTemplate(
template_name="llama2",
user_formatter=StringFormatter(
template=[
TemplateComponent(type="token", content="bos_token"),
TemplateComponent(type="string", content="[INST] {{content}} [/INST]"),
]
),
assistant_formatter=StringFormatter(
template=[
TemplateComponent(type="string", content="{{content}}"),
TemplateComponent(type="token", content="eos_token"),
]
),
system_formatter=StringFormatter(
template=[TemplateComponent(type="string", content="<<SYS>>\n{{content}}\n<</SYS>>\n\n")]
),
)
[docs]
LLAMA2_TEMPLATE_FOR_TOOL = Llama2ConversationTemplate(
template_name="llama2_for_tool",
user_formatter=StringFormatter(
template=[
TemplateComponent(type="token", content="bos_token"),
TemplateComponent(type="string", content="[INST] {{content}} [/INST]"),
]
),
assistant_formatter=StringFormatter(
template=[
TemplateComponent(type="string", content="{{content}}"),
TemplateComponent(type="token", content="eos_token"),
]
),
system_formatter=StringFormatter(
template=[TemplateComponent(type="string", content="<<SYS>>\n{{content}}\n<</SYS>>\n\n")]
),
)