diff --git a/docker-compose.yml b/docker-compose.yml
index efff00d..49af199 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -7,8 +7,9 @@ services:
devices:
- /dev:/dev
volumes:
- - ./model:/rkllm_server/model:ro
+ - ./rkllm_server/:/rkllm_server/ # bring-your-own server
+ - ./model/:/rkllm_server/model/:ro
ports:
- "8080:8080"
command: >
- sh -c "python3 gradio_server.py --target_platform rk3588 --rkllm_model_path /rkllm_server/model/Qwen2.5-3B.rkllm"
+ sh -c "python3 gradio_server.py --target_platform rk3588 --rkllm_model_path /rkllm_server/model/Phi-3.5-mini-instruct-w8a8.rkllm"
diff --git a/rkllm_server/flask_server.py b/rkllm_server/flask_server.py
new file mode 100644
index 0000000..133e24a
--- /dev/null
+++ b/rkllm_server/flask_server.py
@@ -0,0 +1,454 @@
+import ctypes
+import sys
+import os
+import subprocess
+import resource
+import threading
+import time
+import argparse
+import json
+from flask import Flask, request, jsonify, Response
+
+app = Flask(__name__)
+
+PROMPT_TEXT_PREFIX = "<|im_start|>system You are a helpful assistant. <|im_end|> <|im_start|>user"
+PROMPT_TEXT_POSTFIX = "<|im_end|><|im_start|>assistant"
+
+# Set the dynamic library path
+rkllm_lib = ctypes.CDLL('/lib/librkllmrt.so')
+
+# Define the structures from the library
+RKLLM_Handle_t = ctypes.c_void_p
+userdata = ctypes.c_void_p(None)
+
+LLMCallState = ctypes.c_int
+LLMCallState.RKLLM_RUN_NORMAL = 0
+LLMCallState.RKLLM_RUN_WAITING = 1
+LLMCallState.RKLLM_RUN_FINISH = 2
+LLMCallState.RKLLM_RUN_ERROR = 3
+LLMCallState.RKLLM_RUN_GET_LAST_HIDDEN_LAYER = 4
+
+RKLLMInputMode = ctypes.c_int
+RKLLMInputMode.RKLLM_INPUT_PROMPT = 0
+RKLLMInputMode.RKLLM_INPUT_TOKEN = 1
+RKLLMInputMode.RKLLM_INPUT_EMBED = 2
+RKLLMInputMode.RKLLM_INPUT_MULTIMODAL = 3
+
+RKLLMInferMode = ctypes.c_int
+RKLLMInferMode.RKLLM_INFER_GENERATE = 0
+RKLLMInferMode.RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1
+
+class RKLLMExtendParam(ctypes.Structure):
+ _fields_ = [
+ ("base_domain_id", ctypes.c_int32),
+ ("reserved", ctypes.c_uint8 * 112)
+ ]
+
+class RKLLMParam(ctypes.Structure):
+ _fields_ = [
+ ("model_path", ctypes.c_char_p),
+ ("max_context_len", ctypes.c_int32),
+ ("max_new_tokens", ctypes.c_int32),
+ ("top_k", ctypes.c_int32),
+ ("top_p", ctypes.c_float),
+ ("temperature", ctypes.c_float),
+ ("repeat_penalty", ctypes.c_float),
+ ("frequency_penalty", ctypes.c_float),
+ ("presence_penalty", ctypes.c_float),
+ ("mirostat", ctypes.c_int32),
+ ("mirostat_tau", ctypes.c_float),
+ ("mirostat_eta", ctypes.c_float),
+ ("skip_special_token", ctypes.c_bool),
+ ("is_async", ctypes.c_bool),
+ ("img_start", ctypes.c_char_p),
+ ("img_end", ctypes.c_char_p),
+ ("img_content", ctypes.c_char_p),
+ ("extend_param", RKLLMExtendParam),
+ ]
+
+class RKLLMLoraAdapter(ctypes.Structure):
+ _fields_ = [
+ ("lora_adapter_path", ctypes.c_char_p),
+ ("lora_adapter_name", ctypes.c_char_p),
+ ("scale", ctypes.c_float)
+ ]
+
+class RKLLMEmbedInput(ctypes.Structure):
+ _fields_ = [
+ ("embed", ctypes.POINTER(ctypes.c_float)),
+ ("n_tokens", ctypes.c_size_t)
+ ]
+
+class RKLLMTokenInput(ctypes.Structure):
+ _fields_ = [
+ ("input_ids", ctypes.POINTER(ctypes.c_int32)),
+ ("n_tokens", ctypes.c_size_t)
+ ]
+
+class RKLLMMultiModelInput(ctypes.Structure):
+ _fields_ = [
+ ("prompt", ctypes.c_char_p),
+ ("image_embed", ctypes.POINTER(ctypes.c_float)),
+ ("n_image_tokens", ctypes.c_size_t)
+ ]
+
+class RKLLMInputUnion(ctypes.Union):
+ _fields_ = [
+ ("prompt_input", ctypes.c_char_p),
+ ("embed_input", RKLLMEmbedInput),
+ ("token_input", RKLLMTokenInput),
+ ("multimodal_input", RKLLMMultiModelInput)
+ ]
+
+class RKLLMInput(ctypes.Structure):
+ _fields_ = [
+ ("input_mode", ctypes.c_int),
+ ("input_data", RKLLMInputUnion)
+ ]
+
+class RKLLMLoraParam(ctypes.Structure):
+ _fields_ = [
+ ("lora_adapter_name", ctypes.c_char_p)
+ ]
+
+class RKLLMPromptCacheParam(ctypes.Structure):
+ _fields_ = [
+ ("save_prompt_cache", ctypes.c_int),
+ ("prompt_cache_path", ctypes.c_char_p)
+ ]
+
+class RKLLMInferParam(ctypes.Structure):
+ _fields_ = [
+ ("mode", RKLLMInferMode),
+ ("lora_params", ctypes.POINTER(RKLLMLoraParam)),
+ ("prompt_cache_params", ctypes.POINTER(RKLLMPromptCacheParam))
+ ]
+
+class RKLLMResultLastHiddenLayer(ctypes.Structure):
+ _fields_ = [
+ ("hidden_states", ctypes.POINTER(ctypes.c_float)),
+ ("embd_size", ctypes.c_int),
+ ("num_tokens", ctypes.c_int)
+ ]
+
+class RKLLMResult(ctypes.Structure):
+ _fields_ = [
+ ("text", ctypes.c_char_p),
+ ("size", ctypes.c_int),
+ ("last_hidden_layer", RKLLMResultLastHiddenLayer)
+ ]
+
+
+# Create a lock to control multi-user access to the server.
+lock = threading.Lock()
+
+# Create a global variable to indicate whether the server is currently in a blocked state.
+is_blocking = False
+
+# Define global variables to store the callback function output for displaying in the Gradio interface
+global_text = []
+global_state = -1
+split_byte_data = bytes(b"") # Used to store the segmented byte data
+
+# Define the callback function
+def callback_impl(result, userdata, state):
+ global global_text, global_state, split_byte_data
+ if state == LLMCallState.RKLLM_RUN_FINISH:
+ global_state = state
+ print("\n")
+ sys.stdout.flush()
+ elif state == LLMCallState.RKLLM_RUN_ERROR:
+ global_state = state
+ print("run error")
+ sys.stdout.flush()
+ elif state == LLMCallState.RKLLM_RUN_GET_LAST_HIDDEN_LAYER:
+ '''
+ If using the GET_LAST_HIDDEN_LAYER function, the callback interface will return the memory pointer: last_hidden_layer, the number of tokens: num_tokens, and the size of the hidden layer: embd_size.
+ With these three parameters, you can retrieve the data from last_hidden_layer.
+ Note: The data needs to be retrieved during the current callback; if not obtained in time, the pointer will be released by the next callback.
+ '''
+ if result.last_hidden_layer.embd_size != 0 and result.last_hidden_layer.num_tokens != 0:
+ data_size = result.last_hidden_layer.embd_size * result.last_hidden_layer.num_tokens * ctypes.sizeof(ctypes.c_float)
+ print(f"data_size: {data_size}")
+ global_text.append(f"data_size: {data_size}\n")
+ output_path = os.getcwd() + "/last_hidden_layer.bin"
+ with open(output_path, "wb") as outFile:
+ data = ctypes.cast(result.last_hidden_layer.hidden_states, ctypes.POINTER(ctypes.c_float))
+ float_array_type = ctypes.c_float * (data_size // ctypes.sizeof(ctypes.c_float))
+ float_array = float_array_type.from_address(ctypes.addressof(data.contents))
+ outFile.write(bytearray(float_array))
+ print(f"Data saved to {output_path} successfully!")
+ global_text.append(f"Data saved to {output_path} successfully!")
+ else:
+ print("Invalid hidden layer data.")
+ global_text.append("Invalid hidden layer data.")
+ global_state = state
+ time.sleep(0.05) # Delay for 0.05 seconds to wait for the output result
+ sys.stdout.flush()
+ else:
+ # Save the output token text and the RKLLM running state
+ global_state = state
+ # Monitor if the current byte data is complete; if incomplete, record it for later parsing
+ try:
+ global_text.append((split_byte_data + result.contents.text).decode('utf-8'))
+ print((split_byte_data + result.contents.text).decode('utf-8'), end='')
+ split_byte_data = bytes(b"")
+ except:
+ split_byte_data += result.contents.text
+ sys.stdout.flush()
+
+# Connect the callback function between the Python side and the C++ side
+callback_type = ctypes.CFUNCTYPE(None, ctypes.POINTER(RKLLMResult), ctypes.c_void_p, ctypes.c_int)
+callback = callback_type(callback_impl)
+
+# Define the RKLLM class, which includes initialization, inference, and release operations for the RKLLM model in the dynamic library
+class RKLLM(object):
+ def __init__(self, model_path, lora_model_path = None, prompt_cache_path = None):
+ rkllm_param = RKLLMParam()
+ rkllm_param.model_path = bytes(model_path, 'utf-8')
+
+ rkllm_param.max_context_len = 512
+ rkllm_param.max_new_tokens = -1
+ rkllm_param.skip_special_token = True
+
+ rkllm_param.top_k = 1
+ rkllm_param.top_p = 0.9
+ rkllm_param.temperature = 0.8
+ rkllm_param.repeat_penalty = 1.1
+ rkllm_param.frequency_penalty = 0.0
+ rkllm_param.presence_penalty = 0.0
+
+ rkllm_param.mirostat = 0
+ rkllm_param.mirostat_tau = 5.0
+ rkllm_param.mirostat_eta = 0.1
+
+ rkllm_param.is_async = False
+
+ rkllm_param.img_start = "".encode('utf-8')
+ rkllm_param.img_end = "".encode('utf-8')
+ rkllm_param.img_content = "".encode('utf-8')
+
+ rkllm_param.extend_param.base_domain_id = 0
+
+ self.handle = RKLLM_Handle_t()
+
+ self.rkllm_init = rkllm_lib.rkllm_init
+ self.rkllm_init.argtypes = [ctypes.POINTER(RKLLM_Handle_t), ctypes.POINTER(RKLLMParam), callback_type]
+ self.rkllm_init.restype = ctypes.c_int
+ self.rkllm_init(ctypes.byref(self.handle), ctypes.byref(rkllm_param), callback)
+
+ self.rkllm_run = rkllm_lib.rkllm_run
+ self.rkllm_run.argtypes = [RKLLM_Handle_t, ctypes.POINTER(RKLLMInput), ctypes.POINTER(RKLLMInferParam), ctypes.c_void_p]
+ self.rkllm_run.restype = ctypes.c_int
+
+ self.rkllm_destroy = rkllm_lib.rkllm_destroy
+ self.rkllm_destroy.argtypes = [RKLLM_Handle_t]
+ self.rkllm_destroy.restype = ctypes.c_int
+
+ self.lora_adapter_path = None
+ self.lora_model_name = None
+ if lora_model_path:
+ self.lora_adapter_path = lora_model_path
+ self.lora_adapter_name = "test"
+
+ lora_adapter = RKLLMLoraAdapter()
+ ctypes.memset(ctypes.byref(lora_adapter), 0, ctypes.sizeof(RKLLMLoraAdapter))
+ lora_adapter.lora_adapter_path = ctypes.c_char_p((self.lora_adapter_path).encode('utf-8'))
+ lora_adapter.lora_adapter_name = ctypes.c_char_p((self.lora_adapter_name).encode('utf-8'))
+ lora_adapter.scale = 1.0
+
+ rkllm_load_lora = rkllm_lib.rkllm_load_lora
+ rkllm_load_lora.argtypes = [RKLLM_Handle_t, ctypes.POINTER(RKLLMLoraAdapter)]
+ rkllm_load_lora.restype = ctypes.c_int
+ rkllm_load_lora(self.handle, ctypes.byref(lora_adapter))
+
+ self.prompt_cache_path = None
+ if prompt_cache_path:
+ self.prompt_cache_path = prompt_cache_path
+
+ rkllm_load_prompt_cache = rkllm_lib.rkllm_load_prompt_cache
+ rkllm_load_prompt_cache.argtypes = [RKLLM_Handle_t, ctypes.c_char_p]
+ rkllm_load_prompt_cache.restype = ctypes.c_int
+ rkllm_load_prompt_cache(self.handle, ctypes.c_char_p((prompt_cache_path).encode('utf-8')))
+
+ def run(self, prompt):
+ rkllm_lora_params = None
+ if self.lora_model_name:
+ rkllm_lora_params = RKLLMLoraParam()
+ rkllm_lora_params.lora_adapter_name = ctypes.c_char_p((self.lora_model_name).encode('utf-8'))
+
+ rkllm_infer_params = RKLLMInferParam()
+ ctypes.memset(ctypes.byref(rkllm_infer_params), 0, ctypes.sizeof(RKLLMInferParam))
+ rkllm_infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE
+ rkllm_infer_params.lora_params = ctypes.byref(rkllm_lora_params) if rkllm_lora_params else None
+
+ rkllm_input = RKLLMInput()
+ rkllm_input.input_mode = RKLLMInputMode.RKLLM_INPUT_PROMPT
+ rkllm_input.input_data.prompt_input = ctypes.c_char_p((PROMPT_TEXT_PREFIX + prompt + PROMPT_TEXT_POSTFIX).encode('utf-8'))
+ self.rkllm_run(self.handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), None)
+ return
+
+ def release(self):
+ self.rkllm_destroy(self.handle)
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--rkllm_model_path', type=str, required=True, help='Absolute path of the converted RKLLM model on the Linux board;')
+ parser.add_argument('--target_platform', type=str, required=True, help='Target platform: e.g., rk3588/rk3576;')
+ parser.add_argument('--lora_model_path', type=str, help='Absolute path of the lora_model on the Linux board;')
+ parser.add_argument('--prompt_cache_path', type=str, help='Absolute path of the prompt_cache file on the Linux board;')
+ args = parser.parse_args()
+
+ if not os.path.exists(args.rkllm_model_path):
+ print("Error: Please provide the correct rkllm model path, and ensure it is the absolute path on the board.")
+ sys.stdout.flush()
+ exit()
+
+ if not (args.target_platform in ["rk3588", "rk3576"]):
+ print("Error: Please specify the correct target platform: rk3588/rk3576.")
+ sys.stdout.flush()
+ exit()
+
+ if args.lora_model_path:
+ if not os.path.exists(args.lora_model_path):
+ print("Error: Please provide the correct lora_model path, and advise it is the absolute path on the board.")
+ sys.stdout.flush()
+ exit()
+
+ if args.prompt_cache_path:
+ if not os.path.exists(args.prompt_cache_path):
+ print("Error: Please provide the correct prompt_cache_file path, and advise it is the absolute path on the board.")
+ sys.stdout.flush()
+ exit()
+
+ # Fix frequency
+ command = "sudo bash fix_freq_{}.sh".format(args.target_platform)
+ subprocess.run(command, shell=True)
+
+ # Set resource limit
+ resource.setrlimit(resource.RLIMIT_NOFILE, (102400, 102400))
+
+ # Initialize RKLLM model
+ print("=========init....===========")
+ sys.stdout.flush()
+ model_path = args.rkllm_model_path
+ rkllm_model = RKLLM(model_path, args.lora_model_path, args.prompt_cache_path)
+ print("RKLLM Model has been initialized successfully!")
+ print("==============================")
+ sys.stdout.flush()
+
+ # Create a function to receive data sent by the user using a request
+ @app.route('/rkllm_chat', methods=['POST'])
+ def receive_message():
+ # Link global variables to retrieve the output information from the callback function
+ global global_text, global_state
+ global is_blocking
+
+ # If the server is in a blocking state, return a specific response.
+ if is_blocking or global_state==0:
+ return jsonify({'status': 'error', 'message': 'RKLLM_Server is busy! Maybe you can try again later.'}), 503
+
+ lock.acquire()
+ try:
+ # Set the server to a blocking state.
+ is_blocking = True
+
+ # Get JSON data from the POST request.
+ data = request.json
+ if data and 'messages' in data:
+ # Reset global variables.
+ global_text = []
+ global_state = -1
+
+ # Define the structure for the returned response.
+ rkllm_responses = {
+ "id": "rkllm_chat",
+ "object": "rkllm_chat",
+ "created": None,
+ "choices": [],
+ "usage": {
+ "prompt_tokens": None,
+ "completion_tokens": None,
+ "total_tokens": None
+ }
+ }
+
+ if not "stream" in data.keys() or data["stream"] == False:
+ # Process the received data here.
+ messages = data['messages']
+ print("Received messages:", messages)
+ for index, message in enumerate(messages):
+ input_prompt = message['content']
+ rkllm_output = ""
+
+ # Create a thread for model inference.
+ model_thread = threading.Thread(target=rkllm_model.run, args=(input_prompt,))
+ model_thread.start()
+
+ # Wait for the model to finish running and periodically check the inference thread of the model.
+ model_thread_finished = False
+ while not model_thread_finished:
+ while len(global_text) > 0:
+ rkllm_output += global_text.pop(0)
+ time.sleep(0.005)
+
+ model_thread.join(timeout=0.005)
+ model_thread_finished = not model_thread.is_alive()
+
+ rkllm_responses["choices"].append(
+ {"index": index,
+ "message": {
+ "role": "assistant",
+ "content": rkllm_output,
+ },
+ "logprobs": None,
+ "finish_reason": "stop"
+ }
+ )
+ return jsonify(rkllm_responses), 200
+ else:
+ messages = data['messages']
+ print("Received messages:", messages)
+ for index, message in enumerate(messages):
+ input_prompt = message['content']
+ rkllm_output = ""
+
+ def generate():
+ model_thread = threading.Thread(target=rkllm_model.run, args=(input_prompt,))
+ model_thread.start()
+
+ model_thread_finished = False
+ while not model_thread_finished:
+ while len(global_text) > 0:
+ rkllm_output = global_text.pop(0)
+
+ rkllm_responses["choices"].append(
+ {"index": index,
+ "delta": {
+ "role": "assistant",
+ "content": rkllm_output,
+ },
+ "logprobs": None,
+ "finish_reason": "stop" if global_state == 1 else None,
+ }
+ )
+ yield f"{json.dumps(rkllm_responses)}\n\n"
+
+ model_thread.join(timeout=0.005)
+ model_thread_finished = not model_thread.is_alive()
+
+ return Response(generate(), content_type='text/plain')
+ else:
+ return jsonify({'status': 'error', 'message': 'Invalid JSON data!'}), 400
+ finally:
+ lock.release()
+ is_blocking = False
+
+ # Start the Flask application.
+ # app.run(host='0.0.0.0', port=8080)
+ app.run(host='127.0.0.1', port=8080, threaded=True, debug=False)
+
+ print("====================")
+ print("RKLLM model inference completed, releasing RKLLM model resources...")
+ rkllm_model.release()
+ print("====================")
diff --git a/rkllm_server/gradio_server.py b/rkllm_server/gradio_server.py
new file mode 100644
index 0000000..0d7e700
--- /dev/null
+++ b/rkllm_server/gradio_server.py
@@ -0,0 +1,413 @@
+import ctypes
+import sys
+import os
+import subprocess
+import resource
+import threading
+import time
+import gradio as gr
+import argparse
+
+
+PROMPT_TEXT_SYSTEM_COMMON = (
+ "You are a helpful assistant who answers things succintly and in steps. "
+ "You listen to the user's question and answer accordingly, or point out factual "
+ "errors in the user's claims."
+)
+
+# Re: phi3.5
+PROMPT_TEXT_PREFIX = (
+ "<|system|> {content} <|end|>\n<|user|>"
+).format(content = PROMPT_TEXT_SYSTEM_COMMON)
+PROMPT_TEXT_POSTFIX = "<|end|>\n<|assistant|>"
+
+# Re: Qwen
+# PROMPT_TEXT_PREFIX = (
+# "<|im_start|>system {content} <|im_end|>\n<|im_start|>user"
+# ).format(content = PROMPT_TEXT_SYSTEM_COMMON)
+# PROMPT_TEXT_POSTFIX = "<|im_end|>\n<|im_start|>assistant"
+
+# Re: TinyLlama
+# PROMPT_TEXT_PREFIX = (
+# "[INST] <