ONNX Runtime generate() API von 0.5.2 auf 0.6.0 migrieren

Erfahren Sie, wie Sie von der ONNX Runtime generate() Version 0.5.2 auf Version 0.6.0 migrieren.

Version 0.6.0 fügt Unterstützung für den „Chat-Modus“ hinzu, auch bekannt als Fortsetzung, kontinuierliche Dekodierung und interaktive Dekodierung. Mit der Einführung des Chat-Modus wurde eine API-Änderung vorgenommen, die zu Kompatibilitätsproblemen führen kann.

Zusammenfassend lässt sich sagen, dass die neue API eine Methode namens AppendTokens zum Generator hinzufügt, die Mehrfach-Konversationen ermöglicht. Zuvor wurden Eingaben in GeneratorParams gesetzt, bevor der Generator erstellt wurde.

Das Aufrufen von AppendTokens außerhalb der Konversationsschleife kann zur Implementierung von System-Prompt-Caching verwendet werden.

Hinweis: Chat-Modus und System-Prompt-Caching werden nur für Batch-Größe 1 unterstützt. Außerdem werden sie derzeit auf CPUs, NVIDIA-GPUs mit dem CUDA EP und allen GPUs mit dem Web GPU native EP unterstützt. Sie werden auf NPUs oder GPUs, die mit dem DirecML EP laufen, nicht unterstützt. Für den Frage- & Antwortmodus (Q&A) sind die unten beschriebenen Migrationen immer noch erforderlich.

Python

Python-Code für Fragen und Antworten (einzelner Durchgang) zu 0.6.0 migrieren

  1. Ersetzen Sie Aufrufe von params.input_ids = input_tokens durch generator.append_tokens(input_tokens), nachdem das Generatorobjekt erstellt wurde.
  2. Entfernen Sie Aufrufe von generator.compute_logits().
  3. Wenn die Anwendung eine Q&A-Schleife hat, löschen Sie den generator zwischen den append_token-Aufrufen, um den Zustand des Modells zurückzusetzen.

System-Prompt-Caching zu Python-Anwendungen hinzufügen

  1. Erstellen und tokenisieren Sie den System-Prompt und rufen Sie generator.append_tokens(system_tokens) auf. Dieser Aufruf kann erfolgen, bevor der Benutzer nach seinem Prompt gefragt wird.

    system_tokens = tokenizer.encode(system_prompt)
    generator.append_tokens(system_tokens)
    

Chat-Modus zu Python-Anwendungen hinzufügen

  1. Erstellen Sie eine Schleife in Ihrer Anwendung und rufen Sie generator.append_tokens(prompt) jedes Mal auf, wenn der Benutzer eine neue Eingabe liefert.

    while True:
        user_input = input("Input: ")
        input_tokens = tokenizer.encode(user_input)
        generator.append_tokens(input_tokens)
    
        while not generator.is_done():
            generator.generate_next_token()
    
            new_token = generator.get_next_tokens()[0]
            print(tokenizer_stream.decode(new_token), end='', flush=True)
         except KeyboardInterrupt:
         print()
    

C++

C++-Code für Fragen und Antworten (einzelner Durchgang) zu 0.6.0 migrieren

  1. Ersetzen Sie Aufrufe von params->SetInputSequences(*sequences) durch generator->AppendTokenSequences(*sequences).
  2. Entfernen Sie Aufrufe von generator->ComputeLogits().

System-Prompt-Caching zu C++-Anwendungen hinzufügen

  1. Erstellen und tokenisieren Sie den System-Prompt und rufen Sie generator->AppendTokenSequences(*sequences) auf. Dieser Aufruf kann erfolgen, bevor der Benutzer nach seinem Prompt gefragt wird.

    auto sequences = OgaSequences::Create();
    tokenizer->Encode(system_prompt.c_str(), *sequences);
    generator->AppendTokenSequences(*sequences);
    generator.append_tokens(system_tokens)
    

Chat-Modus zu Ihrer C++-Anwendung hinzufügen

  1. Fügen Sie eine Chat-Schleife zu Ihrer Anwendung hinzu.
    std::cout << "Generating response..." << std::endl;
    auto params = OgaGeneratorParams::Create(*model);
    params->SetSearchOption("max_length", 1024);
    
    auto generator = OgaGenerator::Create(*model, *params);
    
    while (true) {
      std::string text;
      std::cout << "Prompt: "  << std::endl;
      std::getline(std::cin, prompt);
    
      auto sequences = OgaSequences::Create();
      tokenizer->Encode(prompt.c_str(), *sequences);
    
      generator->AppendTokenSequences(*sequences);
    
      while (!generator->IsDone()) {
        generator->GenerateNextToken();
    
        const auto num_tokens = generator->GetSequenceCount(0);
        const auto new_token = generator->GetSequenceData(0)[num_tokens - 1];
       std::cout << tokenizer_stream->Decode(new_token) << std::flush;
       }
    }
    

C#

C#-Code für Fragen und Antworten (einzelner Durchgang) zu 0.6.0 migrieren

  1. Ersetzen Sie Aufrufe von generatorParams.SetInputSequences(sequences) durch generator.AppendTokenSequences(sequences)`.
  2. Entfernen Sie Aufrufe von generator.ComputeLogits().

System-Prompt-Caching zu Ihrer C#-Anwendung hinzufügen

  1. Erstellen und tokenisieren Sie den System-Prompt und rufen Sie generator->AppendTokenSequences() auf. Dieser Aufruf kann erfolgen, bevor der Benutzer nach seinem Prompt gefragt wird.

    var systemPrompt = "..."
    auto sequences = OgaSequences::Create();
    tokenizer->Encode(systemPrompt, *sequences);
    generator->AppendTokenSequences(*sequences);
    

Chat-Modus zu Ihrer C#-Anwendung hinzufügen

  1. Fügen Sie eine Chat-Schleife zu Ihrer Anwendung hinzu.
    using var tokenizerStream = tokenizer.CreateStream();
    using var generator = new Generator(model, generatorParams);
    Console.WriteLine("Prompt:");
    prompt = Console.ReadLine();
    
    // Example Phi-3 template
    var sequences = tokenizer.Encode($"<|user|>{prompt}<|end|><|assistant|>");
    
    do
    {
       generator.AppendTokenSequences(sequences);
       var watch = System.Diagnostics.Stopwatch.StartNew();
       while (!generator.IsDone())
       {
          generator.GenerateNextToken();
          Console.Write(tokenizerStream.Decode(generator.GetSequence(0)[^1]));
       }
       Console.WriteLine();
       watch.Stop();
       var runTimeInSeconds = watch.Elapsed.TotalSeconds;
       var outputSequence = generator.GetSequence(0);
       var totalTokens = outputSequence.Length;
       Console.WriteLine($"Streaming Tokens: {totalTokens} Time: {runTimeInSeconds:0.00} Tokens per second: {totalTokens / runTimeInSeconds:0.00}");
       Console.WriteLine("Next prompt:");
       var nextPrompt = Console.ReadLine();
       sequences = tokenizer.Encode($"<|user|>{prompt}<|end|><|assistant|>");
    } while (prompt != null);
    
    

Java

Kommt bald