本文主要研究一下langchain4j结合ONNX进行得分重排

步骤

pom.xml

<dependency>
    <groupId>dev.langchain4j</groupId>
    <artifactId>langchain4j-onnx-scoring</artifactId>
    <version>1.0.0-beta1</version>
</dependency>

下载模型

wget https://hf-mirror.com/Xenova/ms-marco-MiniLM-L-6-v2/resolve/main/onnx/model_quantized.onnx?download=true
wget https://hf-mirror.com/Xenova/ms-marco-MiniLM-L-6-v2/resolve/main/tokenizer.json?download=true

example

public class ONNXTest {

    /**
     * wget https://hf-mirror.com/Xenova/ms-marco-MiniLM-L-6-v2/resolve/main/onnx/model_quantized.onnx?download=true
     * wget https://hf-mirror.com/Xenova/ms-marco-MiniLM-L-6-v2/resolve/main/tokenizer.json?download=true
     * @param args
     * @throws IOException
     */
    public static void main(String[] args) throws IOException {
        // To check the modelMaxLength parameter, refer to the model configuration file at  https://huggingface.co/Xenova/ms-marco-MiniLM-L-6-v2/resolve/main/tokenizer_config.json
        String pathToModel = System.getProperty("user.home")+"/model_quantized.onnx";
        String pathToTokenizer = System.getProperty("user.home")+ "/tokenizer.json";
        OnnxScoringModel model = new OnnxScoringModel(pathToModel, new OrtSession.SessionOptions(), pathToTokenizer, 512, false);
        List<TextSegment> segments = new ArrayList<>();
        segments.add(TextSegment.from("Berlin has a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers."));
        segments.add(TextSegment.from("New York City is famous for the Metropolitan Museum of Art."));

        String query = "How many people live in Berlin?";

        // when
        Response<List<Double>> response = model.scoreAll(segments, query);

        // then
        List<Double> scores = response.content();
        System.out.println("score1:" + scores.get(0));
        System.out.println("score2:" + scores.get(1));
        System.out.println("token count:" + response.tokenUsage().totalTokenCount());
        System.out.println("finish reason:" + response.finishReason());
    }
}
输出如下:
score1:8.663132667541504
score2:-11.245542526245117
token count:50
finish reason:null

小结

langchain4j提供了langchain4j-onnx-scoring用于通过ONNX runtime来本地运行scoring (reranking) model。通过OnnxScoringModel的scoreAll方法可以得到文档的评分。

doc


codecraft
11.9k 声望2k 粉丝

当一个代码的工匠回首往事时,不因虚度年华而悔恨,也不因碌碌无为而羞愧,这样,当他老的时候,可以很自豪告诉世人,我曾经将代码注入生命去打造互联网的浪潮之巅,那是个很疯狂的时代,我在一波波的浪潮上留下...


引用和评论

0 条评论