import ctypes import sys import os import subprocess import resource import threading import time import gradio as gr import argparse PROMPT_TEXT_SYSTEM_COMMON = ( "You are a helpful assistant who answers things succintly and in steps. " "You listen to the user's question and answer accordingly, or point out factual " "errors in the user's claims." ) # Re: phi3.5 PROMPT_TEXT_PREFIX = ( "<|system|> {content} <|end|>\n<|user|>" ).format(content = PROMPT_TEXT_SYSTEM_COMMON) PROMPT_TEXT_POSTFIX = "<|end|>\n<|assistant|>" # Re: Qwen # PROMPT_TEXT_PREFIX = ( # "<|im_start|>system {content} <|im_end|>\n<|im_start|>user" # ).format(content = PROMPT_TEXT_SYSTEM_COMMON) # PROMPT_TEXT_POSTFIX = "<|im_end|>\n<|im_start|>assistant" # Re: TinyLlama # PROMPT_TEXT_PREFIX = ( # "[INST] <>{content}<>\n" # ).format(content = PROMPT_TEXT_SYSTEM_COMMON) # PROMPT_TEXT_POSTFIX = " [/INST]\n" # Set environment variables os.environ["GRADIO_SERVER_NAME"] = "0.0.0.0" os.environ["GRADIO_SERVER_PORT"] = "8080" # Set the dynamic library path rkllm_lib = ctypes.CDLL('./lib/librkllmrt.so') # Define the structures from the library RKLLM_Handle_t = ctypes.c_void_p userdata = ctypes.c_void_p(None) LLMCallState = ctypes.c_int LLMCallState.RKLLM_RUN_NORMAL = 0 LLMCallState.RKLLM_RUN_WAITING = 1 LLMCallState.RKLLM_RUN_FINISH = 2 LLMCallState.RKLLM_RUN_ERROR = 3 LLMCallState.RKLLM_RUN_GET_LAST_HIDDEN_LAYER = 4 RKLLMInputMode = ctypes.c_int RKLLMInputMode.RKLLM_INPUT_PROMPT = 0 RKLLMInputMode.RKLLM_INPUT_TOKEN = 1 RKLLMInputMode.RKLLM_INPUT_EMBED = 2 RKLLMInputMode.RKLLM_INPUT_MULTIMODAL = 3 RKLLMInferMode = ctypes.c_int RKLLMInferMode.RKLLM_INFER_GENERATE = 0 RKLLMInferMode.RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1 class RKLLMExtendParam(ctypes.Structure): _fields_ = [ ("base_domain_id", ctypes.c_int32), ("reserved", ctypes.c_uint8 * 112) ] class RKLLMParam(ctypes.Structure): _fields_ = [ ("model_path", ctypes.c_char_p), ("max_context_len", ctypes.c_int32), ("max_new_tokens", ctypes.c_int32), ("top_k", ctypes.c_int32), ("top_p", ctypes.c_float), ("temperature", ctypes.c_float), ("repeat_penalty", ctypes.c_float), ("frequency_penalty", ctypes.c_float), ("presence_penalty", ctypes.c_float), ("mirostat", ctypes.c_int32), ("mirostat_tau", ctypes.c_float), ("mirostat_eta", ctypes.c_float), ("skip_special_token", ctypes.c_bool), ("is_async", ctypes.c_bool), ("img_start", ctypes.c_char_p), ("img_end", ctypes.c_char_p), ("img_content", ctypes.c_char_p), ("extend_param", RKLLMExtendParam), ] class RKLLMLoraAdapter(ctypes.Structure): _fields_ = [ ("lora_adapter_path", ctypes.c_char_p), ("lora_adapter_name", ctypes.c_char_p), ("scale", ctypes.c_float) ] class RKLLMEmbedInput(ctypes.Structure): _fields_ = [ ("embed", ctypes.POINTER(ctypes.c_float)), ("n_tokens", ctypes.c_size_t) ] class RKLLMTokenInput(ctypes.Structure): _fields_ = [ ("input_ids", ctypes.POINTER(ctypes.c_int32)), ("n_tokens", ctypes.c_size_t) ] class RKLLMMultiModelInput(ctypes.Structure): _fields_ = [ ("prompt", ctypes.c_char_p), ("image_embed", ctypes.POINTER(ctypes.c_float)), ("n_image_tokens", ctypes.c_size_t) ] class RKLLMInputUnion(ctypes.Union): _fields_ = [ ("prompt_input", ctypes.c_char_p), ("embed_input", RKLLMEmbedInput), ("token_input", RKLLMTokenInput), ("multimodal_input", RKLLMMultiModelInput) ] class RKLLMInput(ctypes.Structure): _fields_ = [ ("input_mode", ctypes.c_int), ("input_data", RKLLMInputUnion) ] class RKLLMLoraParam(ctypes.Structure): _fields_ = [ ("lora_adapter_name", ctypes.c_char_p) ] class RKLLMPromptCacheParam(ctypes.Structure): _fields_ = [ ("save_prompt_cache", ctypes.c_int), ("prompt_cache_path", ctypes.c_char_p) ] class RKLLMInferParam(ctypes.Structure): _fields_ = [ ("mode", RKLLMInferMode), ("lora_params", ctypes.POINTER(RKLLMLoraParam)), ("prompt_cache_params", ctypes.POINTER(RKLLMPromptCacheParam)) ] class RKLLMResultLastHiddenLayer(ctypes.Structure): _fields_ = [ ("hidden_states", ctypes.POINTER(ctypes.c_float)), ("embd_size", ctypes.c_int), ("num_tokens", ctypes.c_int) ] class RKLLMResult(ctypes.Structure): _fields_ = [ ("text", ctypes.c_char_p), ("size", ctypes.c_int), ("last_hidden_layer", RKLLMResultLastHiddenLayer) ] # Define global variables to store the callback function output for displaying in the Gradio interface global_text = [] global_state = -1 split_byte_data = bytes(b"") # Used to store the segmented byte data # Define the callback function def callback_impl(result, userdata, state): global global_text, global_state, split_byte_data if state == LLMCallState.RKLLM_RUN_FINISH: global_state = state print("\n") sys.stdout.flush() elif state == LLMCallState.RKLLM_RUN_ERROR: global_state = state print("run error") sys.stdout.flush() elif state == LLMCallState.RKLLM_RUN_GET_LAST_HIDDEN_LAYER: ''' If using the GET_LAST_HIDDEN_LAYER function, the callback interface will return the memory pointer: last_hidden_layer, the number of tokens: num_tokens, and the size of the hidden layer: embd_size. With these three parameters, you can retrieve the data from last_hidden_layer. Note: The data needs to be retrieved during the current callback; if not obtained in time, the pointer will be released by the next callback. ''' if result.last_hidden_layer.embd_size != 0 and result.last_hidden_layer.num_tokens != 0: data_size = result.last_hidden_layer.embd_size * result.last_hidden_layer.num_tokens * ctypes.sizeof(ctypes.c_float) print(f"data_size: {data_size}") global_text.append(f"data_size: {data_size}\n") output_path = os.getcwd() + "/last_hidden_layer.bin" with open(output_path, "wb") as outFile: data = ctypes.cast(result.last_hidden_layer.hidden_states, ctypes.POINTER(ctypes.c_float)) float_array_type = ctypes.c_float * (data_size // ctypes.sizeof(ctypes.c_float)) float_array = float_array_type.from_address(ctypes.addressof(data.contents)) outFile.write(bytearray(float_array)) print(f"Data saved to {output_path} successfully!") global_text.append(f"Data saved to {output_path} successfully!") else: print("Invalid hidden layer data.") global_text.append("Invalid hidden layer data.") global_state = state time.sleep(0.05) # Delay for 0.05 seconds to wait for the output result sys.stdout.flush() else: # Save the output token text and the RKLLM running state global_state = state # Monitor if the current byte data is complete; if incomplete, record it for later parsing try: global_text.append((split_byte_data + result.contents.text).decode('utf-8')) print((split_byte_data + result.contents.text).decode('utf-8'), end='') split_byte_data = bytes(b"") except: split_byte_data += result.contents.text sys.stdout.flush() # Connect the callback function between the Python side and the C++ side callback_type = ctypes.CFUNCTYPE(None, ctypes.POINTER(RKLLMResult), ctypes.c_void_p, ctypes.c_int) callback = callback_type(callback_impl) # Define the RKLLM class, which includes initialization, inference, and release operations for the RKLLM model in the dynamic library class RKLLM(object): def __init__(self, model_path, lora_model_path = None, prompt_cache_path = None): rkllm_param = RKLLMParam() rkllm_param.model_path = bytes(model_path, 'utf-8') rkllm_param.max_context_len = 4096 rkllm_param.max_new_tokens = 256 rkllm_param.skip_special_token = True rkllm_param.top_k = 40 rkllm_param.top_p = 0.9 rkllm_param.temperature = 0.5 rkllm_param.repeat_penalty = 1.1 rkllm_param.frequency_penalty = 0.0 rkllm_param.presence_penalty = 0.0 rkllm_param.mirostat = 0 rkllm_param.mirostat_tau = 5.0 rkllm_param.mirostat_eta = 0.1 rkllm_param.is_async = True rkllm_param.img_start = "".encode('utf-8') rkllm_param.img_end = "".encode('utf-8') rkllm_param.img_content = "".encode('utf-8') rkllm_param.extend_param.base_domain_id = 0 self.handle = RKLLM_Handle_t() self.rkllm_init = rkllm_lib.rkllm_init self.rkllm_init.argtypes = [ctypes.POINTER(RKLLM_Handle_t), ctypes.POINTER(RKLLMParam), callback_type] self.rkllm_init.restype = ctypes.c_int self.rkllm_init(ctypes.byref(self.handle), ctypes.byref(rkllm_param), callback) self.rkllm_run = rkllm_lib.rkllm_run self.rkllm_run.argtypes = [RKLLM_Handle_t, ctypes.POINTER(RKLLMInput), ctypes.POINTER(RKLLMInferParam), ctypes.c_void_p] self.rkllm_run.restype = ctypes.c_int self.rkllm_destroy = rkllm_lib.rkllm_destroy self.rkllm_destroy.argtypes = [RKLLM_Handle_t] self.rkllm_destroy.restype = ctypes.c_int self.lora_adapter_path = None self.lora_model_name = None if lora_model_path: self.lora_adapter_path = lora_model_path self.lora_adapter_name = "test" lora_adapter = RKLLMLoraAdapter() ctypes.memset(ctypes.byref(lora_adapter), 0, ctypes.sizeof(RKLLMLoraAdapter)) lora_adapter.lora_adapter_path = ctypes.c_char_p((self.lora_adapter_path).encode('utf-8')) lora_adapter.lora_adapter_name = ctypes.c_char_p((self.lora_adapter_name).encode('utf-8')) lora_adapter.scale = 1.0 rkllm_load_lora = rkllm_lib.rkllm_load_lora rkllm_load_lora.argtypes = [RKLLM_Handle_t, ctypes.POINTER(RKLLMLoraAdapter)] rkllm_load_lora.restype = ctypes.c_int rkllm_load_lora(self.handle, ctypes.byref(lora_adapter)) self.prompt_cache_path = None if prompt_cache_path: self.prompt_cache_path = prompt_cache_path rkllm_load_prompt_cache = rkllm_lib.rkllm_load_prompt_cache rkllm_load_prompt_cache.argtypes = [RKLLM_Handle_t, ctypes.c_char_p] rkllm_load_prompt_cache.restype = ctypes.c_int rkllm_load_prompt_cache(self.handle, ctypes.c_char_p((prompt_cache_path).encode('utf-8'))) def run(self, prompt): rkllm_lora_params = None if self.lora_model_name: rkllm_lora_params = RKLLMLoraParam() rkllm_lora_params.lora_adapter_name = ctypes.c_char_p((self.lora_model_name).encode('utf-8')) rkllm_infer_params = RKLLMInferParam() ctypes.memset(ctypes.byref(rkllm_infer_params), 0, ctypes.sizeof(RKLLMInferParam)) rkllm_infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE rkllm_infer_params.lora_params = ctypes.byref(rkllm_lora_params) if rkllm_lora_params else None rkllm_input = RKLLMInput() rkllm_input.input_mode = RKLLMInputMode.RKLLM_INPUT_PROMPT rkllm_input.input_data.prompt_input = ctypes.c_char_p((PROMPT_TEXT_PREFIX + prompt + PROMPT_TEXT_POSTFIX).encode('utf-8')) self.rkllm_run(self.handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), None) return def release(self): self.rkllm_destroy(self.handle) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--rkllm_model_path', type=str, required=True, help='Absolute path of the converted RKLLM model on the Linux board;') parser.add_argument('--target_platform', type=str, required=True, help='Target platform: e.g., rk3588/rk3576;') parser.add_argument('--lora_model_path', type=str, help='Absolute path of the lora_model on the Linux board;') parser.add_argument('--prompt_cache_path', type=str, help='Absolute path of the prompt_cache file on the Linux board;') args = parser.parse_args() if not os.path.exists(args.rkllm_model_path): print("Error: Please provide the correct rkllm model path, and ensure it is the absolute path on the board.") sys.stdout.flush() exit() if not (args.target_platform in ["rk3588", "rk3576"]): print("Error: Please specify the correct target platform: rk3588/rk3576.") sys.stdout.flush() exit() if args.lora_model_path: if not os.path.exists(args.lora_model_path): print("Error: Please provide the correct lora_model path, and advise it is the absolute path on the board.") sys.stdout.flush() exit() if args.prompt_cache_path: if not os.path.exists(args.prompt_cache_path): print("Error: Please provide the correct prompt_cache_file path, and advise it is the absolute path on the board.") sys.stdout.flush() exit() # Fix frequency # command = "sudo bash fix_freq_{}.sh".format(args.target_platform) # subprocess.run(command, shell=True) # Set resource limit resource.setrlimit(resource.RLIMIT_NOFILE, (102400, 102400)) # Initialize RKLLM model print("=========init....===========") sys.stdout.flush() model_path = args.rkllm_model_path rkllm_model = RKLLM(model_path, args.lora_model_path, args.prompt_cache_path) print("RKLLM Model has been initialized successfully!") print("==============================") sys.stdout.flush() # Record the user's input prompt def get_user_input(user_message, history): history = history + [[user_message, None]] return "", history # Retrieve the output from the RKLLM model and print it in a streaming manner def get_RKLLM_output(history): # Link global variables to retrieve the output information from the callback function global global_text, global_state global_text = [] global_state = -1 # Create a thread for model inference model_thread = threading.Thread(target=rkllm_model.run, args=(history[-1][0],)) model_thread.start() # history[-1][1] represents the current dialogue history[-1][1] = "" # Wait for the model to finish running and periodically check the inference thread of the model model_thread_finished = False while not model_thread_finished: while len(global_text) > 0: history[-1][1] += global_text.pop(0) time.sleep(0.005) # Gradio automatically pushes the result returned by the yield statement when calling the then method yield history model_thread.join(timeout=0.005) model_thread_finished = not model_thread.is_alive() # Create a Gradio interface with gr.Blocks(title="Chat with RKLLM") as chatRKLLM: gr.Markdown("
Chat with RKLLM
") gr.Markdown("### Enter your question in the inputTextBox and press the Enter key to chat with the RKLLM model.") # Create a Chatbot component to display conversation history rkllmServer = gr.Chatbot(height=600) # Create a Textbox component for user message input msg = gr.Textbox(placeholder="Please input your question here...", label="inputTextBox") # Create a Button component to clear the chat history. clear = gr.Button("Clear") # Submit the user's input message to the get_user_input function and immediately update the chat history. # Then call the get_RKLLM_output function to further update the chat history. # The queue=False parameter ensures that these updates are not queued, but executed immediately. msg.submit(get_user_input, [msg, rkllmServer], [msg, rkllmServer], queue=False).then(get_RKLLM_output, rkllmServer, rkllmServer) # When the clear button is clicked, perform a no-operation (lambda: None) and immediately clear the chat history. clear.click(lambda: None, None, rkllmServer, queue=False) # Enable the event queue system. chatRKLLM.queue() # Start the Gradio application. chatRKLLM.launch() print("====================") print("RKLLM model inference completed, releasing RKLLM model resources...") rkllm_model.release() print("====================")