Bring your own rkllm_server for use on updated kernels.
This commit is contained in:
parent
c86afd3590
commit
29d061c89d
5 changed files with 870 additions and 2 deletions
|
|
@ -7,8 +7,9 @@ services:
|
||||||
devices:
|
devices:
|
||||||
- /dev:/dev
|
- /dev:/dev
|
||||||
volumes:
|
volumes:
|
||||||
- ./model:/rkllm_server/model:ro
|
- ./rkllm_server/:/rkllm_server/ # bring-your-own server
|
||||||
|
- ./model/:/rkllm_server/model/:ro
|
||||||
ports:
|
ports:
|
||||||
- "8080:8080"
|
- "8080:8080"
|
||||||
command: >
|
command: >
|
||||||
sh -c "python3 gradio_server.py --target_platform rk3588 --rkllm_model_path /rkllm_server/model/Qwen2.5-3B.rkllm"
|
sh -c "python3 gradio_server.py --target_platform rk3588 --rkllm_model_path /rkllm_server/model/Phi-3.5-mini-instruct-w8a8.rkllm"
|
||||||
|
|
|
||||||
454
rkllm_server/flask_server.py
Normal file
454
rkllm_server/flask_server.py
Normal file
|
|
@ -0,0 +1,454 @@
|
||||||
|
import ctypes
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import resource
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
from flask import Flask, request, jsonify, Response
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
|
||||||
|
PROMPT_TEXT_PREFIX = "<|im_start|>system You are a helpful assistant. <|im_end|> <|im_start|>user"
|
||||||
|
PROMPT_TEXT_POSTFIX = "<|im_end|><|im_start|>assistant"
|
||||||
|
|
||||||
|
# Set the dynamic library path
|
||||||
|
rkllm_lib = ctypes.CDLL('/lib/librkllmrt.so')
|
||||||
|
|
||||||
|
# Define the structures from the library
|
||||||
|
RKLLM_Handle_t = ctypes.c_void_p
|
||||||
|
userdata = ctypes.c_void_p(None)
|
||||||
|
|
||||||
|
LLMCallState = ctypes.c_int
|
||||||
|
LLMCallState.RKLLM_RUN_NORMAL = 0
|
||||||
|
LLMCallState.RKLLM_RUN_WAITING = 1
|
||||||
|
LLMCallState.RKLLM_RUN_FINISH = 2
|
||||||
|
LLMCallState.RKLLM_RUN_ERROR = 3
|
||||||
|
LLMCallState.RKLLM_RUN_GET_LAST_HIDDEN_LAYER = 4
|
||||||
|
|
||||||
|
RKLLMInputMode = ctypes.c_int
|
||||||
|
RKLLMInputMode.RKLLM_INPUT_PROMPT = 0
|
||||||
|
RKLLMInputMode.RKLLM_INPUT_TOKEN = 1
|
||||||
|
RKLLMInputMode.RKLLM_INPUT_EMBED = 2
|
||||||
|
RKLLMInputMode.RKLLM_INPUT_MULTIMODAL = 3
|
||||||
|
|
||||||
|
RKLLMInferMode = ctypes.c_int
|
||||||
|
RKLLMInferMode.RKLLM_INFER_GENERATE = 0
|
||||||
|
RKLLMInferMode.RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1
|
||||||
|
|
||||||
|
class RKLLMExtendParam(ctypes.Structure):
|
||||||
|
_fields_ = [
|
||||||
|
("base_domain_id", ctypes.c_int32),
|
||||||
|
("reserved", ctypes.c_uint8 * 112)
|
||||||
|
]
|
||||||
|
|
||||||
|
class RKLLMParam(ctypes.Structure):
|
||||||
|
_fields_ = [
|
||||||
|
("model_path", ctypes.c_char_p),
|
||||||
|
("max_context_len", ctypes.c_int32),
|
||||||
|
("max_new_tokens", ctypes.c_int32),
|
||||||
|
("top_k", ctypes.c_int32),
|
||||||
|
("top_p", ctypes.c_float),
|
||||||
|
("temperature", ctypes.c_float),
|
||||||
|
("repeat_penalty", ctypes.c_float),
|
||||||
|
("frequency_penalty", ctypes.c_float),
|
||||||
|
("presence_penalty", ctypes.c_float),
|
||||||
|
("mirostat", ctypes.c_int32),
|
||||||
|
("mirostat_tau", ctypes.c_float),
|
||||||
|
("mirostat_eta", ctypes.c_float),
|
||||||
|
("skip_special_token", ctypes.c_bool),
|
||||||
|
("is_async", ctypes.c_bool),
|
||||||
|
("img_start", ctypes.c_char_p),
|
||||||
|
("img_end", ctypes.c_char_p),
|
||||||
|
("img_content", ctypes.c_char_p),
|
||||||
|
("extend_param", RKLLMExtendParam),
|
||||||
|
]
|
||||||
|
|
||||||
|
class RKLLMLoraAdapter(ctypes.Structure):
|
||||||
|
_fields_ = [
|
||||||
|
("lora_adapter_path", ctypes.c_char_p),
|
||||||
|
("lora_adapter_name", ctypes.c_char_p),
|
||||||
|
("scale", ctypes.c_float)
|
||||||
|
]
|
||||||
|
|
||||||
|
class RKLLMEmbedInput(ctypes.Structure):
|
||||||
|
_fields_ = [
|
||||||
|
("embed", ctypes.POINTER(ctypes.c_float)),
|
||||||
|
("n_tokens", ctypes.c_size_t)
|
||||||
|
]
|
||||||
|
|
||||||
|
class RKLLMTokenInput(ctypes.Structure):
|
||||||
|
_fields_ = [
|
||||||
|
("input_ids", ctypes.POINTER(ctypes.c_int32)),
|
||||||
|
("n_tokens", ctypes.c_size_t)
|
||||||
|
]
|
||||||
|
|
||||||
|
class RKLLMMultiModelInput(ctypes.Structure):
|
||||||
|
_fields_ = [
|
||||||
|
("prompt", ctypes.c_char_p),
|
||||||
|
("image_embed", ctypes.POINTER(ctypes.c_float)),
|
||||||
|
("n_image_tokens", ctypes.c_size_t)
|
||||||
|
]
|
||||||
|
|
||||||
|
class RKLLMInputUnion(ctypes.Union):
|
||||||
|
_fields_ = [
|
||||||
|
("prompt_input", ctypes.c_char_p),
|
||||||
|
("embed_input", RKLLMEmbedInput),
|
||||||
|
("token_input", RKLLMTokenInput),
|
||||||
|
("multimodal_input", RKLLMMultiModelInput)
|
||||||
|
]
|
||||||
|
|
||||||
|
class RKLLMInput(ctypes.Structure):
|
||||||
|
_fields_ = [
|
||||||
|
("input_mode", ctypes.c_int),
|
||||||
|
("input_data", RKLLMInputUnion)
|
||||||
|
]
|
||||||
|
|
||||||
|
class RKLLMLoraParam(ctypes.Structure):
|
||||||
|
_fields_ = [
|
||||||
|
("lora_adapter_name", ctypes.c_char_p)
|
||||||
|
]
|
||||||
|
|
||||||
|
class RKLLMPromptCacheParam(ctypes.Structure):
|
||||||
|
_fields_ = [
|
||||||
|
("save_prompt_cache", ctypes.c_int),
|
||||||
|
("prompt_cache_path", ctypes.c_char_p)
|
||||||
|
]
|
||||||
|
|
||||||
|
class RKLLMInferParam(ctypes.Structure):
|
||||||
|
_fields_ = [
|
||||||
|
("mode", RKLLMInferMode),
|
||||||
|
("lora_params", ctypes.POINTER(RKLLMLoraParam)),
|
||||||
|
("prompt_cache_params", ctypes.POINTER(RKLLMPromptCacheParam))
|
||||||
|
]
|
||||||
|
|
||||||
|
class RKLLMResultLastHiddenLayer(ctypes.Structure):
|
||||||
|
_fields_ = [
|
||||||
|
("hidden_states", ctypes.POINTER(ctypes.c_float)),
|
||||||
|
("embd_size", ctypes.c_int),
|
||||||
|
("num_tokens", ctypes.c_int)
|
||||||
|
]
|
||||||
|
|
||||||
|
class RKLLMResult(ctypes.Structure):
|
||||||
|
_fields_ = [
|
||||||
|
("text", ctypes.c_char_p),
|
||||||
|
("size", ctypes.c_int),
|
||||||
|
("last_hidden_layer", RKLLMResultLastHiddenLayer)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Create a lock to control multi-user access to the server.
|
||||||
|
lock = threading.Lock()
|
||||||
|
|
||||||
|
# Create a global variable to indicate whether the server is currently in a blocked state.
|
||||||
|
is_blocking = False
|
||||||
|
|
||||||
|
# Define global variables to store the callback function output for displaying in the Gradio interface
|
||||||
|
global_text = []
|
||||||
|
global_state = -1
|
||||||
|
split_byte_data = bytes(b"") # Used to store the segmented byte data
|
||||||
|
|
||||||
|
# Define the callback function
|
||||||
|
def callback_impl(result, userdata, state):
|
||||||
|
global global_text, global_state, split_byte_data
|
||||||
|
if state == LLMCallState.RKLLM_RUN_FINISH:
|
||||||
|
global_state = state
|
||||||
|
print("\n")
|
||||||
|
sys.stdout.flush()
|
||||||
|
elif state == LLMCallState.RKLLM_RUN_ERROR:
|
||||||
|
global_state = state
|
||||||
|
print("run error")
|
||||||
|
sys.stdout.flush()
|
||||||
|
elif state == LLMCallState.RKLLM_RUN_GET_LAST_HIDDEN_LAYER:
|
||||||
|
'''
|
||||||
|
If using the GET_LAST_HIDDEN_LAYER function, the callback interface will return the memory pointer: last_hidden_layer, the number of tokens: num_tokens, and the size of the hidden layer: embd_size.
|
||||||
|
With these three parameters, you can retrieve the data from last_hidden_layer.
|
||||||
|
Note: The data needs to be retrieved during the current callback; if not obtained in time, the pointer will be released by the next callback.
|
||||||
|
'''
|
||||||
|
if result.last_hidden_layer.embd_size != 0 and result.last_hidden_layer.num_tokens != 0:
|
||||||
|
data_size = result.last_hidden_layer.embd_size * result.last_hidden_layer.num_tokens * ctypes.sizeof(ctypes.c_float)
|
||||||
|
print(f"data_size: {data_size}")
|
||||||
|
global_text.append(f"data_size: {data_size}\n")
|
||||||
|
output_path = os.getcwd() + "/last_hidden_layer.bin"
|
||||||
|
with open(output_path, "wb") as outFile:
|
||||||
|
data = ctypes.cast(result.last_hidden_layer.hidden_states, ctypes.POINTER(ctypes.c_float))
|
||||||
|
float_array_type = ctypes.c_float * (data_size // ctypes.sizeof(ctypes.c_float))
|
||||||
|
float_array = float_array_type.from_address(ctypes.addressof(data.contents))
|
||||||
|
outFile.write(bytearray(float_array))
|
||||||
|
print(f"Data saved to {output_path} successfully!")
|
||||||
|
global_text.append(f"Data saved to {output_path} successfully!")
|
||||||
|
else:
|
||||||
|
print("Invalid hidden layer data.")
|
||||||
|
global_text.append("Invalid hidden layer data.")
|
||||||
|
global_state = state
|
||||||
|
time.sleep(0.05) # Delay for 0.05 seconds to wait for the output result
|
||||||
|
sys.stdout.flush()
|
||||||
|
else:
|
||||||
|
# Save the output token text and the RKLLM running state
|
||||||
|
global_state = state
|
||||||
|
# Monitor if the current byte data is complete; if incomplete, record it for later parsing
|
||||||
|
try:
|
||||||
|
global_text.append((split_byte_data + result.contents.text).decode('utf-8'))
|
||||||
|
print((split_byte_data + result.contents.text).decode('utf-8'), end='')
|
||||||
|
split_byte_data = bytes(b"")
|
||||||
|
except:
|
||||||
|
split_byte_data += result.contents.text
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
# Connect the callback function between the Python side and the C++ side
|
||||||
|
callback_type = ctypes.CFUNCTYPE(None, ctypes.POINTER(RKLLMResult), ctypes.c_void_p, ctypes.c_int)
|
||||||
|
callback = callback_type(callback_impl)
|
||||||
|
|
||||||
|
# Define the RKLLM class, which includes initialization, inference, and release operations for the RKLLM model in the dynamic library
|
||||||
|
class RKLLM(object):
|
||||||
|
def __init__(self, model_path, lora_model_path = None, prompt_cache_path = None):
|
||||||
|
rkllm_param = RKLLMParam()
|
||||||
|
rkllm_param.model_path = bytes(model_path, 'utf-8')
|
||||||
|
|
||||||
|
rkllm_param.max_context_len = 512
|
||||||
|
rkllm_param.max_new_tokens = -1
|
||||||
|
rkllm_param.skip_special_token = True
|
||||||
|
|
||||||
|
rkllm_param.top_k = 1
|
||||||
|
rkllm_param.top_p = 0.9
|
||||||
|
rkllm_param.temperature = 0.8
|
||||||
|
rkllm_param.repeat_penalty = 1.1
|
||||||
|
rkllm_param.frequency_penalty = 0.0
|
||||||
|
rkllm_param.presence_penalty = 0.0
|
||||||
|
|
||||||
|
rkllm_param.mirostat = 0
|
||||||
|
rkllm_param.mirostat_tau = 5.0
|
||||||
|
rkllm_param.mirostat_eta = 0.1
|
||||||
|
|
||||||
|
rkllm_param.is_async = False
|
||||||
|
|
||||||
|
rkllm_param.img_start = "".encode('utf-8')
|
||||||
|
rkllm_param.img_end = "".encode('utf-8')
|
||||||
|
rkllm_param.img_content = "".encode('utf-8')
|
||||||
|
|
||||||
|
rkllm_param.extend_param.base_domain_id = 0
|
||||||
|
|
||||||
|
self.handle = RKLLM_Handle_t()
|
||||||
|
|
||||||
|
self.rkllm_init = rkllm_lib.rkllm_init
|
||||||
|
self.rkllm_init.argtypes = [ctypes.POINTER(RKLLM_Handle_t), ctypes.POINTER(RKLLMParam), callback_type]
|
||||||
|
self.rkllm_init.restype = ctypes.c_int
|
||||||
|
self.rkllm_init(ctypes.byref(self.handle), ctypes.byref(rkllm_param), callback)
|
||||||
|
|
||||||
|
self.rkllm_run = rkllm_lib.rkllm_run
|
||||||
|
self.rkllm_run.argtypes = [RKLLM_Handle_t, ctypes.POINTER(RKLLMInput), ctypes.POINTER(RKLLMInferParam), ctypes.c_void_p]
|
||||||
|
self.rkllm_run.restype = ctypes.c_int
|
||||||
|
|
||||||
|
self.rkllm_destroy = rkllm_lib.rkllm_destroy
|
||||||
|
self.rkllm_destroy.argtypes = [RKLLM_Handle_t]
|
||||||
|
self.rkllm_destroy.restype = ctypes.c_int
|
||||||
|
|
||||||
|
self.lora_adapter_path = None
|
||||||
|
self.lora_model_name = None
|
||||||
|
if lora_model_path:
|
||||||
|
self.lora_adapter_path = lora_model_path
|
||||||
|
self.lora_adapter_name = "test"
|
||||||
|
|
||||||
|
lora_adapter = RKLLMLoraAdapter()
|
||||||
|
ctypes.memset(ctypes.byref(lora_adapter), 0, ctypes.sizeof(RKLLMLoraAdapter))
|
||||||
|
lora_adapter.lora_adapter_path = ctypes.c_char_p((self.lora_adapter_path).encode('utf-8'))
|
||||||
|
lora_adapter.lora_adapter_name = ctypes.c_char_p((self.lora_adapter_name).encode('utf-8'))
|
||||||
|
lora_adapter.scale = 1.0
|
||||||
|
|
||||||
|
rkllm_load_lora = rkllm_lib.rkllm_load_lora
|
||||||
|
rkllm_load_lora.argtypes = [RKLLM_Handle_t, ctypes.POINTER(RKLLMLoraAdapter)]
|
||||||
|
rkllm_load_lora.restype = ctypes.c_int
|
||||||
|
rkllm_load_lora(self.handle, ctypes.byref(lora_adapter))
|
||||||
|
|
||||||
|
self.prompt_cache_path = None
|
||||||
|
if prompt_cache_path:
|
||||||
|
self.prompt_cache_path = prompt_cache_path
|
||||||
|
|
||||||
|
rkllm_load_prompt_cache = rkllm_lib.rkllm_load_prompt_cache
|
||||||
|
rkllm_load_prompt_cache.argtypes = [RKLLM_Handle_t, ctypes.c_char_p]
|
||||||
|
rkllm_load_prompt_cache.restype = ctypes.c_int
|
||||||
|
rkllm_load_prompt_cache(self.handle, ctypes.c_char_p((prompt_cache_path).encode('utf-8')))
|
||||||
|
|
||||||
|
def run(self, prompt):
|
||||||
|
rkllm_lora_params = None
|
||||||
|
if self.lora_model_name:
|
||||||
|
rkllm_lora_params = RKLLMLoraParam()
|
||||||
|
rkllm_lora_params.lora_adapter_name = ctypes.c_char_p((self.lora_model_name).encode('utf-8'))
|
||||||
|
|
||||||
|
rkllm_infer_params = RKLLMInferParam()
|
||||||
|
ctypes.memset(ctypes.byref(rkllm_infer_params), 0, ctypes.sizeof(RKLLMInferParam))
|
||||||
|
rkllm_infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE
|
||||||
|
rkllm_infer_params.lora_params = ctypes.byref(rkllm_lora_params) if rkllm_lora_params else None
|
||||||
|
|
||||||
|
rkllm_input = RKLLMInput()
|
||||||
|
rkllm_input.input_mode = RKLLMInputMode.RKLLM_INPUT_PROMPT
|
||||||
|
rkllm_input.input_data.prompt_input = ctypes.c_char_p((PROMPT_TEXT_PREFIX + prompt + PROMPT_TEXT_POSTFIX).encode('utf-8'))
|
||||||
|
self.rkllm_run(self.handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), None)
|
||||||
|
return
|
||||||
|
|
||||||
|
def release(self):
|
||||||
|
self.rkllm_destroy(self.handle)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--rkllm_model_path', type=str, required=True, help='Absolute path of the converted RKLLM model on the Linux board;')
|
||||||
|
parser.add_argument('--target_platform', type=str, required=True, help='Target platform: e.g., rk3588/rk3576;')
|
||||||
|
parser.add_argument('--lora_model_path', type=str, help='Absolute path of the lora_model on the Linux board;')
|
||||||
|
parser.add_argument('--prompt_cache_path', type=str, help='Absolute path of the prompt_cache file on the Linux board;')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not os.path.exists(args.rkllm_model_path):
|
||||||
|
print("Error: Please provide the correct rkllm model path, and ensure it is the absolute path on the board.")
|
||||||
|
sys.stdout.flush()
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if not (args.target_platform in ["rk3588", "rk3576"]):
|
||||||
|
print("Error: Please specify the correct target platform: rk3588/rk3576.")
|
||||||
|
sys.stdout.flush()
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if args.lora_model_path:
|
||||||
|
if not os.path.exists(args.lora_model_path):
|
||||||
|
print("Error: Please provide the correct lora_model path, and advise it is the absolute path on the board.")
|
||||||
|
sys.stdout.flush()
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if args.prompt_cache_path:
|
||||||
|
if not os.path.exists(args.prompt_cache_path):
|
||||||
|
print("Error: Please provide the correct prompt_cache_file path, and advise it is the absolute path on the board.")
|
||||||
|
sys.stdout.flush()
|
||||||
|
exit()
|
||||||
|
|
||||||
|
# Fix frequency
|
||||||
|
command = "sudo bash fix_freq_{}.sh".format(args.target_platform)
|
||||||
|
subprocess.run(command, shell=True)
|
||||||
|
|
||||||
|
# Set resource limit
|
||||||
|
resource.setrlimit(resource.RLIMIT_NOFILE, (102400, 102400))
|
||||||
|
|
||||||
|
# Initialize RKLLM model
|
||||||
|
print("=========init....===========")
|
||||||
|
sys.stdout.flush()
|
||||||
|
model_path = args.rkllm_model_path
|
||||||
|
rkllm_model = RKLLM(model_path, args.lora_model_path, args.prompt_cache_path)
|
||||||
|
print("RKLLM Model has been initialized successfully!")
|
||||||
|
print("==============================")
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
# Create a function to receive data sent by the user using a request
|
||||||
|
@app.route('/rkllm_chat', methods=['POST'])
|
||||||
|
def receive_message():
|
||||||
|
# Link global variables to retrieve the output information from the callback function
|
||||||
|
global global_text, global_state
|
||||||
|
global is_blocking
|
||||||
|
|
||||||
|
# If the server is in a blocking state, return a specific response.
|
||||||
|
if is_blocking or global_state==0:
|
||||||
|
return jsonify({'status': 'error', 'message': 'RKLLM_Server is busy! Maybe you can try again later.'}), 503
|
||||||
|
|
||||||
|
lock.acquire()
|
||||||
|
try:
|
||||||
|
# Set the server to a blocking state.
|
||||||
|
is_blocking = True
|
||||||
|
|
||||||
|
# Get JSON data from the POST request.
|
||||||
|
data = request.json
|
||||||
|
if data and 'messages' in data:
|
||||||
|
# Reset global variables.
|
||||||
|
global_text = []
|
||||||
|
global_state = -1
|
||||||
|
|
||||||
|
# Define the structure for the returned response.
|
||||||
|
rkllm_responses = {
|
||||||
|
"id": "rkllm_chat",
|
||||||
|
"object": "rkllm_chat",
|
||||||
|
"created": None,
|
||||||
|
"choices": [],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": None,
|
||||||
|
"completion_tokens": None,
|
||||||
|
"total_tokens": None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if not "stream" in data.keys() or data["stream"] == False:
|
||||||
|
# Process the received data here.
|
||||||
|
messages = data['messages']
|
||||||
|
print("Received messages:", messages)
|
||||||
|
for index, message in enumerate(messages):
|
||||||
|
input_prompt = message['content']
|
||||||
|
rkllm_output = ""
|
||||||
|
|
||||||
|
# Create a thread for model inference.
|
||||||
|
model_thread = threading.Thread(target=rkllm_model.run, args=(input_prompt,))
|
||||||
|
model_thread.start()
|
||||||
|
|
||||||
|
# Wait for the model to finish running and periodically check the inference thread of the model.
|
||||||
|
model_thread_finished = False
|
||||||
|
while not model_thread_finished:
|
||||||
|
while len(global_text) > 0:
|
||||||
|
rkllm_output += global_text.pop(0)
|
||||||
|
time.sleep(0.005)
|
||||||
|
|
||||||
|
model_thread.join(timeout=0.005)
|
||||||
|
model_thread_finished = not model_thread.is_alive()
|
||||||
|
|
||||||
|
rkllm_responses["choices"].append(
|
||||||
|
{"index": index,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": rkllm_output,
|
||||||
|
},
|
||||||
|
"logprobs": None,
|
||||||
|
"finish_reason": "stop"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return jsonify(rkllm_responses), 200
|
||||||
|
else:
|
||||||
|
messages = data['messages']
|
||||||
|
print("Received messages:", messages)
|
||||||
|
for index, message in enumerate(messages):
|
||||||
|
input_prompt = message['content']
|
||||||
|
rkllm_output = ""
|
||||||
|
|
||||||
|
def generate():
|
||||||
|
model_thread = threading.Thread(target=rkllm_model.run, args=(input_prompt,))
|
||||||
|
model_thread.start()
|
||||||
|
|
||||||
|
model_thread_finished = False
|
||||||
|
while not model_thread_finished:
|
||||||
|
while len(global_text) > 0:
|
||||||
|
rkllm_output = global_text.pop(0)
|
||||||
|
|
||||||
|
rkllm_responses["choices"].append(
|
||||||
|
{"index": index,
|
||||||
|
"delta": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": rkllm_output,
|
||||||
|
},
|
||||||
|
"logprobs": None,
|
||||||
|
"finish_reason": "stop" if global_state == 1 else None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
yield f"{json.dumps(rkllm_responses)}\n\n"
|
||||||
|
|
||||||
|
model_thread.join(timeout=0.005)
|
||||||
|
model_thread_finished = not model_thread.is_alive()
|
||||||
|
|
||||||
|
return Response(generate(), content_type='text/plain')
|
||||||
|
else:
|
||||||
|
return jsonify({'status': 'error', 'message': 'Invalid JSON data!'}), 400
|
||||||
|
finally:
|
||||||
|
lock.release()
|
||||||
|
is_blocking = False
|
||||||
|
|
||||||
|
# Start the Flask application.
|
||||||
|
# app.run(host='0.0.0.0', port=8080)
|
||||||
|
app.run(host='127.0.0.1', port=8080, threaded=True, debug=False)
|
||||||
|
|
||||||
|
print("====================")
|
||||||
|
print("RKLLM model inference completed, releasing RKLLM model resources...")
|
||||||
|
rkllm_model.release()
|
||||||
|
print("====================")
|
||||||
413
rkllm_server/gradio_server.py
Normal file
413
rkllm_server/gradio_server.py
Normal file
|
|
@ -0,0 +1,413 @@
|
||||||
|
import ctypes
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import resource
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import gradio as gr
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
PROMPT_TEXT_SYSTEM_COMMON = (
|
||||||
|
"You are a helpful assistant who answers things succintly and in steps. "
|
||||||
|
"You listen to the user's question and answer accordingly, or point out factual "
|
||||||
|
"errors in the user's claims."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Re: phi3.5
|
||||||
|
PROMPT_TEXT_PREFIX = (
|
||||||
|
"<|system|> {content} <|end|>\n<|user|>"
|
||||||
|
).format(content = PROMPT_TEXT_SYSTEM_COMMON)
|
||||||
|
PROMPT_TEXT_POSTFIX = "<|end|>\n<|assistant|>"
|
||||||
|
|
||||||
|
# Re: Qwen
|
||||||
|
# PROMPT_TEXT_PREFIX = (
|
||||||
|
# "<|im_start|>system {content} <|im_end|>\n<|im_start|>user"
|
||||||
|
# ).format(content = PROMPT_TEXT_SYSTEM_COMMON)
|
||||||
|
# PROMPT_TEXT_POSTFIX = "<|im_end|>\n<|im_start|>assistant"
|
||||||
|
|
||||||
|
# Re: TinyLlama
|
||||||
|
# PROMPT_TEXT_PREFIX = (
|
||||||
|
# "[INST] <<SYS>>{content}<</SYS>>\n"
|
||||||
|
# ).format(content = PROMPT_TEXT_SYSTEM_COMMON)
|
||||||
|
# PROMPT_TEXT_POSTFIX = " [/INST]\n"
|
||||||
|
|
||||||
|
# Set environment variables
|
||||||
|
os.environ["GRADIO_SERVER_NAME"] = "0.0.0.0"
|
||||||
|
os.environ["GRADIO_SERVER_PORT"] = "8080"
|
||||||
|
|
||||||
|
# Set the dynamic library path
|
||||||
|
rkllm_lib = ctypes.CDLL('./lib/librkllmrt.so')
|
||||||
|
|
||||||
|
# Define the structures from the library
|
||||||
|
RKLLM_Handle_t = ctypes.c_void_p
|
||||||
|
userdata = ctypes.c_void_p(None)
|
||||||
|
|
||||||
|
LLMCallState = ctypes.c_int
|
||||||
|
LLMCallState.RKLLM_RUN_NORMAL = 0
|
||||||
|
LLMCallState.RKLLM_RUN_WAITING = 1
|
||||||
|
LLMCallState.RKLLM_RUN_FINISH = 2
|
||||||
|
LLMCallState.RKLLM_RUN_ERROR = 3
|
||||||
|
LLMCallState.RKLLM_RUN_GET_LAST_HIDDEN_LAYER = 4
|
||||||
|
|
||||||
|
RKLLMInputMode = ctypes.c_int
|
||||||
|
RKLLMInputMode.RKLLM_INPUT_PROMPT = 0
|
||||||
|
RKLLMInputMode.RKLLM_INPUT_TOKEN = 1
|
||||||
|
RKLLMInputMode.RKLLM_INPUT_EMBED = 2
|
||||||
|
RKLLMInputMode.RKLLM_INPUT_MULTIMODAL = 3
|
||||||
|
|
||||||
|
RKLLMInferMode = ctypes.c_int
|
||||||
|
RKLLMInferMode.RKLLM_INFER_GENERATE = 0
|
||||||
|
RKLLMInferMode.RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1
|
||||||
|
|
||||||
|
class RKLLMExtendParam(ctypes.Structure):
|
||||||
|
_fields_ = [
|
||||||
|
("base_domain_id", ctypes.c_int32),
|
||||||
|
("reserved", ctypes.c_uint8 * 112)
|
||||||
|
]
|
||||||
|
|
||||||
|
class RKLLMParam(ctypes.Structure):
|
||||||
|
_fields_ = [
|
||||||
|
("model_path", ctypes.c_char_p),
|
||||||
|
("max_context_len", ctypes.c_int32),
|
||||||
|
("max_new_tokens", ctypes.c_int32),
|
||||||
|
("top_k", ctypes.c_int32),
|
||||||
|
("top_p", ctypes.c_float),
|
||||||
|
("temperature", ctypes.c_float),
|
||||||
|
("repeat_penalty", ctypes.c_float),
|
||||||
|
("frequency_penalty", ctypes.c_float),
|
||||||
|
("presence_penalty", ctypes.c_float),
|
||||||
|
("mirostat", ctypes.c_int32),
|
||||||
|
("mirostat_tau", ctypes.c_float),
|
||||||
|
("mirostat_eta", ctypes.c_float),
|
||||||
|
("skip_special_token", ctypes.c_bool),
|
||||||
|
("is_async", ctypes.c_bool),
|
||||||
|
("img_start", ctypes.c_char_p),
|
||||||
|
("img_end", ctypes.c_char_p),
|
||||||
|
("img_content", ctypes.c_char_p),
|
||||||
|
("extend_param", RKLLMExtendParam),
|
||||||
|
]
|
||||||
|
|
||||||
|
class RKLLMLoraAdapter(ctypes.Structure):
|
||||||
|
_fields_ = [
|
||||||
|
("lora_adapter_path", ctypes.c_char_p),
|
||||||
|
("lora_adapter_name", ctypes.c_char_p),
|
||||||
|
("scale", ctypes.c_float)
|
||||||
|
]
|
||||||
|
|
||||||
|
class RKLLMEmbedInput(ctypes.Structure):
|
||||||
|
_fields_ = [
|
||||||
|
("embed", ctypes.POINTER(ctypes.c_float)),
|
||||||
|
("n_tokens", ctypes.c_size_t)
|
||||||
|
]
|
||||||
|
|
||||||
|
class RKLLMTokenInput(ctypes.Structure):
|
||||||
|
_fields_ = [
|
||||||
|
("input_ids", ctypes.POINTER(ctypes.c_int32)),
|
||||||
|
("n_tokens", ctypes.c_size_t)
|
||||||
|
]
|
||||||
|
|
||||||
|
class RKLLMMultiModelInput(ctypes.Structure):
|
||||||
|
_fields_ = [
|
||||||
|
("prompt", ctypes.c_char_p),
|
||||||
|
("image_embed", ctypes.POINTER(ctypes.c_float)),
|
||||||
|
("n_image_tokens", ctypes.c_size_t)
|
||||||
|
]
|
||||||
|
|
||||||
|
class RKLLMInputUnion(ctypes.Union):
|
||||||
|
_fields_ = [
|
||||||
|
("prompt_input", ctypes.c_char_p),
|
||||||
|
("embed_input", RKLLMEmbedInput),
|
||||||
|
("token_input", RKLLMTokenInput),
|
||||||
|
("multimodal_input", RKLLMMultiModelInput)
|
||||||
|
]
|
||||||
|
|
||||||
|
class RKLLMInput(ctypes.Structure):
|
||||||
|
_fields_ = [
|
||||||
|
("input_mode", ctypes.c_int),
|
||||||
|
("input_data", RKLLMInputUnion)
|
||||||
|
]
|
||||||
|
|
||||||
|
class RKLLMLoraParam(ctypes.Structure):
|
||||||
|
_fields_ = [
|
||||||
|
("lora_adapter_name", ctypes.c_char_p)
|
||||||
|
]
|
||||||
|
|
||||||
|
class RKLLMPromptCacheParam(ctypes.Structure):
|
||||||
|
_fields_ = [
|
||||||
|
("save_prompt_cache", ctypes.c_int),
|
||||||
|
("prompt_cache_path", ctypes.c_char_p)
|
||||||
|
]
|
||||||
|
|
||||||
|
class RKLLMInferParam(ctypes.Structure):
|
||||||
|
_fields_ = [
|
||||||
|
("mode", RKLLMInferMode),
|
||||||
|
("lora_params", ctypes.POINTER(RKLLMLoraParam)),
|
||||||
|
("prompt_cache_params", ctypes.POINTER(RKLLMPromptCacheParam))
|
||||||
|
]
|
||||||
|
|
||||||
|
class RKLLMResultLastHiddenLayer(ctypes.Structure):
|
||||||
|
_fields_ = [
|
||||||
|
("hidden_states", ctypes.POINTER(ctypes.c_float)),
|
||||||
|
("embd_size", ctypes.c_int),
|
||||||
|
("num_tokens", ctypes.c_int)
|
||||||
|
]
|
||||||
|
|
||||||
|
class RKLLMResult(ctypes.Structure):
|
||||||
|
_fields_ = [
|
||||||
|
("text", ctypes.c_char_p),
|
||||||
|
("size", ctypes.c_int),
|
||||||
|
("last_hidden_layer", RKLLMResultLastHiddenLayer)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Define global variables to store the callback function output for displaying in the Gradio interface
|
||||||
|
global_text = []
|
||||||
|
global_state = -1
|
||||||
|
split_byte_data = bytes(b"") # Used to store the segmented byte data
|
||||||
|
|
||||||
|
# Define the callback function
|
||||||
|
def callback_impl(result, userdata, state):
|
||||||
|
global global_text, global_state, split_byte_data
|
||||||
|
if state == LLMCallState.RKLLM_RUN_FINISH:
|
||||||
|
global_state = state
|
||||||
|
print("\n")
|
||||||
|
sys.stdout.flush()
|
||||||
|
elif state == LLMCallState.RKLLM_RUN_ERROR:
|
||||||
|
global_state = state
|
||||||
|
print("run error")
|
||||||
|
sys.stdout.flush()
|
||||||
|
elif state == LLMCallState.RKLLM_RUN_GET_LAST_HIDDEN_LAYER:
|
||||||
|
'''
|
||||||
|
If using the GET_LAST_HIDDEN_LAYER function, the callback interface will return the memory pointer: last_hidden_layer, the number of tokens: num_tokens, and the size of the hidden layer: embd_size.
|
||||||
|
With these three parameters, you can retrieve the data from last_hidden_layer.
|
||||||
|
Note: The data needs to be retrieved during the current callback; if not obtained in time, the pointer will be released by the next callback.
|
||||||
|
'''
|
||||||
|
if result.last_hidden_layer.embd_size != 0 and result.last_hidden_layer.num_tokens != 0:
|
||||||
|
data_size = result.last_hidden_layer.embd_size * result.last_hidden_layer.num_tokens * ctypes.sizeof(ctypes.c_float)
|
||||||
|
print(f"data_size: {data_size}")
|
||||||
|
global_text.append(f"data_size: {data_size}\n")
|
||||||
|
output_path = os.getcwd() + "/last_hidden_layer.bin"
|
||||||
|
with open(output_path, "wb") as outFile:
|
||||||
|
data = ctypes.cast(result.last_hidden_layer.hidden_states, ctypes.POINTER(ctypes.c_float))
|
||||||
|
float_array_type = ctypes.c_float * (data_size // ctypes.sizeof(ctypes.c_float))
|
||||||
|
float_array = float_array_type.from_address(ctypes.addressof(data.contents))
|
||||||
|
outFile.write(bytearray(float_array))
|
||||||
|
print(f"Data saved to {output_path} successfully!")
|
||||||
|
global_text.append(f"Data saved to {output_path} successfully!")
|
||||||
|
else:
|
||||||
|
print("Invalid hidden layer data.")
|
||||||
|
global_text.append("Invalid hidden layer data.")
|
||||||
|
global_state = state
|
||||||
|
time.sleep(0.05) # Delay for 0.05 seconds to wait for the output result
|
||||||
|
sys.stdout.flush()
|
||||||
|
else:
|
||||||
|
# Save the output token text and the RKLLM running state
|
||||||
|
global_state = state
|
||||||
|
# Monitor if the current byte data is complete; if incomplete, record it for later parsing
|
||||||
|
try:
|
||||||
|
global_text.append((split_byte_data + result.contents.text).decode('utf-8'))
|
||||||
|
print((split_byte_data + result.contents.text).decode('utf-8'), end='')
|
||||||
|
split_byte_data = bytes(b"")
|
||||||
|
except:
|
||||||
|
split_byte_data += result.contents.text
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
# Connect the callback function between the Python side and the C++ side
|
||||||
|
callback_type = ctypes.CFUNCTYPE(None, ctypes.POINTER(RKLLMResult), ctypes.c_void_p, ctypes.c_int)
|
||||||
|
callback = callback_type(callback_impl)
|
||||||
|
|
||||||
|
# Define the RKLLM class, which includes initialization, inference, and release operations for the RKLLM model in the dynamic library
|
||||||
|
class RKLLM(object):
|
||||||
|
def __init__(self, model_path, lora_model_path = None, prompt_cache_path = None):
|
||||||
|
rkllm_param = RKLLMParam()
|
||||||
|
rkllm_param.model_path = bytes(model_path, 'utf-8')
|
||||||
|
|
||||||
|
rkllm_param.max_context_len = 4096
|
||||||
|
rkllm_param.max_new_tokens = 256
|
||||||
|
rkllm_param.skip_special_token = True
|
||||||
|
|
||||||
|
rkllm_param.top_k = 40
|
||||||
|
rkllm_param.top_p = 0.9
|
||||||
|
rkllm_param.temperature = 0.5
|
||||||
|
rkllm_param.repeat_penalty = 1.1
|
||||||
|
rkllm_param.frequency_penalty = 0.0
|
||||||
|
rkllm_param.presence_penalty = 0.0
|
||||||
|
|
||||||
|
rkllm_param.mirostat = 0
|
||||||
|
rkllm_param.mirostat_tau = 5.0
|
||||||
|
rkllm_param.mirostat_eta = 0.1
|
||||||
|
|
||||||
|
rkllm_param.is_async = True
|
||||||
|
|
||||||
|
rkllm_param.img_start = "".encode('utf-8')
|
||||||
|
rkllm_param.img_end = "".encode('utf-8')
|
||||||
|
rkllm_param.img_content = "".encode('utf-8')
|
||||||
|
|
||||||
|
rkllm_param.extend_param.base_domain_id = 0
|
||||||
|
|
||||||
|
self.handle = RKLLM_Handle_t()
|
||||||
|
|
||||||
|
self.rkllm_init = rkllm_lib.rkllm_init
|
||||||
|
self.rkllm_init.argtypes = [ctypes.POINTER(RKLLM_Handle_t), ctypes.POINTER(RKLLMParam), callback_type]
|
||||||
|
self.rkllm_init.restype = ctypes.c_int
|
||||||
|
self.rkllm_init(ctypes.byref(self.handle), ctypes.byref(rkllm_param), callback)
|
||||||
|
|
||||||
|
self.rkllm_run = rkllm_lib.rkllm_run
|
||||||
|
self.rkllm_run.argtypes = [RKLLM_Handle_t, ctypes.POINTER(RKLLMInput), ctypes.POINTER(RKLLMInferParam), ctypes.c_void_p]
|
||||||
|
self.rkllm_run.restype = ctypes.c_int
|
||||||
|
|
||||||
|
self.rkllm_destroy = rkllm_lib.rkllm_destroy
|
||||||
|
self.rkllm_destroy.argtypes = [RKLLM_Handle_t]
|
||||||
|
self.rkllm_destroy.restype = ctypes.c_int
|
||||||
|
|
||||||
|
self.lora_adapter_path = None
|
||||||
|
self.lora_model_name = None
|
||||||
|
if lora_model_path:
|
||||||
|
self.lora_adapter_path = lora_model_path
|
||||||
|
self.lora_adapter_name = "test"
|
||||||
|
|
||||||
|
lora_adapter = RKLLMLoraAdapter()
|
||||||
|
ctypes.memset(ctypes.byref(lora_adapter), 0, ctypes.sizeof(RKLLMLoraAdapter))
|
||||||
|
lora_adapter.lora_adapter_path = ctypes.c_char_p((self.lora_adapter_path).encode('utf-8'))
|
||||||
|
lora_adapter.lora_adapter_name = ctypes.c_char_p((self.lora_adapter_name).encode('utf-8'))
|
||||||
|
lora_adapter.scale = 1.0
|
||||||
|
|
||||||
|
rkllm_load_lora = rkllm_lib.rkllm_load_lora
|
||||||
|
rkllm_load_lora.argtypes = [RKLLM_Handle_t, ctypes.POINTER(RKLLMLoraAdapter)]
|
||||||
|
rkllm_load_lora.restype = ctypes.c_int
|
||||||
|
rkllm_load_lora(self.handle, ctypes.byref(lora_adapter))
|
||||||
|
|
||||||
|
self.prompt_cache_path = None
|
||||||
|
if prompt_cache_path:
|
||||||
|
self.prompt_cache_path = prompt_cache_path
|
||||||
|
|
||||||
|
rkllm_load_prompt_cache = rkllm_lib.rkllm_load_prompt_cache
|
||||||
|
rkllm_load_prompt_cache.argtypes = [RKLLM_Handle_t, ctypes.c_char_p]
|
||||||
|
rkllm_load_prompt_cache.restype = ctypes.c_int
|
||||||
|
rkllm_load_prompt_cache(self.handle, ctypes.c_char_p((prompt_cache_path).encode('utf-8')))
|
||||||
|
|
||||||
|
def run(self, prompt):
|
||||||
|
rkllm_lora_params = None
|
||||||
|
if self.lora_model_name:
|
||||||
|
rkllm_lora_params = RKLLMLoraParam()
|
||||||
|
rkllm_lora_params.lora_adapter_name = ctypes.c_char_p((self.lora_model_name).encode('utf-8'))
|
||||||
|
|
||||||
|
rkllm_infer_params = RKLLMInferParam()
|
||||||
|
ctypes.memset(ctypes.byref(rkllm_infer_params), 0, ctypes.sizeof(RKLLMInferParam))
|
||||||
|
rkllm_infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE
|
||||||
|
rkllm_infer_params.lora_params = ctypes.byref(rkllm_lora_params) if rkllm_lora_params else None
|
||||||
|
|
||||||
|
rkllm_input = RKLLMInput()
|
||||||
|
rkllm_input.input_mode = RKLLMInputMode.RKLLM_INPUT_PROMPT
|
||||||
|
rkllm_input.input_data.prompt_input = ctypes.c_char_p((PROMPT_TEXT_PREFIX + prompt + PROMPT_TEXT_POSTFIX).encode('utf-8'))
|
||||||
|
self.rkllm_run(self.handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), None)
|
||||||
|
return
|
||||||
|
|
||||||
|
def release(self):
|
||||||
|
self.rkllm_destroy(self.handle)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--rkllm_model_path', type=str, required=True, help='Absolute path of the converted RKLLM model on the Linux board;')
|
||||||
|
parser.add_argument('--target_platform', type=str, required=True, help='Target platform: e.g., rk3588/rk3576;')
|
||||||
|
parser.add_argument('--lora_model_path', type=str, help='Absolute path of the lora_model on the Linux board;')
|
||||||
|
parser.add_argument('--prompt_cache_path', type=str, help='Absolute path of the prompt_cache file on the Linux board;')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not os.path.exists(args.rkllm_model_path):
|
||||||
|
print("Error: Please provide the correct rkllm model path, and ensure it is the absolute path on the board.")
|
||||||
|
sys.stdout.flush()
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if not (args.target_platform in ["rk3588", "rk3576"]):
|
||||||
|
print("Error: Please specify the correct target platform: rk3588/rk3576.")
|
||||||
|
sys.stdout.flush()
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if args.lora_model_path:
|
||||||
|
if not os.path.exists(args.lora_model_path):
|
||||||
|
print("Error: Please provide the correct lora_model path, and advise it is the absolute path on the board.")
|
||||||
|
sys.stdout.flush()
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if args.prompt_cache_path:
|
||||||
|
if not os.path.exists(args.prompt_cache_path):
|
||||||
|
print("Error: Please provide the correct prompt_cache_file path, and advise it is the absolute path on the board.")
|
||||||
|
sys.stdout.flush()
|
||||||
|
exit()
|
||||||
|
|
||||||
|
# Fix frequency
|
||||||
|
# command = "sudo bash fix_freq_{}.sh".format(args.target_platform)
|
||||||
|
# subprocess.run(command, shell=True)
|
||||||
|
|
||||||
|
# Set resource limit
|
||||||
|
resource.setrlimit(resource.RLIMIT_NOFILE, (102400, 102400))
|
||||||
|
|
||||||
|
# Initialize RKLLM model
|
||||||
|
print("=========init....===========")
|
||||||
|
sys.stdout.flush()
|
||||||
|
model_path = args.rkllm_model_path
|
||||||
|
rkllm_model = RKLLM(model_path, args.lora_model_path, args.prompt_cache_path)
|
||||||
|
print("RKLLM Model has been initialized successfully!")
|
||||||
|
print("==============================")
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
# Record the user's input prompt
|
||||||
|
def get_user_input(user_message, history):
|
||||||
|
history = history + [[user_message, None]]
|
||||||
|
return "", history
|
||||||
|
|
||||||
|
# Retrieve the output from the RKLLM model and print it in a streaming manner
|
||||||
|
def get_RKLLM_output(history):
|
||||||
|
# Link global variables to retrieve the output information from the callback function
|
||||||
|
global global_text, global_state
|
||||||
|
global_text = []
|
||||||
|
global_state = -1
|
||||||
|
|
||||||
|
# Create a thread for model inference
|
||||||
|
model_thread = threading.Thread(target=rkllm_model.run, args=(history[-1][0],))
|
||||||
|
model_thread.start()
|
||||||
|
|
||||||
|
# history[-1][1] represents the current dialogue
|
||||||
|
history[-1][1] = ""
|
||||||
|
|
||||||
|
# Wait for the model to finish running and periodically check the inference thread of the model
|
||||||
|
model_thread_finished = False
|
||||||
|
while not model_thread_finished:
|
||||||
|
while len(global_text) > 0:
|
||||||
|
history[-1][1] += global_text.pop(0)
|
||||||
|
time.sleep(0.005)
|
||||||
|
# Gradio automatically pushes the result returned by the yield statement when calling the then method
|
||||||
|
yield history
|
||||||
|
|
||||||
|
model_thread.join(timeout=0.005)
|
||||||
|
model_thread_finished = not model_thread.is_alive()
|
||||||
|
|
||||||
|
# Create a Gradio interface
|
||||||
|
with gr.Blocks(title="Chat with RKLLM") as chatRKLLM:
|
||||||
|
gr.Markdown("<div align='center'><font size='70'> Chat with RKLLM </font></div>")
|
||||||
|
gr.Markdown("### Enter your question in the inputTextBox and press the Enter key to chat with the RKLLM model.")
|
||||||
|
# Create a Chatbot component to display conversation history
|
||||||
|
rkllmServer = gr.Chatbot(height=600)
|
||||||
|
# Create a Textbox component for user message input
|
||||||
|
msg = gr.Textbox(placeholder="Please input your question here...", label="inputTextBox")
|
||||||
|
# Create a Button component to clear the chat history.
|
||||||
|
clear = gr.Button("Clear")
|
||||||
|
|
||||||
|
# Submit the user's input message to the get_user_input function and immediately update the chat history.
|
||||||
|
# Then call the get_RKLLM_output function to further update the chat history.
|
||||||
|
# The queue=False parameter ensures that these updates are not queued, but executed immediately.
|
||||||
|
msg.submit(get_user_input, [msg, rkllmServer], [msg, rkllmServer], queue=False).then(get_RKLLM_output, rkllmServer, rkllmServer)
|
||||||
|
# When the clear button is clicked, perform a no-operation (lambda: None) and immediately clear the chat history.
|
||||||
|
clear.click(lambda: None, None, rkllmServer, queue=False)
|
||||||
|
|
||||||
|
# Enable the event queue system.
|
||||||
|
chatRKLLM.queue()
|
||||||
|
# Start the Gradio application.
|
||||||
|
chatRKLLM.launch()
|
||||||
|
|
||||||
|
print("====================")
|
||||||
|
print("RKLLM model inference completed, releasing RKLLM model resources...")
|
||||||
|
rkllm_model.release()
|
||||||
|
print("====================")
|
||||||
0
rkllm_server/lib/.gitkeep
Normal file
0
rkllm_server/lib/.gitkeep
Normal file
BIN
rkllm_server/lib/librkllmrt.so
Normal file
BIN
rkllm_server/lib/librkllmrt.so
Normal file
Binary file not shown.
Loading…
Add table
Add a link
Reference in a new issue