compose-rkllm_chat/rkllm_server/flask_server.py

454 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import ctypes
import sys
import os
import subprocess
import resource
import threading
import time
import argparse
import json
from flask import Flask, request, jsonify, Response
app = Flask(__name__)
PROMPT_TEXT_PREFIX = "<|im_start|>system You are a helpful assistant. <|im_end|> <|im_start|>user"
PROMPT_TEXT_POSTFIX = "<|im_end|><|im_start|>assistant"
# Set the dynamic library path
rkllm_lib = ctypes.CDLL('/lib/librkllmrt.so')
# Define the structures from the library
RKLLM_Handle_t = ctypes.c_void_p
userdata = ctypes.c_void_p(None)
LLMCallState = ctypes.c_int
LLMCallState.RKLLM_RUN_NORMAL = 0
LLMCallState.RKLLM_RUN_WAITING = 1
LLMCallState.RKLLM_RUN_FINISH = 2
LLMCallState.RKLLM_RUN_ERROR = 3
LLMCallState.RKLLM_RUN_GET_LAST_HIDDEN_LAYER = 4
RKLLMInputMode = ctypes.c_int
RKLLMInputMode.RKLLM_INPUT_PROMPT = 0
RKLLMInputMode.RKLLM_INPUT_TOKEN = 1
RKLLMInputMode.RKLLM_INPUT_EMBED = 2
RKLLMInputMode.RKLLM_INPUT_MULTIMODAL = 3
RKLLMInferMode = ctypes.c_int
RKLLMInferMode.RKLLM_INFER_GENERATE = 0
RKLLMInferMode.RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1
class RKLLMExtendParam(ctypes.Structure):
_fields_ = [
("base_domain_id", ctypes.c_int32),
("reserved", ctypes.c_uint8 * 112)
]
class RKLLMParam(ctypes.Structure):
_fields_ = [
("model_path", ctypes.c_char_p),
("max_context_len", ctypes.c_int32),
("max_new_tokens", ctypes.c_int32),
("top_k", ctypes.c_int32),
("top_p", ctypes.c_float),
("temperature", ctypes.c_float),
("repeat_penalty", ctypes.c_float),
("frequency_penalty", ctypes.c_float),
("presence_penalty", ctypes.c_float),
("mirostat", ctypes.c_int32),
("mirostat_tau", ctypes.c_float),
("mirostat_eta", ctypes.c_float),
("skip_special_token", ctypes.c_bool),
("is_async", ctypes.c_bool),
("img_start", ctypes.c_char_p),
("img_end", ctypes.c_char_p),
("img_content", ctypes.c_char_p),
("extend_param", RKLLMExtendParam),
]
class RKLLMLoraAdapter(ctypes.Structure):
_fields_ = [
("lora_adapter_path", ctypes.c_char_p),
("lora_adapter_name", ctypes.c_char_p),
("scale", ctypes.c_float)
]
class RKLLMEmbedInput(ctypes.Structure):
_fields_ = [
("embed", ctypes.POINTER(ctypes.c_float)),
("n_tokens", ctypes.c_size_t)
]
class RKLLMTokenInput(ctypes.Structure):
_fields_ = [
("input_ids", ctypes.POINTER(ctypes.c_int32)),
("n_tokens", ctypes.c_size_t)
]
class RKLLMMultiModelInput(ctypes.Structure):
_fields_ = [
("prompt", ctypes.c_char_p),
("image_embed", ctypes.POINTER(ctypes.c_float)),
("n_image_tokens", ctypes.c_size_t)
]
class RKLLMInputUnion(ctypes.Union):
_fields_ = [
("prompt_input", ctypes.c_char_p),
("embed_input", RKLLMEmbedInput),
("token_input", RKLLMTokenInput),
("multimodal_input", RKLLMMultiModelInput)
]
class RKLLMInput(ctypes.Structure):
_fields_ = [
("input_mode", ctypes.c_int),
("input_data", RKLLMInputUnion)
]
class RKLLMLoraParam(ctypes.Structure):
_fields_ = [
("lora_adapter_name", ctypes.c_char_p)
]
class RKLLMPromptCacheParam(ctypes.Structure):
_fields_ = [
("save_prompt_cache", ctypes.c_int),
("prompt_cache_path", ctypes.c_char_p)
]
class RKLLMInferParam(ctypes.Structure):
_fields_ = [
("mode", RKLLMInferMode),
("lora_params", ctypes.POINTER(RKLLMLoraParam)),
("prompt_cache_params", ctypes.POINTER(RKLLMPromptCacheParam))
]
class RKLLMResultLastHiddenLayer(ctypes.Structure):
_fields_ = [
("hidden_states", ctypes.POINTER(ctypes.c_float)),
("embd_size", ctypes.c_int),
("num_tokens", ctypes.c_int)
]
class RKLLMResult(ctypes.Structure):
_fields_ = [
("text", ctypes.c_char_p),
("size", ctypes.c_int),
("last_hidden_layer", RKLLMResultLastHiddenLayer)
]
# Create a lock to control multi-user access to the server.
lock = threading.Lock()
# Create a global variable to indicate whether the server is currently in a blocked state.
is_blocking = False
# Define global variables to store the callback function output for displaying in the Gradio interface
global_text = []
global_state = -1
split_byte_data = bytes(b"") # Used to store the segmented byte data
# Define the callback function
def callback_impl(result, userdata, state):
global global_text, global_state, split_byte_data
if state == LLMCallState.RKLLM_RUN_FINISH:
global_state = state
print("\n")
sys.stdout.flush()
elif state == LLMCallState.RKLLM_RUN_ERROR:
global_state = state
print("run error")
sys.stdout.flush()
elif state == LLMCallState.RKLLM_RUN_GET_LAST_HIDDEN_LAYER:
'''
If using the GET_LAST_HIDDEN_LAYER function, the callback interface will return the memory pointer: last_hidden_layer, the number of tokens: num_tokens, and the size of the hidden layer: embd_size.
With these three parameters, you can retrieve the data from last_hidden_layer.
Note: The data needs to be retrieved during the current callback; if not obtained in time, the pointer will be released by the next callback.
'''
if result.last_hidden_layer.embd_size != 0 and result.last_hidden_layer.num_tokens != 0:
data_size = result.last_hidden_layer.embd_size * result.last_hidden_layer.num_tokens * ctypes.sizeof(ctypes.c_float)
print(f"data_size: {data_size}")
global_text.append(f"data_size: {data_size}\n")
output_path = os.getcwd() + "/last_hidden_layer.bin"
with open(output_path, "wb") as outFile:
data = ctypes.cast(result.last_hidden_layer.hidden_states, ctypes.POINTER(ctypes.c_float))
float_array_type = ctypes.c_float * (data_size // ctypes.sizeof(ctypes.c_float))
float_array = float_array_type.from_address(ctypes.addressof(data.contents))
outFile.write(bytearray(float_array))
print(f"Data saved to {output_path} successfully!")
global_text.append(f"Data saved to {output_path} successfully!")
else:
print("Invalid hidden layer data.")
global_text.append("Invalid hidden layer data.")
global_state = state
time.sleep(0.05) # Delay for 0.05 seconds to wait for the output result
sys.stdout.flush()
else:
# Save the output token text and the RKLLM running state
global_state = state
# Monitor if the current byte data is complete; if incomplete, record it for later parsing
try:
global_text.append((split_byte_data + result.contents.text).decode('utf-8'))
print((split_byte_data + result.contents.text).decode('utf-8'), end='')
split_byte_data = bytes(b"")
except:
split_byte_data += result.contents.text
sys.stdout.flush()
# Connect the callback function between the Python side and the C++ side
callback_type = ctypes.CFUNCTYPE(None, ctypes.POINTER(RKLLMResult), ctypes.c_void_p, ctypes.c_int)
callback = callback_type(callback_impl)
# Define the RKLLM class, which includes initialization, inference, and release operations for the RKLLM model in the dynamic library
class RKLLM(object):
def __init__(self, model_path, lora_model_path = None, prompt_cache_path = None):
rkllm_param = RKLLMParam()
rkllm_param.model_path = bytes(model_path, 'utf-8')
rkllm_param.max_context_len = 512
rkllm_param.max_new_tokens = -1
rkllm_param.skip_special_token = True
rkllm_param.top_k = 1
rkllm_param.top_p = 0.9
rkllm_param.temperature = 0.8
rkllm_param.repeat_penalty = 1.1
rkllm_param.frequency_penalty = 0.0
rkllm_param.presence_penalty = 0.0
rkllm_param.mirostat = 0
rkllm_param.mirostat_tau = 5.0
rkllm_param.mirostat_eta = 0.1
rkllm_param.is_async = False
rkllm_param.img_start = "".encode('utf-8')
rkllm_param.img_end = "".encode('utf-8')
rkllm_param.img_content = "".encode('utf-8')
rkllm_param.extend_param.base_domain_id = 0
self.handle = RKLLM_Handle_t()
self.rkllm_init = rkllm_lib.rkllm_init
self.rkllm_init.argtypes = [ctypes.POINTER(RKLLM_Handle_t), ctypes.POINTER(RKLLMParam), callback_type]
self.rkllm_init.restype = ctypes.c_int
self.rkllm_init(ctypes.byref(self.handle), ctypes.byref(rkllm_param), callback)
self.rkllm_run = rkllm_lib.rkllm_run
self.rkllm_run.argtypes = [RKLLM_Handle_t, ctypes.POINTER(RKLLMInput), ctypes.POINTER(RKLLMInferParam), ctypes.c_void_p]
self.rkllm_run.restype = ctypes.c_int
self.rkllm_destroy = rkllm_lib.rkllm_destroy
self.rkllm_destroy.argtypes = [RKLLM_Handle_t]
self.rkllm_destroy.restype = ctypes.c_int
self.lora_adapter_path = None
self.lora_model_name = None
if lora_model_path:
self.lora_adapter_path = lora_model_path
self.lora_adapter_name = "test"
lora_adapter = RKLLMLoraAdapter()
ctypes.memset(ctypes.byref(lora_adapter), 0, ctypes.sizeof(RKLLMLoraAdapter))
lora_adapter.lora_adapter_path = ctypes.c_char_p((self.lora_adapter_path).encode('utf-8'))
lora_adapter.lora_adapter_name = ctypes.c_char_p((self.lora_adapter_name).encode('utf-8'))
lora_adapter.scale = 1.0
rkllm_load_lora = rkllm_lib.rkllm_load_lora
rkllm_load_lora.argtypes = [RKLLM_Handle_t, ctypes.POINTER(RKLLMLoraAdapter)]
rkllm_load_lora.restype = ctypes.c_int
rkllm_load_lora(self.handle, ctypes.byref(lora_adapter))
self.prompt_cache_path = None
if prompt_cache_path:
self.prompt_cache_path = prompt_cache_path
rkllm_load_prompt_cache = rkllm_lib.rkllm_load_prompt_cache
rkllm_load_prompt_cache.argtypes = [RKLLM_Handle_t, ctypes.c_char_p]
rkllm_load_prompt_cache.restype = ctypes.c_int
rkllm_load_prompt_cache(self.handle, ctypes.c_char_p((prompt_cache_path).encode('utf-8')))
def run(self, prompt):
rkllm_lora_params = None
if self.lora_model_name:
rkllm_lora_params = RKLLMLoraParam()
rkllm_lora_params.lora_adapter_name = ctypes.c_char_p((self.lora_model_name).encode('utf-8'))
rkllm_infer_params = RKLLMInferParam()
ctypes.memset(ctypes.byref(rkllm_infer_params), 0, ctypes.sizeof(RKLLMInferParam))
rkllm_infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE
rkllm_infer_params.lora_params = ctypes.byref(rkllm_lora_params) if rkllm_lora_params else None
rkllm_input = RKLLMInput()
rkllm_input.input_mode = RKLLMInputMode.RKLLM_INPUT_PROMPT
rkllm_input.input_data.prompt_input = ctypes.c_char_p((PROMPT_TEXT_PREFIX + prompt + PROMPT_TEXT_POSTFIX).encode('utf-8'))
self.rkllm_run(self.handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), None)
return
def release(self):
self.rkllm_destroy(self.handle)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--rkllm_model_path', type=str, required=True, help='Absolute path of the converted RKLLM model on the Linux board;')
parser.add_argument('--target_platform', type=str, required=True, help='Target platform: e.g., rk3588/rk3576;')
parser.add_argument('--lora_model_path', type=str, help='Absolute path of the lora_model on the Linux board;')
parser.add_argument('--prompt_cache_path', type=str, help='Absolute path of the prompt_cache file on the Linux board;')
args = parser.parse_args()
if not os.path.exists(args.rkllm_model_path):
print("Error: Please provide the correct rkllm model path, and ensure it is the absolute path on the board.")
sys.stdout.flush()
exit()
if not (args.target_platform in ["rk3588", "rk3576"]):
print("Error: Please specify the correct target platform: rk3588/rk3576.")
sys.stdout.flush()
exit()
if args.lora_model_path:
if not os.path.exists(args.lora_model_path):
print("Error: Please provide the correct lora_model path, and advise it is the absolute path on the board.")
sys.stdout.flush()
exit()
if args.prompt_cache_path:
if not os.path.exists(args.prompt_cache_path):
print("Error: Please provide the correct prompt_cache_file path, and advise it is the absolute path on the board.")
sys.stdout.flush()
exit()
# Fix frequency
command = "sudo bash fix_freq_{}.sh".format(args.target_platform)
subprocess.run(command, shell=True)
# Set resource limit
resource.setrlimit(resource.RLIMIT_NOFILE, (102400, 102400))
# Initialize RKLLM model
print("=========init....===========")
sys.stdout.flush()
model_path = args.rkllm_model_path
rkllm_model = RKLLM(model_path, args.lora_model_path, args.prompt_cache_path)
print("RKLLM Model has been initialized successfully")
print("==============================")
sys.stdout.flush()
# Create a function to receive data sent by the user using a request
@app.route('/rkllm_chat', methods=['POST'])
def receive_message():
# Link global variables to retrieve the output information from the callback function
global global_text, global_state
global is_blocking
# If the server is in a blocking state, return a specific response.
if is_blocking or global_state==0:
return jsonify({'status': 'error', 'message': 'RKLLM_Server is busy! Maybe you can try again later.'}), 503
lock.acquire()
try:
# Set the server to a blocking state.
is_blocking = True
# Get JSON data from the POST request.
data = request.json
if data and 'messages' in data:
# Reset global variables.
global_text = []
global_state = -1
# Define the structure for the returned response.
rkllm_responses = {
"id": "rkllm_chat",
"object": "rkllm_chat",
"created": None,
"choices": [],
"usage": {
"prompt_tokens": None,
"completion_tokens": None,
"total_tokens": None
}
}
if not "stream" in data.keys() or data["stream"] == False:
# Process the received data here.
messages = data['messages']
print("Received messages:", messages)
for index, message in enumerate(messages):
input_prompt = message['content']
rkllm_output = ""
# Create a thread for model inference.
model_thread = threading.Thread(target=rkllm_model.run, args=(input_prompt,))
model_thread.start()
# Wait for the model to finish running and periodically check the inference thread of the model.
model_thread_finished = False
while not model_thread_finished:
while len(global_text) > 0:
rkllm_output += global_text.pop(0)
time.sleep(0.005)
model_thread.join(timeout=0.005)
model_thread_finished = not model_thread.is_alive()
rkllm_responses["choices"].append(
{"index": index,
"message": {
"role": "assistant",
"content": rkllm_output,
},
"logprobs": None,
"finish_reason": "stop"
}
)
return jsonify(rkllm_responses), 200
else:
messages = data['messages']
print("Received messages:", messages)
for index, message in enumerate(messages):
input_prompt = message['content']
rkllm_output = ""
def generate():
model_thread = threading.Thread(target=rkllm_model.run, args=(input_prompt,))
model_thread.start()
model_thread_finished = False
while not model_thread_finished:
while len(global_text) > 0:
rkllm_output = global_text.pop(0)
rkllm_responses["choices"].append(
{"index": index,
"delta": {
"role": "assistant",
"content": rkllm_output,
},
"logprobs": None,
"finish_reason": "stop" if global_state == 1 else None,
}
)
yield f"{json.dumps(rkllm_responses)}\n\n"
model_thread.join(timeout=0.005)
model_thread_finished = not model_thread.is_alive()
return Response(generate(), content_type='text/plain')
else:
return jsonify({'status': 'error', 'message': 'Invalid JSON data!'}), 400
finally:
lock.release()
is_blocking = False
# Start the Flask application.
# app.run(host='0.0.0.0', port=8080)
app.run(host='127.0.0.1', port=8080, threaded=True, debug=False)
print("====================")
print("RKLLM model inference completed, releasing RKLLM model resources...")
rkllm_model.release()
print("====================")