compose-rkllm_chat/rkllm_server/gradio_server.py

413 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import ctypes
import sys
import os
import subprocess
import resource
import threading
import time
import gradio as gr
import argparse
PROMPT_TEXT_SYSTEM_COMMON = (
"You are a helpful assistant who answers things succintly and in steps. "
"You listen to the user's question and answer accordingly, or point out factual "
"errors in the user's claims."
)
# Re: phi3.5
PROMPT_TEXT_PREFIX = (
"<|system|> {content} <|end|>\n<|user|>"
).format(content = PROMPT_TEXT_SYSTEM_COMMON)
PROMPT_TEXT_POSTFIX = "<|end|>\n<|assistant|>"
# Re: Qwen
# PROMPT_TEXT_PREFIX = (
# "<|im_start|>system {content} <|im_end|>\n<|im_start|>user"
# ).format(content = PROMPT_TEXT_SYSTEM_COMMON)
# PROMPT_TEXT_POSTFIX = "<|im_end|>\n<|im_start|>assistant"
# Re: TinyLlama
# PROMPT_TEXT_PREFIX = (
# "[INST] <<SYS>>{content}<</SYS>>\n"
# ).format(content = PROMPT_TEXT_SYSTEM_COMMON)
# PROMPT_TEXT_POSTFIX = " [/INST]\n"
# Set environment variables
os.environ["GRADIO_SERVER_NAME"] = "0.0.0.0"
os.environ["GRADIO_SERVER_PORT"] = "8080"
# Set the dynamic library path
rkllm_lib = ctypes.CDLL('./lib/librkllmrt.so')
# Define the structures from the library
RKLLM_Handle_t = ctypes.c_void_p
userdata = ctypes.c_void_p(None)
LLMCallState = ctypes.c_int
LLMCallState.RKLLM_RUN_NORMAL = 0
LLMCallState.RKLLM_RUN_WAITING = 1
LLMCallState.RKLLM_RUN_FINISH = 2
LLMCallState.RKLLM_RUN_ERROR = 3
LLMCallState.RKLLM_RUN_GET_LAST_HIDDEN_LAYER = 4
RKLLMInputMode = ctypes.c_int
RKLLMInputMode.RKLLM_INPUT_PROMPT = 0
RKLLMInputMode.RKLLM_INPUT_TOKEN = 1
RKLLMInputMode.RKLLM_INPUT_EMBED = 2
RKLLMInputMode.RKLLM_INPUT_MULTIMODAL = 3
RKLLMInferMode = ctypes.c_int
RKLLMInferMode.RKLLM_INFER_GENERATE = 0
RKLLMInferMode.RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1
class RKLLMExtendParam(ctypes.Structure):
_fields_ = [
("base_domain_id", ctypes.c_int32),
("reserved", ctypes.c_uint8 * 112)
]
class RKLLMParam(ctypes.Structure):
_fields_ = [
("model_path", ctypes.c_char_p),
("max_context_len", ctypes.c_int32),
("max_new_tokens", ctypes.c_int32),
("top_k", ctypes.c_int32),
("top_p", ctypes.c_float),
("temperature", ctypes.c_float),
("repeat_penalty", ctypes.c_float),
("frequency_penalty", ctypes.c_float),
("presence_penalty", ctypes.c_float),
("mirostat", ctypes.c_int32),
("mirostat_tau", ctypes.c_float),
("mirostat_eta", ctypes.c_float),
("skip_special_token", ctypes.c_bool),
("is_async", ctypes.c_bool),
("img_start", ctypes.c_char_p),
("img_end", ctypes.c_char_p),
("img_content", ctypes.c_char_p),
("extend_param", RKLLMExtendParam),
]
class RKLLMLoraAdapter(ctypes.Structure):
_fields_ = [
("lora_adapter_path", ctypes.c_char_p),
("lora_adapter_name", ctypes.c_char_p),
("scale", ctypes.c_float)
]
class RKLLMEmbedInput(ctypes.Structure):
_fields_ = [
("embed", ctypes.POINTER(ctypes.c_float)),
("n_tokens", ctypes.c_size_t)
]
class RKLLMTokenInput(ctypes.Structure):
_fields_ = [
("input_ids", ctypes.POINTER(ctypes.c_int32)),
("n_tokens", ctypes.c_size_t)
]
class RKLLMMultiModelInput(ctypes.Structure):
_fields_ = [
("prompt", ctypes.c_char_p),
("image_embed", ctypes.POINTER(ctypes.c_float)),
("n_image_tokens", ctypes.c_size_t)
]
class RKLLMInputUnion(ctypes.Union):
_fields_ = [
("prompt_input", ctypes.c_char_p),
("embed_input", RKLLMEmbedInput),
("token_input", RKLLMTokenInput),
("multimodal_input", RKLLMMultiModelInput)
]
class RKLLMInput(ctypes.Structure):
_fields_ = [
("input_mode", ctypes.c_int),
("input_data", RKLLMInputUnion)
]
class RKLLMLoraParam(ctypes.Structure):
_fields_ = [
("lora_adapter_name", ctypes.c_char_p)
]
class RKLLMPromptCacheParam(ctypes.Structure):
_fields_ = [
("save_prompt_cache", ctypes.c_int),
("prompt_cache_path", ctypes.c_char_p)
]
class RKLLMInferParam(ctypes.Structure):
_fields_ = [
("mode", RKLLMInferMode),
("lora_params", ctypes.POINTER(RKLLMLoraParam)),
("prompt_cache_params", ctypes.POINTER(RKLLMPromptCacheParam))
]
class RKLLMResultLastHiddenLayer(ctypes.Structure):
_fields_ = [
("hidden_states", ctypes.POINTER(ctypes.c_float)),
("embd_size", ctypes.c_int),
("num_tokens", ctypes.c_int)
]
class RKLLMResult(ctypes.Structure):
_fields_ = [
("text", ctypes.c_char_p),
("size", ctypes.c_int),
("last_hidden_layer", RKLLMResultLastHiddenLayer)
]
# Define global variables to store the callback function output for displaying in the Gradio interface
global_text = []
global_state = -1
split_byte_data = bytes(b"") # Used to store the segmented byte data
# Define the callback function
def callback_impl(result, userdata, state):
global global_text, global_state, split_byte_data
if state == LLMCallState.RKLLM_RUN_FINISH:
global_state = state
print("\n")
sys.stdout.flush()
elif state == LLMCallState.RKLLM_RUN_ERROR:
global_state = state
print("run error")
sys.stdout.flush()
elif state == LLMCallState.RKLLM_RUN_GET_LAST_HIDDEN_LAYER:
'''
If using the GET_LAST_HIDDEN_LAYER function, the callback interface will return the memory pointer: last_hidden_layer, the number of tokens: num_tokens, and the size of the hidden layer: embd_size.
With these three parameters, you can retrieve the data from last_hidden_layer.
Note: The data needs to be retrieved during the current callback; if not obtained in time, the pointer will be released by the next callback.
'''
if result.last_hidden_layer.embd_size != 0 and result.last_hidden_layer.num_tokens != 0:
data_size = result.last_hidden_layer.embd_size * result.last_hidden_layer.num_tokens * ctypes.sizeof(ctypes.c_float)
print(f"data_size: {data_size}")
global_text.append(f"data_size: {data_size}\n")
output_path = os.getcwd() + "/last_hidden_layer.bin"
with open(output_path, "wb") as outFile:
data = ctypes.cast(result.last_hidden_layer.hidden_states, ctypes.POINTER(ctypes.c_float))
float_array_type = ctypes.c_float * (data_size // ctypes.sizeof(ctypes.c_float))
float_array = float_array_type.from_address(ctypes.addressof(data.contents))
outFile.write(bytearray(float_array))
print(f"Data saved to {output_path} successfully!")
global_text.append(f"Data saved to {output_path} successfully!")
else:
print("Invalid hidden layer data.")
global_text.append("Invalid hidden layer data.")
global_state = state
time.sleep(0.05) # Delay for 0.05 seconds to wait for the output result
sys.stdout.flush()
else:
# Save the output token text and the RKLLM running state
global_state = state
# Monitor if the current byte data is complete; if incomplete, record it for later parsing
try:
global_text.append((split_byte_data + result.contents.text).decode('utf-8'))
print((split_byte_data + result.contents.text).decode('utf-8'), end='')
split_byte_data = bytes(b"")
except:
split_byte_data += result.contents.text
sys.stdout.flush()
# Connect the callback function between the Python side and the C++ side
callback_type = ctypes.CFUNCTYPE(None, ctypes.POINTER(RKLLMResult), ctypes.c_void_p, ctypes.c_int)
callback = callback_type(callback_impl)
# Define the RKLLM class, which includes initialization, inference, and release operations for the RKLLM model in the dynamic library
class RKLLM(object):
def __init__(self, model_path, lora_model_path = None, prompt_cache_path = None):
rkllm_param = RKLLMParam()
rkllm_param.model_path = bytes(model_path, 'utf-8')
rkllm_param.max_context_len = 4096
rkllm_param.max_new_tokens = 256
rkllm_param.skip_special_token = True
rkllm_param.top_k = 40
rkllm_param.top_p = 0.9
rkllm_param.temperature = 0.5
rkllm_param.repeat_penalty = 1.1
rkllm_param.frequency_penalty = 0.0
rkllm_param.presence_penalty = 0.0
rkllm_param.mirostat = 0
rkllm_param.mirostat_tau = 5.0
rkllm_param.mirostat_eta = 0.1
rkllm_param.is_async = True
rkllm_param.img_start = "".encode('utf-8')
rkllm_param.img_end = "".encode('utf-8')
rkllm_param.img_content = "".encode('utf-8')
rkllm_param.extend_param.base_domain_id = 0
self.handle = RKLLM_Handle_t()
self.rkllm_init = rkllm_lib.rkllm_init
self.rkllm_init.argtypes = [ctypes.POINTER(RKLLM_Handle_t), ctypes.POINTER(RKLLMParam), callback_type]
self.rkllm_init.restype = ctypes.c_int
self.rkllm_init(ctypes.byref(self.handle), ctypes.byref(rkllm_param), callback)
self.rkllm_run = rkllm_lib.rkllm_run
self.rkllm_run.argtypes = [RKLLM_Handle_t, ctypes.POINTER(RKLLMInput), ctypes.POINTER(RKLLMInferParam), ctypes.c_void_p]
self.rkllm_run.restype = ctypes.c_int
self.rkllm_destroy = rkllm_lib.rkllm_destroy
self.rkllm_destroy.argtypes = [RKLLM_Handle_t]
self.rkllm_destroy.restype = ctypes.c_int
self.lora_adapter_path = None
self.lora_model_name = None
if lora_model_path:
self.lora_adapter_path = lora_model_path
self.lora_adapter_name = "test"
lora_adapter = RKLLMLoraAdapter()
ctypes.memset(ctypes.byref(lora_adapter), 0, ctypes.sizeof(RKLLMLoraAdapter))
lora_adapter.lora_adapter_path = ctypes.c_char_p((self.lora_adapter_path).encode('utf-8'))
lora_adapter.lora_adapter_name = ctypes.c_char_p((self.lora_adapter_name).encode('utf-8'))
lora_adapter.scale = 1.0
rkllm_load_lora = rkllm_lib.rkllm_load_lora
rkllm_load_lora.argtypes = [RKLLM_Handle_t, ctypes.POINTER(RKLLMLoraAdapter)]
rkllm_load_lora.restype = ctypes.c_int
rkllm_load_lora(self.handle, ctypes.byref(lora_adapter))
self.prompt_cache_path = None
if prompt_cache_path:
self.prompt_cache_path = prompt_cache_path
rkllm_load_prompt_cache = rkllm_lib.rkllm_load_prompt_cache
rkllm_load_prompt_cache.argtypes = [RKLLM_Handle_t, ctypes.c_char_p]
rkllm_load_prompt_cache.restype = ctypes.c_int
rkllm_load_prompt_cache(self.handle, ctypes.c_char_p((prompt_cache_path).encode('utf-8')))
def run(self, prompt):
rkllm_lora_params = None
if self.lora_model_name:
rkllm_lora_params = RKLLMLoraParam()
rkllm_lora_params.lora_adapter_name = ctypes.c_char_p((self.lora_model_name).encode('utf-8'))
rkllm_infer_params = RKLLMInferParam()
ctypes.memset(ctypes.byref(rkllm_infer_params), 0, ctypes.sizeof(RKLLMInferParam))
rkllm_infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE
rkllm_infer_params.lora_params = ctypes.byref(rkllm_lora_params) if rkllm_lora_params else None
rkllm_input = RKLLMInput()
rkllm_input.input_mode = RKLLMInputMode.RKLLM_INPUT_PROMPT
rkllm_input.input_data.prompt_input = ctypes.c_char_p((PROMPT_TEXT_PREFIX + prompt + PROMPT_TEXT_POSTFIX).encode('utf-8'))
self.rkllm_run(self.handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), None)
return
def release(self):
self.rkllm_destroy(self.handle)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--rkllm_model_path', type=str, required=True, help='Absolute path of the converted RKLLM model on the Linux board;')
parser.add_argument('--target_platform', type=str, required=True, help='Target platform: e.g., rk3588/rk3576;')
parser.add_argument('--lora_model_path', type=str, help='Absolute path of the lora_model on the Linux board;')
parser.add_argument('--prompt_cache_path', type=str, help='Absolute path of the prompt_cache file on the Linux board;')
args = parser.parse_args()
if not os.path.exists(args.rkllm_model_path):
print("Error: Please provide the correct rkllm model path, and ensure it is the absolute path on the board.")
sys.stdout.flush()
exit()
if not (args.target_platform in ["rk3588", "rk3576"]):
print("Error: Please specify the correct target platform: rk3588/rk3576.")
sys.stdout.flush()
exit()
if args.lora_model_path:
if not os.path.exists(args.lora_model_path):
print("Error: Please provide the correct lora_model path, and advise it is the absolute path on the board.")
sys.stdout.flush()
exit()
if args.prompt_cache_path:
if not os.path.exists(args.prompt_cache_path):
print("Error: Please provide the correct prompt_cache_file path, and advise it is the absolute path on the board.")
sys.stdout.flush()
exit()
# Fix frequency
# command = "sudo bash fix_freq_{}.sh".format(args.target_platform)
# subprocess.run(command, shell=True)
# Set resource limit
resource.setrlimit(resource.RLIMIT_NOFILE, (102400, 102400))
# Initialize RKLLM model
print("=========init....===========")
sys.stdout.flush()
model_path = args.rkllm_model_path
rkllm_model = RKLLM(model_path, args.lora_model_path, args.prompt_cache_path)
print("RKLLM Model has been initialized successfully")
print("==============================")
sys.stdout.flush()
# Record the user's input prompt
def get_user_input(user_message, history):
history = history + [[user_message, None]]
return "", history
# Retrieve the output from the RKLLM model and print it in a streaming manner
def get_RKLLM_output(history):
# Link global variables to retrieve the output information from the callback function
global global_text, global_state
global_text = []
global_state = -1
# Create a thread for model inference
model_thread = threading.Thread(target=rkllm_model.run, args=(history[-1][0],))
model_thread.start()
# history[-1][1] represents the current dialogue
history[-1][1] = ""
# Wait for the model to finish running and periodically check the inference thread of the model
model_thread_finished = False
while not model_thread_finished:
while len(global_text) > 0:
history[-1][1] += global_text.pop(0)
time.sleep(0.005)
# Gradio automatically pushes the result returned by the yield statement when calling the then method
yield history
model_thread.join(timeout=0.005)
model_thread_finished = not model_thread.is_alive()
# Create a Gradio interface
with gr.Blocks(title="Chat with RKLLM") as chatRKLLM:
gr.Markdown("<div align='center'><font size='70'> Chat with RKLLM </font></div>")
gr.Markdown("### Enter your question in the inputTextBox and press the Enter key to chat with the RKLLM model.")
# Create a Chatbot component to display conversation history
rkllmServer = gr.Chatbot(height=600)
# Create a Textbox component for user message input
msg = gr.Textbox(placeholder="Please input your question here...", label="inputTextBox")
# Create a Button component to clear the chat history.
clear = gr.Button("Clear")
# Submit the user's input message to the get_user_input function and immediately update the chat history.
# Then call the get_RKLLM_output function to further update the chat history.
# The queue=False parameter ensures that these updates are not queued, but executed immediately.
msg.submit(get_user_input, [msg, rkllmServer], [msg, rkllmServer], queue=False).then(get_RKLLM_output, rkllmServer, rkllmServer)
# When the clear button is clicked, perform a no-operation (lambda: None) and immediately clear the chat history.
clear.click(lambda: None, None, rkllmServer, queue=False)
# Enable the event queue system.
chatRKLLM.queue()
# Start the Gradio application.
chatRKLLM.launch()
print("====================")
print("RKLLM model inference completed, releasing RKLLM model resources...")
rkllm_model.release()
print("====================")