Java调用triton inference server初体验
背景
前面实现了NVIDIA Triton的demo,以及wetts在NVIDIA Triton中的demo。
在wetts triton demo中,使用了python客户端,来调用GRPCInferenceService。
本文,我们使用Java客户端,来调用GRPCInferenceService。
初体验
生成grpc client stub
git clone https://github.com/triton-inference-server/common.git
$ cd library
$ cp ../common-repo/protobuf/*.proto src/main/proto/
如下:
maven会生成对应的java文件
运行示例
官方示例
mvn exec:java -Dexec.mainClass=clients.SimpleJavaClient -Dexec.args="<host> <port>"
结果会报异常,因为我们并没有simple这个模块。
Exception in thread "main" io.grpc.StatusRuntimeException: NOT_FOUND: Request for unknown model: 'simple' is not found
at io.grpc.stub.ClientCalls.toStatusRuntimeException(ClientCalls.java:262)
at io.grpc.stub.ClientCalls.getUnchecked(ClientCalls.java:243)
at io.grpc.stub.ClientCalls.blockingUnaryCall(ClientCalls.java:156)
at inference.GRPCInferenceServiceGrpc$GRPCInferenceServiceBlockingStub.modelInfer(GRPCInferenceServiceGrpc.java:1563)
at com.oppo.bot.speech.atom.ttsg.service.test.MainTritonClientTest.main(MainTritonClientTest.java:82)
wetts示例
import com.google.protobuf.ByteString;
import inference.GRPCInferenceServiceGrpc;
import inference.GRPCInferenceServiceGrpc.GRPCInferenceServiceBlockingStub;
import inference.GrpcService.*;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import org.apache.commons.io.FileUtils;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.Charset;
public class MainTritonClientTest {
public static void main(String[] args) throws IOException {
String host = "10.5.153.28";
int port = 8001;
String model_name = "tts";
String model_version = "";
// # Create gRPC stub for communicating with the server
ManagedChannel channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build();
GRPCInferenceServiceBlockingStub grpc_stub = GRPCInferenceServiceGrpc.newBlockingStub(channel);
// check server is live
ServerLiveRequest serverLiveRequest = ServerLiveRequest.getDefaultInstance();
ServerLiveResponse r = grpc_stub.serverLive(serverLiveRequest);
System.out.println(r);
// Generate the request
ModelInferRequest.Builder request = ModelInferRequest.newBuilder();
request.setModelName(model_name);
request.setModelVersion(model_version);
// Input data
InferTensorContents.Builder input0_data = InferTensorContents.newBuilder();
input0_data.addBytesContents(ByteString.copyFrom("今天天气不错", Charset.forName("UTF-8")));
// Populate the inputs in inference request
ModelInferRequest.InferInputTensor.Builder input0 = ModelInferRequest.InferInputTensor
.newBuilder();
input0.setName("text");
input0.setDatatype("BYTES");
input0.addShape(1);
input0.addShape(1);
input0.setContents(input0_data);
request.addInputs(0, input0);
// Populate the outputs in the inference request
ModelInferRequest.InferRequestedOutputTensor.Builder output0 = ModelInferRequest.InferRequestedOutputTensor
.newBuilder();
output0.setName("wav");
request.addOutputs(0, output0);
ModelInferResponse response = grpc_stub.modelInfer(request.build());
System.out.println("raw output count: " + response.getRawOutputContentsCount());
ByteBuffer byteBuffer = response.getRawOutputContents(0).asReadOnlyByteBuffer().order(ByteOrder.LITTLE_ENDIAN);
byte[] bytes = new byte[byteBuffer.remaining()
];
byteBuffer.get(bytes);
String filename = "/tmp/grpc_" + System.currentTimeMillis() + ".pcm";
System.out.println("save to file: " + filename);
FileUtils.writeByteArrayToFile(new File(filename), bytes);
channel.shutdownNow();
}
}