Source code for lmflow.utils.conversation_template.gemma
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved.
import logging
from dataclasses import dataclass
from .base import StringFormatter, TemplateComponent, ConversationTemplate
[docs]
logger = logging.getLogger(__name__)
@dataclass
[docs]
class GemmaConversationTemplate(ConversationTemplate):
[docs]
def encode_conversation(self, *args, **kwargs):
if kwargs.get('system'):
logger.warning(
'As of now, Gemma does not support system messages officially. '
'ConversationTemplate will add your system messages right after '
'the bos token and before the user message without any special formatting. '
'For more details, please refer to the [official template]'
'(https://huggingface.co/google/gemma-1.1-2b-it/blob/bf4924f313df5166dee1467161e886e55f2eb4d4/tokenizer_config.json#L1507).'
)
return super().encode_conversation(*args, **kwargs)
[docs]
GEMMA_TEMPLATE = GemmaConversationTemplate(
template_name='gemma',
user_formatter=StringFormatter(
template=[
TemplateComponent(type='string', content='<start_of_turn>user\n{{content}}<end_of_turn>\n')
]
),
assistant_formatter=StringFormatter(
template=[
TemplateComponent(type='string', content='<start_of_turn>model\n{{content}}<end_of_turn>\n')
]
),
system_formatter=StringFormatter(
template=[
TemplateComponent(type='string', content='{{content}}')
]
),
special_starter=TemplateComponent(type='token', content='bos_token')
)