diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 2a7d575ce..17b5140e2 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -5,6 +5,7 @@ import json import ctypes import dataclasses +import datetime import random import string @@ -23,6 +24,7 @@ ) import jinja2 +import jinja2.ext as jinja2_ext from jinja2.sandbox import ImmutableSandboxedEnvironment import numpy as np @@ -208,11 +210,20 @@ def __init__( set(stop_token_ids) if stop_token_ids is not None else None ) - self._environment = ImmutableSandboxedEnvironment( - loader=jinja2.BaseLoader(), + # self._environment = ImmutableSandboxedEnvironment( + # loader=jinja2.BaseLoader(), + # trim_blocks=True, + # lstrip_blocks=True, + # ).from_string(self.template) + + environment = ImmutableSandboxedEnvironment( trim_blocks=True, lstrip_blocks=True, - ).from_string(self.template) + extensions=[jinja2_ext.loopcontrols], + ) + environment.filters["tojson"] = lambda x, indent=None, separators=None, sort_keys=False: json.dumps(x, indent=indent, separators=separators, sort_keys=sort_keys, ensure_ascii=False) + environment.globals["strftime_now"] = lambda format: datetime.datetime.now().strftime(format) + self._environment = environment.from_string(self.template) def __call__( self,