diff --git a/docker-compose.yml b/docker-compose.yml index efff00d..49af199 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -7,8 +7,9 @@ services: devices: - /dev:/dev volumes: - - ./model:/rkllm_server/model:ro + - ./rkllm_server/:/rkllm_server/ # bring-your-own server + - ./model/:/rkllm_server/model/:ro ports: - "8080:8080" command: > - sh -c "python3 gradio_server.py --target_platform rk3588 --rkllm_model_path /rkllm_server/model/Qwen2.5-3B.rkllm" + sh -c "python3 gradio_server.py --target_platform rk3588 --rkllm_model_path /rkllm_server/model/Phi-3.5-mini-instruct-w8a8.rkllm" diff --git a/rkllm_server/flask_server.py b/rkllm_server/flask_server.py new file mode 100644 index 0000000..133e24a --- /dev/null +++ b/rkllm_server/flask_server.py @@ -0,0 +1,454 @@ +import ctypes +import sys +import os +import subprocess +import resource +import threading +import time +import argparse +import json +from flask import Flask, request, jsonify, Response + +app = Flask(__name__) + +PROMPT_TEXT_PREFIX = "<|im_start|>system You are a helpful assistant. <|im_end|> <|im_start|>user" +PROMPT_TEXT_POSTFIX = "<|im_end|><|im_start|>assistant" + +# Set the dynamic library path +rkllm_lib = ctypes.CDLL('/lib/librkllmrt.so') + +# Define the structures from the library +RKLLM_Handle_t = ctypes.c_void_p +userdata = ctypes.c_void_p(None) + +LLMCallState = ctypes.c_int +LLMCallState.RKLLM_RUN_NORMAL = 0 +LLMCallState.RKLLM_RUN_WAITING = 1 +LLMCallState.RKLLM_RUN_FINISH = 2 +LLMCallState.RKLLM_RUN_ERROR = 3 +LLMCallState.RKLLM_RUN_GET_LAST_HIDDEN_LAYER = 4 + +RKLLMInputMode = ctypes.c_int +RKLLMInputMode.RKLLM_INPUT_PROMPT = 0 +RKLLMInputMode.RKLLM_INPUT_TOKEN = 1 +RKLLMInputMode.RKLLM_INPUT_EMBED = 2 +RKLLMInputMode.RKLLM_INPUT_MULTIMODAL = 3 + +RKLLMInferMode = ctypes.c_int +RKLLMInferMode.RKLLM_INFER_GENERATE = 0 +RKLLMInferMode.RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1 + +class RKLLMExtendParam(ctypes.Structure): + _fields_ = [ + ("base_domain_id", ctypes.c_int32), + ("reserved", ctypes.c_uint8 * 112) + ] + +class RKLLMParam(ctypes.Structure): + _fields_ = [ + ("model_path", ctypes.c_char_p), + ("max_context_len", ctypes.c_int32), + ("max_new_tokens", ctypes.c_int32), + ("top_k", ctypes.c_int32), + ("top_p", ctypes.c_float), + ("temperature", ctypes.c_float), + ("repeat_penalty", ctypes.c_float), + ("frequency_penalty", ctypes.c_float), + ("presence_penalty", ctypes.c_float), + ("mirostat", ctypes.c_int32), + ("mirostat_tau", ctypes.c_float), + ("mirostat_eta", ctypes.c_float), + ("skip_special_token", ctypes.c_bool), + ("is_async", ctypes.c_bool), + ("img_start", ctypes.c_char_p), + ("img_end", ctypes.c_char_p), + ("img_content", ctypes.c_char_p), + ("extend_param", RKLLMExtendParam), + ] + +class RKLLMLoraAdapter(ctypes.Structure): + _fields_ = [ + ("lora_adapter_path", ctypes.c_char_p), + ("lora_adapter_name", ctypes.c_char_p), + ("scale", ctypes.c_float) + ] + +class RKLLMEmbedInput(ctypes.Structure): + _fields_ = [ + ("embed", ctypes.POINTER(ctypes.c_float)), + ("n_tokens", ctypes.c_size_t) + ] + +class RKLLMTokenInput(ctypes.Structure): + _fields_ = [ + ("input_ids", ctypes.POINTER(ctypes.c_int32)), + ("n_tokens", ctypes.c_size_t) + ] + +class RKLLMMultiModelInput(ctypes.Structure): + _fields_ = [ + ("prompt", ctypes.c_char_p), + ("image_embed", ctypes.POINTER(ctypes.c_float)), + ("n_image_tokens", ctypes.c_size_t) + ] + +class RKLLMInputUnion(ctypes.Union): + _fields_ = [ + ("prompt_input", ctypes.c_char_p), + ("embed_input", RKLLMEmbedInput), + ("token_input", RKLLMTokenInput), + ("multimodal_input", RKLLMMultiModelInput) + ] + +class RKLLMInput(ctypes.Structure): + _fields_ = [ + ("input_mode", ctypes.c_int), + ("input_data", RKLLMInputUnion) + ] + +class RKLLMLoraParam(ctypes.Structure): + _fields_ = [ + ("lora_adapter_name", ctypes.c_char_p) + ] + +class RKLLMPromptCacheParam(ctypes.Structure): + _fields_ = [ + ("save_prompt_cache", ctypes.c_int), + ("prompt_cache_path", ctypes.c_char_p) + ] + +class RKLLMInferParam(ctypes.Structure): + _fields_ = [ + ("mode", RKLLMInferMode), + ("lora_params", ctypes.POINTER(RKLLMLoraParam)), + ("prompt_cache_params", ctypes.POINTER(RKLLMPromptCacheParam)) + ] + +class RKLLMResultLastHiddenLayer(ctypes.Structure): + _fields_ = [ + ("hidden_states", ctypes.POINTER(ctypes.c_float)), + ("embd_size", ctypes.c_int), + ("num_tokens", ctypes.c_int) + ] + +class RKLLMResult(ctypes.Structure): + _fields_ = [ + ("text", ctypes.c_char_p), + ("size", ctypes.c_int), + ("last_hidden_layer", RKLLMResultLastHiddenLayer) + ] + + +# Create a lock to control multi-user access to the server. +lock = threading.Lock() + +# Create a global variable to indicate whether the server is currently in a blocked state. +is_blocking = False + +# Define global variables to store the callback function output for displaying in the Gradio interface +global_text = [] +global_state = -1 +split_byte_data = bytes(b"") # Used to store the segmented byte data + +# Define the callback function +def callback_impl(result, userdata, state): + global global_text, global_state, split_byte_data + if state == LLMCallState.RKLLM_RUN_FINISH: + global_state = state + print("\n") + sys.stdout.flush() + elif state == LLMCallState.RKLLM_RUN_ERROR: + global_state = state + print("run error") + sys.stdout.flush() + elif state == LLMCallState.RKLLM_RUN_GET_LAST_HIDDEN_LAYER: + ''' + If using the GET_LAST_HIDDEN_LAYER function, the callback interface will return the memory pointer: last_hidden_layer, the number of tokens: num_tokens, and the size of the hidden layer: embd_size. + With these three parameters, you can retrieve the data from last_hidden_layer. + Note: The data needs to be retrieved during the current callback; if not obtained in time, the pointer will be released by the next callback. + ''' + if result.last_hidden_layer.embd_size != 0 and result.last_hidden_layer.num_tokens != 0: + data_size = result.last_hidden_layer.embd_size * result.last_hidden_layer.num_tokens * ctypes.sizeof(ctypes.c_float) + print(f"data_size: {data_size}") + global_text.append(f"data_size: {data_size}\n") + output_path = os.getcwd() + "/last_hidden_layer.bin" + with open(output_path, "wb") as outFile: + data = ctypes.cast(result.last_hidden_layer.hidden_states, ctypes.POINTER(ctypes.c_float)) + float_array_type = ctypes.c_float * (data_size // ctypes.sizeof(ctypes.c_float)) + float_array = float_array_type.from_address(ctypes.addressof(data.contents)) + outFile.write(bytearray(float_array)) + print(f"Data saved to {output_path} successfully!") + global_text.append(f"Data saved to {output_path} successfully!") + else: + print("Invalid hidden layer data.") + global_text.append("Invalid hidden layer data.") + global_state = state + time.sleep(0.05) # Delay for 0.05 seconds to wait for the output result + sys.stdout.flush() + else: + # Save the output token text and the RKLLM running state + global_state = state + # Monitor if the current byte data is complete; if incomplete, record it for later parsing + try: + global_text.append((split_byte_data + result.contents.text).decode('utf-8')) + print((split_byte_data + result.contents.text).decode('utf-8'), end='') + split_byte_data = bytes(b"") + except: + split_byte_data += result.contents.text + sys.stdout.flush() + +# Connect the callback function between the Python side and the C++ side +callback_type = ctypes.CFUNCTYPE(None, ctypes.POINTER(RKLLMResult), ctypes.c_void_p, ctypes.c_int) +callback = callback_type(callback_impl) + +# Define the RKLLM class, which includes initialization, inference, and release operations for the RKLLM model in the dynamic library +class RKLLM(object): + def __init__(self, model_path, lora_model_path = None, prompt_cache_path = None): + rkllm_param = RKLLMParam() + rkllm_param.model_path = bytes(model_path, 'utf-8') + + rkllm_param.max_context_len = 512 + rkllm_param.max_new_tokens = -1 + rkllm_param.skip_special_token = True + + rkllm_param.top_k = 1 + rkllm_param.top_p = 0.9 + rkllm_param.temperature = 0.8 + rkllm_param.repeat_penalty = 1.1 + rkllm_param.frequency_penalty = 0.0 + rkllm_param.presence_penalty = 0.0 + + rkllm_param.mirostat = 0 + rkllm_param.mirostat_tau = 5.0 + rkllm_param.mirostat_eta = 0.1 + + rkllm_param.is_async = False + + rkllm_param.img_start = "".encode('utf-8') + rkllm_param.img_end = "".encode('utf-8') + rkllm_param.img_content = "".encode('utf-8') + + rkllm_param.extend_param.base_domain_id = 0 + + self.handle = RKLLM_Handle_t() + + self.rkllm_init = rkllm_lib.rkllm_init + self.rkllm_init.argtypes = [ctypes.POINTER(RKLLM_Handle_t), ctypes.POINTER(RKLLMParam), callback_type] + self.rkllm_init.restype = ctypes.c_int + self.rkllm_init(ctypes.byref(self.handle), ctypes.byref(rkllm_param), callback) + + self.rkllm_run = rkllm_lib.rkllm_run + self.rkllm_run.argtypes = [RKLLM_Handle_t, ctypes.POINTER(RKLLMInput), ctypes.POINTER(RKLLMInferParam), ctypes.c_void_p] + self.rkllm_run.restype = ctypes.c_int + + self.rkllm_destroy = rkllm_lib.rkllm_destroy + self.rkllm_destroy.argtypes = [RKLLM_Handle_t] + self.rkllm_destroy.restype = ctypes.c_int + + self.lora_adapter_path = None + self.lora_model_name = None + if lora_model_path: + self.lora_adapter_path = lora_model_path + self.lora_adapter_name = "test" + + lora_adapter = RKLLMLoraAdapter() + ctypes.memset(ctypes.byref(lora_adapter), 0, ctypes.sizeof(RKLLMLoraAdapter)) + lora_adapter.lora_adapter_path = ctypes.c_char_p((self.lora_adapter_path).encode('utf-8')) + lora_adapter.lora_adapter_name = ctypes.c_char_p((self.lora_adapter_name).encode('utf-8')) + lora_adapter.scale = 1.0 + + rkllm_load_lora = rkllm_lib.rkllm_load_lora + rkllm_load_lora.argtypes = [RKLLM_Handle_t, ctypes.POINTER(RKLLMLoraAdapter)] + rkllm_load_lora.restype = ctypes.c_int + rkllm_load_lora(self.handle, ctypes.byref(lora_adapter)) + + self.prompt_cache_path = None + if prompt_cache_path: + self.prompt_cache_path = prompt_cache_path + + rkllm_load_prompt_cache = rkllm_lib.rkllm_load_prompt_cache + rkllm_load_prompt_cache.argtypes = [RKLLM_Handle_t, ctypes.c_char_p] + rkllm_load_prompt_cache.restype = ctypes.c_int + rkllm_load_prompt_cache(self.handle, ctypes.c_char_p((prompt_cache_path).encode('utf-8'))) + + def run(self, prompt): + rkllm_lora_params = None + if self.lora_model_name: + rkllm_lora_params = RKLLMLoraParam() + rkllm_lora_params.lora_adapter_name = ctypes.c_char_p((self.lora_model_name).encode('utf-8')) + + rkllm_infer_params = RKLLMInferParam() + ctypes.memset(ctypes.byref(rkllm_infer_params), 0, ctypes.sizeof(RKLLMInferParam)) + rkllm_infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE + rkllm_infer_params.lora_params = ctypes.byref(rkllm_lora_params) if rkllm_lora_params else None + + rkllm_input = RKLLMInput() + rkllm_input.input_mode = RKLLMInputMode.RKLLM_INPUT_PROMPT + rkllm_input.input_data.prompt_input = ctypes.c_char_p((PROMPT_TEXT_PREFIX + prompt + PROMPT_TEXT_POSTFIX).encode('utf-8')) + self.rkllm_run(self.handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), None) + return + + def release(self): + self.rkllm_destroy(self.handle) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--rkllm_model_path', type=str, required=True, help='Absolute path of the converted RKLLM model on the Linux board;') + parser.add_argument('--target_platform', type=str, required=True, help='Target platform: e.g., rk3588/rk3576;') + parser.add_argument('--lora_model_path', type=str, help='Absolute path of the lora_model on the Linux board;') + parser.add_argument('--prompt_cache_path', type=str, help='Absolute path of the prompt_cache file on the Linux board;') + args = parser.parse_args() + + if not os.path.exists(args.rkllm_model_path): + print("Error: Please provide the correct rkllm model path, and ensure it is the absolute path on the board.") + sys.stdout.flush() + exit() + + if not (args.target_platform in ["rk3588", "rk3576"]): + print("Error: Please specify the correct target platform: rk3588/rk3576.") + sys.stdout.flush() + exit() + + if args.lora_model_path: + if not os.path.exists(args.lora_model_path): + print("Error: Please provide the correct lora_model path, and advise it is the absolute path on the board.") + sys.stdout.flush() + exit() + + if args.prompt_cache_path: + if not os.path.exists(args.prompt_cache_path): + print("Error: Please provide the correct prompt_cache_file path, and advise it is the absolute path on the board.") + sys.stdout.flush() + exit() + + # Fix frequency + command = "sudo bash fix_freq_{}.sh".format(args.target_platform) + subprocess.run(command, shell=True) + + # Set resource limit + resource.setrlimit(resource.RLIMIT_NOFILE, (102400, 102400)) + + # Initialize RKLLM model + print("=========init....===========") + sys.stdout.flush() + model_path = args.rkllm_model_path + rkllm_model = RKLLM(model_path, args.lora_model_path, args.prompt_cache_path) + print("RKLLM Model has been initialized successfully!") + print("==============================") + sys.stdout.flush() + + # Create a function to receive data sent by the user using a request + @app.route('/rkllm_chat', methods=['POST']) + def receive_message(): + # Link global variables to retrieve the output information from the callback function + global global_text, global_state + global is_blocking + + # If the server is in a blocking state, return a specific response. + if is_blocking or global_state==0: + return jsonify({'status': 'error', 'message': 'RKLLM_Server is busy! Maybe you can try again later.'}), 503 + + lock.acquire() + try: + # Set the server to a blocking state. + is_blocking = True + + # Get JSON data from the POST request. + data = request.json + if data and 'messages' in data: + # Reset global variables. + global_text = [] + global_state = -1 + + # Define the structure for the returned response. + rkllm_responses = { + "id": "rkllm_chat", + "object": "rkllm_chat", + "created": None, + "choices": [], + "usage": { + "prompt_tokens": None, + "completion_tokens": None, + "total_tokens": None + } + } + + if not "stream" in data.keys() or data["stream"] == False: + # Process the received data here. + messages = data['messages'] + print("Received messages:", messages) + for index, message in enumerate(messages): + input_prompt = message['content'] + rkllm_output = "" + + # Create a thread for model inference. + model_thread = threading.Thread(target=rkllm_model.run, args=(input_prompt,)) + model_thread.start() + + # Wait for the model to finish running and periodically check the inference thread of the model. + model_thread_finished = False + while not model_thread_finished: + while len(global_text) > 0: + rkllm_output += global_text.pop(0) + time.sleep(0.005) + + model_thread.join(timeout=0.005) + model_thread_finished = not model_thread.is_alive() + + rkllm_responses["choices"].append( + {"index": index, + "message": { + "role": "assistant", + "content": rkllm_output, + }, + "logprobs": None, + "finish_reason": "stop" + } + ) + return jsonify(rkllm_responses), 200 + else: + messages = data['messages'] + print("Received messages:", messages) + for index, message in enumerate(messages): + input_prompt = message['content'] + rkllm_output = "" + + def generate(): + model_thread = threading.Thread(target=rkllm_model.run, args=(input_prompt,)) + model_thread.start() + + model_thread_finished = False + while not model_thread_finished: + while len(global_text) > 0: + rkllm_output = global_text.pop(0) + + rkllm_responses["choices"].append( + {"index": index, + "delta": { + "role": "assistant", + "content": rkllm_output, + }, + "logprobs": None, + "finish_reason": "stop" if global_state == 1 else None, + } + ) + yield f"{json.dumps(rkllm_responses)}\n\n" + + model_thread.join(timeout=0.005) + model_thread_finished = not model_thread.is_alive() + + return Response(generate(), content_type='text/plain') + else: + return jsonify({'status': 'error', 'message': 'Invalid JSON data!'}), 400 + finally: + lock.release() + is_blocking = False + + # Start the Flask application. + # app.run(host='0.0.0.0', port=8080) + app.run(host='127.0.0.1', port=8080, threaded=True, debug=False) + + print("====================") + print("RKLLM model inference completed, releasing RKLLM model resources...") + rkllm_model.release() + print("====================") diff --git a/rkllm_server/gradio_server.py b/rkllm_server/gradio_server.py new file mode 100644 index 0000000..0d7e700 --- /dev/null +++ b/rkllm_server/gradio_server.py @@ -0,0 +1,413 @@ +import ctypes +import sys +import os +import subprocess +import resource +import threading +import time +import gradio as gr +import argparse + + +PROMPT_TEXT_SYSTEM_COMMON = ( + "You are a helpful assistant who answers things succintly and in steps. " + "You listen to the user's question and answer accordingly, or point out factual " + "errors in the user's claims." +) + +# Re: phi3.5 +PROMPT_TEXT_PREFIX = ( + "<|system|> {content} <|end|>\n<|user|>" +).format(content = PROMPT_TEXT_SYSTEM_COMMON) +PROMPT_TEXT_POSTFIX = "<|end|>\n<|assistant|>" + +# Re: Qwen +# PROMPT_TEXT_PREFIX = ( +# "<|im_start|>system {content} <|im_end|>\n<|im_start|>user" +# ).format(content = PROMPT_TEXT_SYSTEM_COMMON) +# PROMPT_TEXT_POSTFIX = "<|im_end|>\n<|im_start|>assistant" + +# Re: TinyLlama +# PROMPT_TEXT_PREFIX = ( +# "[INST] <>{content}<>\n" +# ).format(content = PROMPT_TEXT_SYSTEM_COMMON) +# PROMPT_TEXT_POSTFIX = " [/INST]\n" + +# Set environment variables +os.environ["GRADIO_SERVER_NAME"] = "0.0.0.0" +os.environ["GRADIO_SERVER_PORT"] = "8080" + +# Set the dynamic library path +rkllm_lib = ctypes.CDLL('./lib/librkllmrt.so') + +# Define the structures from the library +RKLLM_Handle_t = ctypes.c_void_p +userdata = ctypes.c_void_p(None) + +LLMCallState = ctypes.c_int +LLMCallState.RKLLM_RUN_NORMAL = 0 +LLMCallState.RKLLM_RUN_WAITING = 1 +LLMCallState.RKLLM_RUN_FINISH = 2 +LLMCallState.RKLLM_RUN_ERROR = 3 +LLMCallState.RKLLM_RUN_GET_LAST_HIDDEN_LAYER = 4 + +RKLLMInputMode = ctypes.c_int +RKLLMInputMode.RKLLM_INPUT_PROMPT = 0 +RKLLMInputMode.RKLLM_INPUT_TOKEN = 1 +RKLLMInputMode.RKLLM_INPUT_EMBED = 2 +RKLLMInputMode.RKLLM_INPUT_MULTIMODAL = 3 + +RKLLMInferMode = ctypes.c_int +RKLLMInferMode.RKLLM_INFER_GENERATE = 0 +RKLLMInferMode.RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1 + +class RKLLMExtendParam(ctypes.Structure): + _fields_ = [ + ("base_domain_id", ctypes.c_int32), + ("reserved", ctypes.c_uint8 * 112) + ] + +class RKLLMParam(ctypes.Structure): + _fields_ = [ + ("model_path", ctypes.c_char_p), + ("max_context_len", ctypes.c_int32), + ("max_new_tokens", ctypes.c_int32), + ("top_k", ctypes.c_int32), + ("top_p", ctypes.c_float), + ("temperature", ctypes.c_float), + ("repeat_penalty", ctypes.c_float), + ("frequency_penalty", ctypes.c_float), + ("presence_penalty", ctypes.c_float), + ("mirostat", ctypes.c_int32), + ("mirostat_tau", ctypes.c_float), + ("mirostat_eta", ctypes.c_float), + ("skip_special_token", ctypes.c_bool), + ("is_async", ctypes.c_bool), + ("img_start", ctypes.c_char_p), + ("img_end", ctypes.c_char_p), + ("img_content", ctypes.c_char_p), + ("extend_param", RKLLMExtendParam), + ] + +class RKLLMLoraAdapter(ctypes.Structure): + _fields_ = [ + ("lora_adapter_path", ctypes.c_char_p), + ("lora_adapter_name", ctypes.c_char_p), + ("scale", ctypes.c_float) + ] + +class RKLLMEmbedInput(ctypes.Structure): + _fields_ = [ + ("embed", ctypes.POINTER(ctypes.c_float)), + ("n_tokens", ctypes.c_size_t) + ] + +class RKLLMTokenInput(ctypes.Structure): + _fields_ = [ + ("input_ids", ctypes.POINTER(ctypes.c_int32)), + ("n_tokens", ctypes.c_size_t) + ] + +class RKLLMMultiModelInput(ctypes.Structure): + _fields_ = [ + ("prompt", ctypes.c_char_p), + ("image_embed", ctypes.POINTER(ctypes.c_float)), + ("n_image_tokens", ctypes.c_size_t) + ] + +class RKLLMInputUnion(ctypes.Union): + _fields_ = [ + ("prompt_input", ctypes.c_char_p), + ("embed_input", RKLLMEmbedInput), + ("token_input", RKLLMTokenInput), + ("multimodal_input", RKLLMMultiModelInput) + ] + +class RKLLMInput(ctypes.Structure): + _fields_ = [ + ("input_mode", ctypes.c_int), + ("input_data", RKLLMInputUnion) + ] + +class RKLLMLoraParam(ctypes.Structure): + _fields_ = [ + ("lora_adapter_name", ctypes.c_char_p) + ] + +class RKLLMPromptCacheParam(ctypes.Structure): + _fields_ = [ + ("save_prompt_cache", ctypes.c_int), + ("prompt_cache_path", ctypes.c_char_p) + ] + +class RKLLMInferParam(ctypes.Structure): + _fields_ = [ + ("mode", RKLLMInferMode), + ("lora_params", ctypes.POINTER(RKLLMLoraParam)), + ("prompt_cache_params", ctypes.POINTER(RKLLMPromptCacheParam)) + ] + +class RKLLMResultLastHiddenLayer(ctypes.Structure): + _fields_ = [ + ("hidden_states", ctypes.POINTER(ctypes.c_float)), + ("embd_size", ctypes.c_int), + ("num_tokens", ctypes.c_int) + ] + +class RKLLMResult(ctypes.Structure): + _fields_ = [ + ("text", ctypes.c_char_p), + ("size", ctypes.c_int), + ("last_hidden_layer", RKLLMResultLastHiddenLayer) + ] + +# Define global variables to store the callback function output for displaying in the Gradio interface +global_text = [] +global_state = -1 +split_byte_data = bytes(b"") # Used to store the segmented byte data + +# Define the callback function +def callback_impl(result, userdata, state): + global global_text, global_state, split_byte_data + if state == LLMCallState.RKLLM_RUN_FINISH: + global_state = state + print("\n") + sys.stdout.flush() + elif state == LLMCallState.RKLLM_RUN_ERROR: + global_state = state + print("run error") + sys.stdout.flush() + elif state == LLMCallState.RKLLM_RUN_GET_LAST_HIDDEN_LAYER: + ''' + If using the GET_LAST_HIDDEN_LAYER function, the callback interface will return the memory pointer: last_hidden_layer, the number of tokens: num_tokens, and the size of the hidden layer: embd_size. + With these three parameters, you can retrieve the data from last_hidden_layer. + Note: The data needs to be retrieved during the current callback; if not obtained in time, the pointer will be released by the next callback. + ''' + if result.last_hidden_layer.embd_size != 0 and result.last_hidden_layer.num_tokens != 0: + data_size = result.last_hidden_layer.embd_size * result.last_hidden_layer.num_tokens * ctypes.sizeof(ctypes.c_float) + print(f"data_size: {data_size}") + global_text.append(f"data_size: {data_size}\n") + output_path = os.getcwd() + "/last_hidden_layer.bin" + with open(output_path, "wb") as outFile: + data = ctypes.cast(result.last_hidden_layer.hidden_states, ctypes.POINTER(ctypes.c_float)) + float_array_type = ctypes.c_float * (data_size // ctypes.sizeof(ctypes.c_float)) + float_array = float_array_type.from_address(ctypes.addressof(data.contents)) + outFile.write(bytearray(float_array)) + print(f"Data saved to {output_path} successfully!") + global_text.append(f"Data saved to {output_path} successfully!") + else: + print("Invalid hidden layer data.") + global_text.append("Invalid hidden layer data.") + global_state = state + time.sleep(0.05) # Delay for 0.05 seconds to wait for the output result + sys.stdout.flush() + else: + # Save the output token text and the RKLLM running state + global_state = state + # Monitor if the current byte data is complete; if incomplete, record it for later parsing + try: + global_text.append((split_byte_data + result.contents.text).decode('utf-8')) + print((split_byte_data + result.contents.text).decode('utf-8'), end='') + split_byte_data = bytes(b"") + except: + split_byte_data += result.contents.text + sys.stdout.flush() + +# Connect the callback function between the Python side and the C++ side +callback_type = ctypes.CFUNCTYPE(None, ctypes.POINTER(RKLLMResult), ctypes.c_void_p, ctypes.c_int) +callback = callback_type(callback_impl) + +# Define the RKLLM class, which includes initialization, inference, and release operations for the RKLLM model in the dynamic library +class RKLLM(object): + def __init__(self, model_path, lora_model_path = None, prompt_cache_path = None): + rkllm_param = RKLLMParam() + rkllm_param.model_path = bytes(model_path, 'utf-8') + + rkllm_param.max_context_len = 4096 + rkllm_param.max_new_tokens = 256 + rkllm_param.skip_special_token = True + + rkllm_param.top_k = 40 + rkllm_param.top_p = 0.9 + rkllm_param.temperature = 0.5 + rkllm_param.repeat_penalty = 1.1 + rkllm_param.frequency_penalty = 0.0 + rkllm_param.presence_penalty = 0.0 + + rkllm_param.mirostat = 0 + rkllm_param.mirostat_tau = 5.0 + rkllm_param.mirostat_eta = 0.1 + + rkllm_param.is_async = True + + rkllm_param.img_start = "".encode('utf-8') + rkllm_param.img_end = "".encode('utf-8') + rkllm_param.img_content = "".encode('utf-8') + + rkllm_param.extend_param.base_domain_id = 0 + + self.handle = RKLLM_Handle_t() + + self.rkllm_init = rkllm_lib.rkllm_init + self.rkllm_init.argtypes = [ctypes.POINTER(RKLLM_Handle_t), ctypes.POINTER(RKLLMParam), callback_type] + self.rkllm_init.restype = ctypes.c_int + self.rkllm_init(ctypes.byref(self.handle), ctypes.byref(rkllm_param), callback) + + self.rkllm_run = rkllm_lib.rkllm_run + self.rkllm_run.argtypes = [RKLLM_Handle_t, ctypes.POINTER(RKLLMInput), ctypes.POINTER(RKLLMInferParam), ctypes.c_void_p] + self.rkllm_run.restype = ctypes.c_int + + self.rkllm_destroy = rkllm_lib.rkllm_destroy + self.rkllm_destroy.argtypes = [RKLLM_Handle_t] + self.rkllm_destroy.restype = ctypes.c_int + + self.lora_adapter_path = None + self.lora_model_name = None + if lora_model_path: + self.lora_adapter_path = lora_model_path + self.lora_adapter_name = "test" + + lora_adapter = RKLLMLoraAdapter() + ctypes.memset(ctypes.byref(lora_adapter), 0, ctypes.sizeof(RKLLMLoraAdapter)) + lora_adapter.lora_adapter_path = ctypes.c_char_p((self.lora_adapter_path).encode('utf-8')) + lora_adapter.lora_adapter_name = ctypes.c_char_p((self.lora_adapter_name).encode('utf-8')) + lora_adapter.scale = 1.0 + + rkllm_load_lora = rkllm_lib.rkllm_load_lora + rkllm_load_lora.argtypes = [RKLLM_Handle_t, ctypes.POINTER(RKLLMLoraAdapter)] + rkllm_load_lora.restype = ctypes.c_int + rkllm_load_lora(self.handle, ctypes.byref(lora_adapter)) + + self.prompt_cache_path = None + if prompt_cache_path: + self.prompt_cache_path = prompt_cache_path + + rkllm_load_prompt_cache = rkllm_lib.rkllm_load_prompt_cache + rkllm_load_prompt_cache.argtypes = [RKLLM_Handle_t, ctypes.c_char_p] + rkllm_load_prompt_cache.restype = ctypes.c_int + rkllm_load_prompt_cache(self.handle, ctypes.c_char_p((prompt_cache_path).encode('utf-8'))) + + def run(self, prompt): + rkllm_lora_params = None + if self.lora_model_name: + rkllm_lora_params = RKLLMLoraParam() + rkllm_lora_params.lora_adapter_name = ctypes.c_char_p((self.lora_model_name).encode('utf-8')) + + rkllm_infer_params = RKLLMInferParam() + ctypes.memset(ctypes.byref(rkllm_infer_params), 0, ctypes.sizeof(RKLLMInferParam)) + rkllm_infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE + rkllm_infer_params.lora_params = ctypes.byref(rkllm_lora_params) if rkllm_lora_params else None + + rkllm_input = RKLLMInput() + rkllm_input.input_mode = RKLLMInputMode.RKLLM_INPUT_PROMPT + rkllm_input.input_data.prompt_input = ctypes.c_char_p((PROMPT_TEXT_PREFIX + prompt + PROMPT_TEXT_POSTFIX).encode('utf-8')) + self.rkllm_run(self.handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), None) + return + + def release(self): + self.rkllm_destroy(self.handle) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--rkllm_model_path', type=str, required=True, help='Absolute path of the converted RKLLM model on the Linux board;') + parser.add_argument('--target_platform', type=str, required=True, help='Target platform: e.g., rk3588/rk3576;') + parser.add_argument('--lora_model_path', type=str, help='Absolute path of the lora_model on the Linux board;') + parser.add_argument('--prompt_cache_path', type=str, help='Absolute path of the prompt_cache file on the Linux board;') + args = parser.parse_args() + + if not os.path.exists(args.rkllm_model_path): + print("Error: Please provide the correct rkllm model path, and ensure it is the absolute path on the board.") + sys.stdout.flush() + exit() + + if not (args.target_platform in ["rk3588", "rk3576"]): + print("Error: Please specify the correct target platform: rk3588/rk3576.") + sys.stdout.flush() + exit() + + if args.lora_model_path: + if not os.path.exists(args.lora_model_path): + print("Error: Please provide the correct lora_model path, and advise it is the absolute path on the board.") + sys.stdout.flush() + exit() + + if args.prompt_cache_path: + if not os.path.exists(args.prompt_cache_path): + print("Error: Please provide the correct prompt_cache_file path, and advise it is the absolute path on the board.") + sys.stdout.flush() + exit() + + # Fix frequency + # command = "sudo bash fix_freq_{}.sh".format(args.target_platform) + # subprocess.run(command, shell=True) + + # Set resource limit + resource.setrlimit(resource.RLIMIT_NOFILE, (102400, 102400)) + + # Initialize RKLLM model + print("=========init....===========") + sys.stdout.flush() + model_path = args.rkllm_model_path + rkllm_model = RKLLM(model_path, args.lora_model_path, args.prompt_cache_path) + print("RKLLM Model has been initialized successfully!") + print("==============================") + sys.stdout.flush() + + # Record the user's input prompt + def get_user_input(user_message, history): + history = history + [[user_message, None]] + return "", history + + # Retrieve the output from the RKLLM model and print it in a streaming manner + def get_RKLLM_output(history): + # Link global variables to retrieve the output information from the callback function + global global_text, global_state + global_text = [] + global_state = -1 + + # Create a thread for model inference + model_thread = threading.Thread(target=rkllm_model.run, args=(history[-1][0],)) + model_thread.start() + + # history[-1][1] represents the current dialogue + history[-1][1] = "" + + # Wait for the model to finish running and periodically check the inference thread of the model + model_thread_finished = False + while not model_thread_finished: + while len(global_text) > 0: + history[-1][1] += global_text.pop(0) + time.sleep(0.005) + # Gradio automatically pushes the result returned by the yield statement when calling the then method + yield history + + model_thread.join(timeout=0.005) + model_thread_finished = not model_thread.is_alive() + + # Create a Gradio interface + with gr.Blocks(title="Chat with RKLLM") as chatRKLLM: + gr.Markdown("
Chat with RKLLM
") + gr.Markdown("### Enter your question in the inputTextBox and press the Enter key to chat with the RKLLM model.") + # Create a Chatbot component to display conversation history + rkllmServer = gr.Chatbot(height=600) + # Create a Textbox component for user message input + msg = gr.Textbox(placeholder="Please input your question here...", label="inputTextBox") + # Create a Button component to clear the chat history. + clear = gr.Button("Clear") + + # Submit the user's input message to the get_user_input function and immediately update the chat history. + # Then call the get_RKLLM_output function to further update the chat history. + # The queue=False parameter ensures that these updates are not queued, but executed immediately. + msg.submit(get_user_input, [msg, rkllmServer], [msg, rkllmServer], queue=False).then(get_RKLLM_output, rkllmServer, rkllmServer) + # When the clear button is clicked, perform a no-operation (lambda: None) and immediately clear the chat history. + clear.click(lambda: None, None, rkllmServer, queue=False) + + # Enable the event queue system. + chatRKLLM.queue() + # Start the Gradio application. + chatRKLLM.launch() + + print("====================") + print("RKLLM model inference completed, releasing RKLLM model resources...") + rkllm_model.release() + print("====================") diff --git a/rkllm_server/lib/.gitkeep b/rkllm_server/lib/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/rkllm_server/lib/librkllmrt.so b/rkllm_server/lib/librkllmrt.so new file mode 100644 index 0000000..3bb8cb6 Binary files /dev/null and b/rkllm_server/lib/librkllmrt.so differ