TIS教程04-客户端
簡介
在之前的文章中,我們主要關注服務端的配置和部署,這無可厚非,因為Triton Inference Server本就是服務端框架。但是,作為一個完善的生態,Triton也對客戶端請求做了諸多封裝以方便開發者的使用,這樣我們就不需要過分關注協議通訊等諸多細節。為了提供這種封裝,簡化與Triton的通信,Triton團隊提供了幾個客戶端庫并給出了使用示例,本文將圍繞Python SDK進行講解,感興趣的也可以查看C++庫和其他語言的支持(grpc_generated)。
安裝
Python客戶端庫的最簡單方式就是使用pip進行安裝,當然也支持源碼安裝和docker鏡像的使用,但是這里為了開發方便我個人比較建議客戶端sdk就在虛擬環境中安裝而不是使用docker。
在前面的文章中,我們已經介紹了Triton Inference Server主要支持兩種協議,即HTTP和GRPC,因此他提供單獨某種協議的Python包安裝或者兩種協議均支持的Python包安裝,命令如下,需要支持指定協議只需要將下面的all更改為http或者grpc即可。使用all表示同時安裝HTTP/REST和GRPC客戶端庫。
pip install nvidia-pyindex pip install tritonclient[all]需要注意的是,pip安裝目前僅支持Linux,且系統必須包含perf_analyzer,這個包在Ubuntu 20.04上默認是有的,之前的版本可以通過下面的命令補上。
sudo apt update sudo apt install libb64-dev使用
使用該庫必須要有GRPC和HTTP的基礎知識并且去閱讀官方文檔和源碼,我們下面看幾個典型的應用。
Bytes/String 數據類型
一些框架支持張量,其中張量中的每個元素都是可變長度的二進制數據。 每個元素可以保存一個字符串或任意字節序列。 在客戶端,此數據類型為 BYTES(具體參考)。
Python客戶端使用numpy來表示輸入輸出張量。對于BYTES張量,numpy中對應的數據類型應該為np.object_, 為了與以前版本的客戶端庫向后兼容,np.bytes_ 也可以用于 BYTES 張量。但是,不建議使用np.bytes_,因為使用此 dtype 會導致 numpy 從每個數組元素中刪除所有尾隨零。 因此,以零結尾的二進制序列將無法正確表示。
關于BYTES/STRING張量的Python示例代碼可以訪問官方示例,其中構建輸入并發起請求的核心代碼如下,我已經進行了詳細注釋。
# 構建輸入輸出 inputs = [] outputs = [] # 設定發起請求的請求體數據格式 inputs.append(grpcclient.InferInput('INPUT0', [1, 16], "BYTES")) inputs.append(grpcclient.InferInput('INPUT1', [1, 16], "BYTES"))# 模擬輸入數據 in0 = np.arange(start=0, stop=16, dtype=np.int32) in0 = np.expand_dims(in0, axis=0) in1 = np.ones(shape=(1, 16), dtype=np.int32) expected_sum = np.add(in0, in1) expected_diff = np.subtract(in0, in1)# 這里的演示模型期待兩個BYTES張量,每個張量內部元素是UTF8的字符串表示的整數 in0n = np.array([str(x).encode('utf-8') for x in in0.reshape(in0.size)], dtype=np.object_) input0_data = in0n.reshape(in0.shape) in1n = np.array([str(x).encode('utf-8') for x in in1.reshape(in1.size)], dtype=np.object_) input1_data = in1n.reshape(in1.shape)# 初始化數據 inputs[0].set_data_from_numpy(input0_data) inputs[1].set_data_from_numpy(input1_data)outputs.append(grpcclient.InferRequestedOutput('OUTPUT0')) outputs.append(grpcclient.InferRequestedOutput('OUTPUT1'))# 請求服務端進行推理 results = triton_client.infer(model_name=model_name,inputs=inputs,outputs=outputs)# 獲得推理respose中對應鍵的數據 output0_data = results.as_numpy('OUTPUT0') output1_data = results.as_numpy('OUTPUT1')系統共享內存
在某些情況下,使用系統共享內存在客戶端庫和 Triton服務端 之間通信張量可以顯著提高性能。
關于使用系統共享內存的示例代碼可以訪問官方示例。
不過,由于Python 沒有分配和訪問共享內存的標準方法,因此作為示例,官方提供了一個簡單的系統共享內存模塊,可與 Python 客戶端庫一起使用以創建、設置和銷毀系統共享內存。
這部分的代碼可以自己去閱讀,理解起來并不難。
CUDA共享內存
在某些情況下,使用 CUDA 共享內存在客戶端庫和 Triton服務端 之間進行張量通信可以顯著提高性能。
關于使用CUDA共享內存的示例可以訪問官方示例,它的整體流程就如代碼中所示。
同樣,由于Python 沒有分配和訪問共享內存的標準方法,因此作為示例,提供了一個簡單的 CUDA 共享內存模塊,可與 Python 客戶端庫一起使用以創建、設置和銷毀 CUDA 共享內存。
補充
此外還有很多操作,比如有狀態模型的推理序列的控制等等,這些后續用到了我會補充。
圖像分類示例
下面我們以官方的圖像分類源碼為示例來理解客戶端工作的整個流程。
服務端配置
首先,這個客戶端要想成功運行首先需要部署一個分類模型并開啟服務端訪問,具體可以參考官方的quickstart。我們按照這個quickstart教程和官方的模型倉庫示例文件組織模型倉庫的文件如下。
model_repository/ ├── densenet_onnx├── 1│ └── model.onnx├── config.pbtxt└── densenet_labels.txt接著,我們使用之前文章下載的docker鏡像運行Triton服務器,命令如下(這里選項里面設置了使用1號GPU卡并顯式服務器日志)。
docker run --gpus '"device=1"' --rm -p8000:8000 -p8001:8001 -p8002:8002 -v/home/zhouchen/model_repository:/models nvcr.io/nvidia/tritonserver:20.10-py3 tritonserver --model-repository=/models --log-verbose 1此時我們應該可以看到如下的日志界面,這代表服務器成功部署并運行著,然后我們就可以著手客戶端的配置了。
客戶端配置
虛擬環境的創建和相關庫的安裝本文第二節已經提到了,這里不多贅述,我們直接看官方的源碼,這個代碼比較長長,我這里刪去了模型解析函數、預處理和后處理函數來進行解析,這樣會清晰直觀一點,因此這個代碼是無法跑通的,可以跑通的源碼去這個鏈接查看。
處理后的帶有我詳細注釋的代碼如下所示(從main開始看),整體流程為“構建想要使用的協議的client對象”—>“預處理數據”—>“然后創建請求體”—>“向服務端發起模型推理請求”—>“得到服務端反饋結果”—>“后處理結果數據”—>“顯示渲染后的結果”。
import argparse from functools import partial import os import sysfrom PIL import Image import numpy as np from attrdict import AttrDictimport tritonclient.grpc as grpcclient import tritonclient.grpc.model_config_pb2 as mc import tritonclient.http as httpclient from tritonclient.utils import InferenceServerException from tritonclient.utils import triton_to_np_dtypeif sys.version_info >= (3, 0):import queue else:import Queue as queueclass UserData:def __init__(self):self._completed_requests = queue.Queue()# Callback function used for async_stream_infer() def completion_callback(user_data, result, error):# passing error raise and handling outuser_data._completed_requests.put((result, error))FLAGS = Nonedef requestGenerator(batched_image_data, input_name, output_name, dtype, FLAGS):protocol = FLAGS.protocol.lower()if protocol == "grpc":client = grpcclientelse:client = httpclient# Set the input datainputs = [client.InferInput(input_name, batched_image_data.shape, dtype)]inputs[0].set_data_from_numpy(batched_image_data)outputs = [client.InferRequestedOutput(output_name, class_count=FLAGS.classes)]yield inputs, outputs, FLAGS.model_name, FLAGS.model_versiondef convert_http_metadata_config(_metadata, _config):_model_metadata = AttrDict(_metadata)_model_config = AttrDict(_config)return _model_metadata, _model_configif __name__ == '__main__':‘# 命令行參數parser = argparse.ArgumentParser()parser.add_argument('-v', '--verbose', action="store_true", required=False, default=False,help='Enable verbose output')parser.add_argument('-a', '--async', dest="async_set", action="store_true", required=False, default=False,help='Use asynchronous inference API')parser.add_argument('--streaming', action="store_true", required=False, default=False,help='Use streaming inference API. ' + 'The flag is only available with gRPC protocol.')parser.add_argument('-m', '--model-name', type=str, required=True, help='Name of model')parser.add_argument('-x', '--model-version', type=str, required=False, default="",help='Version of model. Default is to use latest version.')parser.add_argument('-b', '--batch-size', type=int, required=False, default=1, help='Batch size. Default is 1.')parser.add_argument('-c', '--classes', type=int, required=False, default=1,help='Number of class results to report. Default is 1.')parser.add_argument('-s', '--scaling', type=str, choices=['NONE', 'INCEPTION', 'VGG'], required=False,default='NONE', help='Type of scaling to apply to image pixels. Default is NONE.')parser.add_argument('-u', '--url', type=str, required=False, default='localhost:8000',help='Inference server URL. Default is localhost:8000.')parser.add_argument('-i', '--protocol', type=str, required=False, default='HTTP',help='Protocol (HTTP/gRPC) used to communicate with ' + 'the inference service. Default is HTTP.')parser.add_argument('image_filename', type=str, nargs='?', default=None, help='Input image / Input folder.')# 從命令行解析參數FLAGS = parser.parse_args()if FLAGS.streaming and FLAGS.protocol.lower() != "grpc":raise Exception("Streaming is only allowed with gRPC protocol")try:# 根據協議命令行參數給的協議類型創建client對象if FLAGS.protocol.lower() == "grpc":# Create gRPC client for communicating with the servertriton_client = grpcclient.InferenceServerClient(url=FLAGS.url, verbose=FLAGS.verbose)else:# Specify large enough concurrency to handle the# the number of requests.concurrency = 20 if FLAGS.async_set else 1triton_client = httpclient.InferenceServerClient(url=FLAGS.url, verbose=FLAGS.verbose, concurrency=concurrency)except Exception as e:print("client creation failed: " + str(e))sys.exit(1)# 確認模型和需求相符并且得到預處理需要的模型配置參數try:model_metadata = triton_client.get_model_metadata(model_name=FLAGS.model_name, model_version=FLAGS.model_version)except InferenceServerException as e:print("failed to retrieve the metadata: " + str(e))sys.exit(1)try:model_config = triton_client.get_model_config(model_name=FLAGS.model_name, model_version=FLAGS.model_version)except InferenceServerException as e:print("failed to retrieve the config: " + str(e))sys.exit(1)if FLAGS.protocol.lower() == "grpc":model_config = model_config.configelse:model_metadata, model_config = convert_http_metadata_config(model_metadata, model_config)# 解析模型配置max_batch_size, input_name, output_name, c, h, w, format, dtype = parse_model(model_metadata, model_config)# 得到輸入圖像文件名列表filenames = []if os.path.isdir(FLAGS.image_filename):filenames = [os.path.join(FLAGS.image_filename, f)for f in os.listdir(FLAGS.image_filename)if os.path.isfile(os.path.join(FLAGS.image_filename, f))]else:filenames = [FLAGS.image_filename,]filenames.sort()# 輸入數據預處理為符合模型輸入的格式# requirementsimage_data = []for filename in filenames:img = Image.open(filename)image_data.append(preprocess(img, format, dtype, c, h, w, FLAGS.scaling,FLAGS.protocol.lower()))# 按照batchsize發起請求requests = []responses = []result_filenames = []request_ids = []image_idx = 0last_request = Falseuser_data = UserData()# Holds the handles to the ongoing HTTP async requests.async_requests = []sent_count = 0if FLAGS.streaming:triton_client.start_stream(partial(completion_callback, user_data))while not last_request:input_filenames = []repeated_image_data = []for idx in range(FLAGS.batch_size):input_filenames.append(filenames[image_idx])repeated_image_data.append(image_data[image_idx])image_idx = (image_idx + 1) % len(image_data)if image_idx == 0:last_request = Trueif max_batch_size > 0:batched_image_data = np.stack(repeated_image_data, axis=0)else:batched_image_data = repeated_image_data[0]# 發送請求try:for inputs, outputs, model_name, model_version in requestGenerator(batched_image_data, input_name, output_name, dtype, FLAGS):sent_count += 1if FLAGS.streaming:triton_client.async_stream_infer(FLAGS.model_name,inputs,request_id=str(sent_count),model_version=FLAGS.model_version,outputs=outputs)elif FLAGS.async_set:if FLAGS.protocol.lower() == "grpc":triton_client.async_infer(FLAGS.model_name,inputs,partial(completion_callback, user_data),request_id=str(sent_count),model_version=FLAGS.model_version,outputs=outputs)else:async_requests.append(triton_client.async_infer(FLAGS.model_name,inputs,request_id=str(sent_count),model_version=FLAGS.model_version,outputs=outputs))else:responses.append(triton_client.infer(FLAGS.model_name,inputs,request_id=str(sent_count),model_version=FLAGS.model_version,outputs=outputs))except InferenceServerException as e:print("inference failed: " + str(e))if FLAGS.streaming:triton_client.stop_stream()sys.exit(1)if FLAGS.streaming:triton_client.stop_stream()if FLAGS.protocol.lower() == "grpc":if FLAGS.streaming or FLAGS.async_set:processed_count = 0while processed_count < sent_count:(results, error) = user_data._completed_requests.get()processed_count += 1if error is not None:print("inference failed: " + str(error))sys.exit(1)responses.append(results)else:if FLAGS.async_set:# Collect results from the ongoing async requests# for HTTP Async requests.for async_request in async_requests:responses.append(async_request.get_result())# 得到響應,并解析響應得到需要的結果,后處理結果for response in responses:if FLAGS.protocol.lower() == "grpc":this_id = response.get_response().idelse:this_id = response.get_response()["id"]print("Request {}, batch size {}".format(this_id, FLAGS.batch_size))postprocess(response, output_name, FLAGS.batch_size, max_batch_size > 0)print("PASS")我們將客戶端代碼文件也放在服務端的機器上進行測試,因此地址填寫localhost即可,使用下面的命令對任意一個圖片進行推理,得到下圖的反饋,這表示請求正常反饋了。
python image_client.py -m densenet_onnx --verbose --protocol grpc -u "localhost:8001" -s NONE ./data/car.jpg同時,可以在服務端日志上看到如下的處理流程日志。
至此,我們就完成了使用客戶端sdk進行圖像分類的演示。
總結
本文主要介紹了TIS客戶端SDK的使用,并使用圖像分類任務進行了演示,具體的可以查看官方文檔。
總結
以上是生活随笔為你收集整理的TIS教程04-客户端的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 0007-Reverse Integer
- 下一篇: PyCharm编写shell脚本无法运行