Java调用triton inference server初体验

  |   0 评论   |   0 浏览

背景

前面实现了NVIDIA Triton的demo,以及wetts在NVIDIA Triton中的demo。

在wetts triton demo中,使用了python客户端,来调用GRPCInferenceService。

本文,我们使用Java客户端,来调用GRPCInferenceService。

初体验

生成grpc client stub

git clone https://github.com/triton-inference-server/common.git


$ cd library
$ cp ../common-repo/protobuf/*.proto src/main/proto/

如下:

图片.png

maven会生成对应的java文件

图片.png

运行示例

官方示例

mvn exec:java -Dexec.mainClass=clients.SimpleJavaClient -Dexec.args="<host> <port>"

结果会报异常,因为我们并没有simple这个模块。

Exception in thread "main" io.grpc.StatusRuntimeException: NOT_FOUND: Request for unknown model: 'simple' is not found
	at io.grpc.stub.ClientCalls.toStatusRuntimeException(ClientCalls.java:262)
	at io.grpc.stub.ClientCalls.getUnchecked(ClientCalls.java:243)
	at io.grpc.stub.ClientCalls.blockingUnaryCall(ClientCalls.java:156)
	at inference.GRPCInferenceServiceGrpc$GRPCInferenceServiceBlockingStub.modelInfer(GRPCInferenceServiceGrpc.java:1563)
	at com.oppo.bot.speech.atom.ttsg.service.test.MainTritonClientTest.main(MainTritonClientTest.java:82)

wetts示例

import com.google.protobuf.ByteString;
import inference.GRPCInferenceServiceGrpc;
import inference.GRPCInferenceServiceGrpc.GRPCInferenceServiceBlockingStub;
import inference.GrpcService.*;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import org.apache.commons.io.FileUtils;

import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.Charset;

public class MainTritonClientTest {

    public static void main(String[] args) throws IOException {

        String host = "10.5.153.28";
        int port = 8001;

        String model_name = "tts";
        String model_version = "";

        // # Create gRPC stub for communicating with the server
        ManagedChannel channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build();
        GRPCInferenceServiceBlockingStub grpc_stub = GRPCInferenceServiceGrpc.newBlockingStub(channel);

        // check server is live
        ServerLiveRequest serverLiveRequest = ServerLiveRequest.getDefaultInstance();
        ServerLiveResponse r = grpc_stub.serverLive(serverLiveRequest);
        System.out.println(r);

        // Generate the request
        ModelInferRequest.Builder request = ModelInferRequest.newBuilder();
        request.setModelName(model_name);
        request.setModelVersion(model_version);

        // Input data
        InferTensorContents.Builder input0_data = InferTensorContents.newBuilder();
        input0_data.addBytesContents(ByteString.copyFrom("今天天气不错", Charset.forName("UTF-8")));

        // Populate the inputs in inference request
        ModelInferRequest.InferInputTensor.Builder input0 = ModelInferRequest.InferInputTensor
                .newBuilder();
        input0.setName("text");
        input0.setDatatype("BYTES");
        input0.addShape(1);
        input0.addShape(1);
        input0.setContents(input0_data);

        request.addInputs(0, input0);

        // Populate the outputs in the inference request
        ModelInferRequest.InferRequestedOutputTensor.Builder output0 = ModelInferRequest.InferRequestedOutputTensor
                .newBuilder();
        output0.setName("wav");

        request.addOutputs(0, output0);

        ModelInferResponse response = grpc_stub.modelInfer(request.build());

        System.out.println("raw output count: " + response.getRawOutputContentsCount());
        ByteBuffer byteBuffer = response.getRawOutputContents(0).asReadOnlyByteBuffer().order(ByteOrder.LITTLE_ENDIAN);
        byte[] bytes = new byte[byteBuffer.remaining()
        ];
        byteBuffer.get(bytes);

        String filename = "/tmp/grpc_" + System.currentTimeMillis() + ".pcm";
        System.out.println("save to file: " + filename);
        FileUtils.writeByteArrayToFile(new File(filename), bytes);

        channel.shutdownNow();
    }
}

  

参考

  1. Example Java and Scala client Using Generated GRPC API