From c84d5ff99999640e9da3754d0f31ca805c6c2545 Mon Sep 17 00:00:00 2001 From: Peter-Jan Karens Date: Sun, 24 Dec 2023 00:45:40 +0100 Subject: [PATCH 1/2] wip: chatml function calling --- llama_cpp/llama_chat_format.py | 53 ++++++++++++++++++---------------- llama_cpp/llama_types.py | 1 + 2 files changed, 29 insertions(+), 25 deletions(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 037f96a2dd..79d829f728 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -5,6 +5,7 @@ import ctypes import dataclasses from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol +import uuid import llama_cpp.llama as llama import llama_cpp.llama_types as llama_types @@ -704,7 +705,7 @@ def format_openchat( return ChatFormatterResponse(prompt=_prompt, stop=_sep) -@register_chat_completion_handler("functionary") +@register_chat_completion_handler("chatml-function-calling") def functionary_chat_handler( llama: llama.Llama, messages: List[llama_types.ChatCompletionRequestMessage], @@ -871,40 +872,40 @@ def prepare_messages_for_inference( def message_to_str(msg: llama_types.ChatCompletionRequestMessage): if msg["role"] == "system": - return f"system:\n{msg['content']}\n" + return f"system\n{msg['content']}<|im_end|>\n" - elif msg["role"] == "function" and "name" in msg: - return f"function name={msg['name']}:\n{msg['content']}\n" - elif msg["role"] == "function" and "function_call" in msg: - return f"function name={msg['function_call']['name']}:\n{msg['function_call']['arguments']}\n" + # elif msg["role"] == "function" and "name" in msg: + # return f"function name={msg['name']}:\n{msg['content']}<|im_end|>\n" + # elif msg["role"] == "function" and "function_call" in msg: + # return f"function name={msg['function_call']['name']}:\n{msg['function_call']['arguments']}<|im_end|>\n" elif msg["role"] == "tool": if msg["content"] is not None: - return f"function name={msg['tool_call_id']}:\n{msg['content']}\n" + return f"function id={msg['tool_call_id']} name=functions.{msg['name']}\n{msg['content']}<|im_end|>\n" else: - return f"function name={msg['tool_call_id']}\n" + return f"function id={msg['tool_call_id']} name=functions.{msg['name']}<|im_end|>\n" elif msg["role"] == "user": if msg["content"] is None: - return "user:\n\n" + return "user\n<|im_end|>\n" else: - return f"user:\n{msg['content']}\n" + return f"user\n{msg['content']}<|im_end|>\n" elif msg["role"] == "assistant": - if msg["content"] is not None and "function_call" in msg: - return f"assistant:\n{msg['content']}\nassistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}\n" - elif "function_call" in msg: - return f"assistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}\n" - elif "tool_calls" in msg and len(msg["tool_calls"]) > 0: + # if msg["content"] is not None and "function_call" in msg: + # return f"assistant\n{msg['content']}\nassistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}<|im_end|>\n" + # elif "function_call" in msg: + # return f"assistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}<|im_end|>\n" + if "tool_calls" in msg and len(msg["tool_calls"]) > 0: for tool_call in msg[ "tool_calls" ]: # NOTE: probably doesn't work with the functionary model - return f"assistant to={tool_call['id']}:\n{tool_call['function']['arguments']}\n" + return f"assistant id={tool_call['id']} to=functions.{tool_call['function']['name']}\n{tool_call['function']['arguments']}<|im_end|>\n" elif msg["content"] is None: return "assistant" else: - return f"assistant:\n{msg['content']}\n" + return f"assistant\n{msg['content']}<|im_end|>\n" else: raise ValueError(f"Unsupported role: {msg['role']}") - return "".join([message_to_str(msg) for msg in all_messages]) + return "".join([""+message_to_str(msg) for msg in all_messages]) if tools is not None: functions = [tool["function"] for tool in tools if tool["type"] == "function"] @@ -918,14 +919,14 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage): if function_call is None and (functions is None or len(functions) == 0): completion_or_completion_chunks = llama.create_completion( - prompt=prompt + ":\n", + prompt=prompt, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, typical_p=typical_p, stream=stream, - stop=["user:", ""], + stop=["user:", "", "<|im_end|>"], max_tokens=max_tokens, presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, @@ -944,20 +945,22 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage): isinstance(function_call, str) and function_call == "auto" ): stop = "\n" + call_id = str(uuid.uuid4()).split("-")[0] + prompt = prompt +" id=call_" + call_id + " name=functions." completion: llama_types.Completion = llama.create_completion( prompt=prompt, stop=stop, stream=False ) # type: ignore completion_text = completion["choices"][0]["text"] # strip " to=functions." and ending ":" - function_call = completion_text.split(".")[-1][:-1] + function_call = completion_text#.split(".")[-1][:-1] new_prompt = prompt + completion_text + stop elif isinstance(function_call, str) and function_call != "none": - new_prompt = prompt + f":\n" + new_prompt = prompt elif isinstance(function_call, dict): - new_prompt = prompt + f" to=functions.{function_call['name']}:\n" + new_prompt = prompt + f" to=functions.{function_call['name']}\n" function_call = function_call["name"] else: - new_prompt = prompt + f":\n" + new_prompt = prompt function_body = None for function in functions or []: @@ -995,7 +998,7 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage): completion: llama_types.Completion = llama.create_completion( prompt=new_prompt, - stop=["user:", ""], + stop=["user:", "", "<|im_end|>", "\n\n\n\n"], stream=False, grammar=grammar, max_tokens=max_tokens, diff --git a/llama_cpp/llama_types.py b/llama_cpp/llama_types.py index 5b51e98ce4..0ebe91e215 100644 --- a/llama_cpp/llama_types.py +++ b/llama_cpp/llama_types.py @@ -219,6 +219,7 @@ class ChatCompletionRequestToolMessage(TypedDict): role: Literal["tool"] content: Optional[str] tool_call_id: str + name: str class ChatCompletionRequestFunctionMessage(TypedDict): From d9eb95987bf87774941f9a5870b03e3ec7e8e952 Mon Sep 17 00:00:00 2001 From: Peter-Jan Karens Date: Sun, 24 Dec 2023 01:05:35 +0100 Subject: [PATCH 2/2] chore: small fix to prevent query params in function name --- llama_cpp/llama_chat_format.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 79d829f728..6deb372b99 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -948,7 +948,7 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage): call_id = str(uuid.uuid4()).split("-")[0] prompt = prompt +" id=call_" + call_id + " name=functions." completion: llama_types.Completion = llama.create_completion( - prompt=prompt, stop=stop, stream=False + prompt=prompt, stop=[stop, ' '], stream=False ) # type: ignore completion_text = completion["choices"][0]["text"] # strip " to=functions." and ending ":"