package com.logicbig.example;

import dev.langchain4j.model.chat.StreamingChatModel;
import dev.langchain4j.model.ollama.OllamaStreamingChatModel;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.service.SystemMessage;
import dev.langchain4j.service.TokenStream;
import dev.langchain4j.service.UserMessage;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicInteger;

public class StreamingExample {

    interface StreamingAssistant {
        @SystemMessage("Always give short answers.")
        @UserMessage("{{it}}")
        TokenStream chatStream(String message);
    }

    public static void main(String[] args) throws Exception {
        // Create streaming model
        StreamingChatModel model =
                OllamaStreamingChatModel.builder()
                                        .baseUrl("http://localhost:11434")
                                        .modelName("phi3:mini-128k")
                                        .temperature(0.7)
                                        .build();

        StreamingAssistant assistant =
                AiServices.create(StreamingAssistant.class, model);

        System.out.println("=== Basic Streaming ===");
        basicStreaming(assistant);


        System.out.println("\n\n=== Streaming with Completion Future ===");
        streamingWithFuture(assistant);


        System.out.println("\n\n=== Simulated Cancellation ===");
        streamingWithCancellation(assistant);
    }

    private static void basicStreaming(StreamingAssistant assistant) {
        CountDownLatch latch = new CountDownLatch(1);

        String message = "Tell me about artificial intelligence";
        System.out.println("message: " + message);
        System.out.println("Streaming response:");
        TokenStream stream = assistant.chatStream(message);
        AtomicInteger tokenCount = new AtomicInteger(0);
        stream.onPartialResponse(partial -> {
                  System.out.print(partial);
                  tokenCount.incrementAndGet();
              })
              .onCompleteResponse(response -> {
                  System.out.println("\n\nTotal tokens: " + tokenCount.get());
                  latch.countDown();
              })
              .onError(error -> {
                  System.err.println("\nError: " + error.getMessage());
                  latch.countDown();
              })
              .start();

        try {
            latch.await();
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }


    }

    private static void streamingWithFuture(StreamingAssistant assistant) throws Exception {
        CompletableFuture<String> completeResponse = new CompletableFuture<>();
        StringBuilder fullResponse = new StringBuilder();

        String message = "Explain machine learning in simple terms";
        System.out.println("message: " + message);
        System.out.println("streaming response:");
        TokenStream stream = assistant.chatStream(message);

        stream
                .onPartialResponse(partial -> {
                    System.out.print(partial);
                    fullResponse.append(partial);
                })
                .onCompleteResponse(response -> {
                    completeResponse.complete(fullResponse.toString());
                })
                .onError(error -> {
                    completeResponse.completeExceptionally(error);
                })
                .start();

        // Wait for completion
        String finalResponse = completeResponse.get();
        System.out.println("\n\nFinal response length: " + finalResponse.length() + " characters");
    }

    private static void streamingWithCancellation(StreamingAssistant assistant) {
        CountDownLatch latch = new CountDownLatch(1);

        String message = "Write a long story about a dragon";
        System.out.println("message: " + message);
        TokenStream stream = assistant.chatStream(message);
        System.out.println("Streaming response:");

        AtomicInteger tokenCount = new AtomicInteger(0);
        stream.onPartialResponseWithContext((partial, context) -> {
                  System.out.print(partial.text());
                  int count = tokenCount.incrementAndGet();
                  // Simulate cancellation after 10 tokens
                  if (count >= 10) {
                      System.out.println("\n\n[CANCELLED after " + count + " tokens]");
                      latch.countDown();
                      context.streamingHandle().cancel();
                  }
              }).onCompleteResponse(response -> {
                  System.out.println("\n\nCompleted: " + tokenCount.get() + " tokens");
                  latch.countDown();
              })
              .onError(error -> {
                  error.printStackTrace();
                  latch.countDown();
              })
              .start();
        try {
            latch.await();
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
    }
}