/*
 * Decompiled with CFR 0.152.
 */
package io.quarkiverse.mcp.server.sse.runtime;

import io.quarkiverse.mcp.server.CompletionManager;
import io.quarkiverse.mcp.server.CompletionResponse;
import io.quarkiverse.mcp.server.InitialCheck;
import io.quarkiverse.mcp.server.InitialRequest;
import io.quarkiverse.mcp.server.PromptManager;
import io.quarkiverse.mcp.server.ResourceManager;
import io.quarkiverse.mcp.server.ResourceTemplateManager;
import io.quarkiverse.mcp.server.ToolManager;
import io.quarkiverse.mcp.server.runtime.ConnectionManager;
import io.quarkiverse.mcp.server.runtime.ContextSupport;
import io.quarkiverse.mcp.server.runtime.FeatureArgument;
import io.quarkiverse.mcp.server.runtime.FeatureMetadata;
import io.quarkiverse.mcp.server.runtime.McpConnectionBase;
import io.quarkiverse.mcp.server.runtime.McpMessageHandler;
import io.quarkiverse.mcp.server.runtime.McpMetadata;
import io.quarkiverse.mcp.server.runtime.McpRequest;
import io.quarkiverse.mcp.server.runtime.McpRequestImpl;
import io.quarkiverse.mcp.server.runtime.Messages;
import io.quarkiverse.mcp.server.runtime.NotificationManagerImpl;
import io.quarkiverse.mcp.server.runtime.PromptCompletionManagerImpl;
import io.quarkiverse.mcp.server.runtime.PromptManagerImpl;
import io.quarkiverse.mcp.server.runtime.ResourceManagerImpl;
import io.quarkiverse.mcp.server.runtime.ResourceTemplateCompletionManagerImpl;
import io.quarkiverse.mcp.server.runtime.ResourceTemplateManagerImpl;
import io.quarkiverse.mcp.server.runtime.ResponseHandlers;
import io.quarkiverse.mcp.server.runtime.SecuritySupport;
import io.quarkiverse.mcp.server.runtime.Sender;
import io.quarkiverse.mcp.server.runtime.ToolManagerImpl;
import io.quarkiverse.mcp.server.runtime.config.McpServerRuntimeConfig;
import io.quarkiverse.mcp.server.runtime.config.McpServersRuntimeConfig;
import io.quarkiverse.mcp.server.sse.runtime.StreamableHttpMcpConnection;
import io.quarkus.arc.All;
import io.quarkus.security.identity.CurrentIdentityAssociation;
import io.quarkus.security.identity.SecurityIdentity;
import io.quarkus.vertx.http.runtime.CurrentVertxRequest;
import io.quarkus.vertx.http.runtime.security.QuarkusHttpUser;
import io.vertx.core.Future;
import io.vertx.core.Handler;
import io.vertx.core.Vertx;
import io.vertx.core.buffer.Buffer;
import io.vertx.core.http.HttpHeaders;
import io.vertx.core.http.HttpServerRequest;
import io.vertx.core.http.HttpServerResponse;
import io.vertx.core.json.Json;
import io.vertx.core.json.JsonArray;
import io.vertx.core.json.JsonObject;
import io.vertx.ext.web.RoutingContext;
import jakarta.enterprise.inject.Instance;
import jakarta.inject.Singleton;
import java.util.List;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import org.jboss.logging.Logger;

@Singleton
public class StreamableHttpMcpMessageHandler
extends McpMessageHandler<HttpMcpRequest>
implements Handler<RoutingContext> {
    private static final Logger LOG = Logger.getLogger(StreamableHttpMcpMessageHandler.class);
    public static final String MCP_SESSION_ID_HEADER = "Mcp-Session-Id";
    private final McpMetadata metadata;
    private final CurrentVertxRequest currentVertxRequest;
    private final CurrentIdentityAssociation currentIdentityAssociation;
    private static final Set<String> FORCE_SSE_REQUESTS = Set.of("tools/call", "prompts/get", "resources/read", "completion/complete");
    private static final Set<FeatureArgument.Provider> FORCE_SSE_PROVIDERS = Set.of(FeatureArgument.Provider.PROGRESS, FeatureArgument.Provider.MCP_LOG, FeatureArgument.Provider.SAMPLING, FeatureArgument.Provider.ROOTS, FeatureArgument.Provider.ELICITATION);

    StreamableHttpMcpMessageHandler(McpServersRuntimeConfig config, ConnectionManager connectionManager, PromptManagerImpl promptManager, ToolManagerImpl toolManager, ResourceManagerImpl resourceManager, PromptCompletionManagerImpl promptCompleteManager, ResourceTemplateManagerImpl resourceTemplateManager, ResourceTemplateCompletionManagerImpl resourceTemplateCompleteManager, NotificationManagerImpl notificationManager, ResponseHandlers responseHandlers, @All List<InitialCheck> initialChecks, CurrentVertxRequest currentVertxRequest, Instance<CurrentIdentityAssociation> currentIdentityAssociation, McpMetadata metadata, Vertx vertx) {
        super(config, connectionManager, promptManager, toolManager, resourceManager, promptCompleteManager, resourceTemplateManager, resourceTemplateCompleteManager, notificationManager, responseHandlers, metadata, vertx, initialChecks);
        this.metadata = metadata;
        this.currentVertxRequest = currentVertxRequest;
        this.currentIdentityAssociation = currentIdentityAssociation.isResolvable() ? (CurrentIdentityAssociation)currentIdentityAssociation.get() : null;
    }

    public void handle(final RoutingContext ctx) {
        Object json;
        StreamableHttpMcpConnection connection;
        String serverName = (String)ctx.get("mcp.sse.server-name");
        if (serverName == null) {
            throw new IllegalStateException("Server name not defined");
        }
        HttpServerRequest request = ctx.request();
        List accepts = ctx.request().headers().getAll(HttpHeaders.ACCEPT);
        if (!this.accepts(accepts, "application/json") || !this.accepts(accepts, "text/event-stream")) {
            LOG.errorf("Invalid Accept header: %s", (Object)accepts);
            ctx.fail(400);
            return;
        }
        String mcpSessionId = request.getHeader(MCP_SESSION_ID_HEADER);
        if (mcpSessionId == null) {
            String id = ConnectionManager.connectionId();
            LOG.debugf("Streamable connection initialized [%s]", (Object)id);
            McpServerRuntimeConfig serverConfig = (McpServerRuntimeConfig)this.config.servers().get(serverName);
            connection = new StreamableHttpMcpConnection(id, serverConfig);
            this.connectionManager.add((McpConnectionBase)connection);
        } else {
            McpConnectionBase conn = this.connectionManager.get(mcpSessionId);
            if (conn == null) {
                LOG.errorf("Mcp session not found: %s", (Object)mcpSessionId);
                ctx.fail(404);
                return;
            }
            if (conn instanceof StreamableHttpMcpConnection) {
                StreamableHttpMcpConnection streamable;
                connection = streamable = (StreamableHttpMcpConnection)conn;
            } else {
                throw new IllegalStateException("Invalid connection type: " + conn.getClass().getName());
            }
        }
        try {
            json = Json.decodeValue((Buffer)ctx.body().buffer());
        }
        catch (Exception e) {
            String msg = "Unable to parse the JSON message";
            LOG.errorf((Throwable)e, msg, new Object[0]);
            ctx.response().putHeader(HttpHeaders.CONTENT_TYPE, (CharSequence)"application/json");
            ctx.end(Messages.newError(null, (int)-32700, (String)msg).toBuffer());
            return;
        }
        final QuarkusHttpUser user = (QuarkusHttpUser)ctx.user();
        SecuritySupport securitySupport = new SecuritySupport(){

            public void setCurrentIdentity(CurrentIdentityAssociation currentIdentityAssociation) {
                if (user != null) {
                    SecurityIdentity identity = user.getSecurityIdentity();
                    currentIdentityAssociation.setIdentity(identity);
                } else {
                    currentIdentityAssociation.setIdentity(QuarkusHttpUser.getSecurityIdentity((RoutingContext)ctx, null));
                }
            }
        };
        ContextSupport contextSupport = new ContextSupport(){

            public void requestContextActivated() {
                StreamableHttpMcpMessageHandler.this.currentVertxRequest.setCurrent(ctx);
            }
        };
        HttpMcpRequest mcpRequest = new HttpMcpRequest(serverName, json, connection, securitySupport, ctx.response(), mcpSessionId == null, contextSupport, this.currentIdentityAssociation);
        ScanResult result = this.scan(mcpRequest);
        if (result.forceSseInit()) {
            mcpRequest.initiateSse();
        }
        this.handle((McpRequest)mcpRequest).onComplete(ar -> {
            if (ar.succeeded()) {
                if (mcpRequest.sse.get()) {
                    ctx.response().end();
                } else if (!ctx.response().ended()) {
                    if (!result.containsRequest()) {
                        ctx.response().setStatusCode(202).end();
                    } else {
                        ctx.end();
                    }
                }
            } else if (!ctx.response().ended()) {
                ctx.response().setStatusCode(500).end();
            }
        });
    }

    public void terminateSession(RoutingContext ctx) {
        HttpServerRequest request = ctx.request();
        String mcpSessionId = request.getHeader(MCP_SESSION_ID_HEADER);
        if (mcpSessionId == null) {
            LOG.errorf("Mcp session id header is missing: %s", (Object)ctx.normalizedPath());
            ctx.fail(404);
            return;
        }
        McpConnectionBase connection = this.connectionManager.get(mcpSessionId);
        if (connection == null) {
            LOG.errorf("Mcp session not found: %s", (Object)mcpSessionId);
            ctx.fail(404);
            return;
        }
        if (this.connectionManager.remove(connection.id())) {
            LOG.infof("Mcp session terminated: %s", (Object)connection.id());
        }
        ctx.end();
    }

    protected void afterInitialize(HttpMcpRequest mcpRequest) {
        mcpRequest.response.headers().add(MCP_SESSION_ID_HEADER, ((StreamableHttpMcpConnection)mcpRequest.connection()).id());
    }

    protected void initializeFailed(HttpMcpRequest mcpRequest) {
        this.connectionManager.remove(((StreamableHttpMcpConnection)mcpRequest.connection()).id());
    }

    protected void jsonrpcValidationFailed(HttpMcpRequest mcpRequest) {
        if (mcpRequest.newSession) {
            this.connectionManager.remove(((StreamableHttpMcpConnection)mcpRequest.connection()).id());
        }
    }

    protected InitialRequest.Transport transport() {
        return InitialRequest.Transport.STREAMABLE_HTTP;
    }

    private boolean accepts(List<String> accepts, String contentType) {
        for (String accept : accepts) {
            if (!accept.contains(contentType)) continue;
            return true;
        }
        return false;
    }

    private ScanResult scan(HttpMcpRequest mcpRequest) {
        boolean forceSseInit = false;
        boolean containsRequest = false;
        Object object = mcpRequest.json();
        if (object instanceof JsonObject) {
            JsonObject message = (JsonObject)object;
            forceSseInit = this.forceSse(mcpRequest, message);
            containsRequest = Messages.isRequest((JsonObject)message);
        } else {
            JsonArray batch;
            object = mcpRequest.json();
            if (object instanceof JsonArray && !Messages.isResponse((JsonObject)(batch = (JsonArray)object).getJsonObject(0))) {
                forceSseInit = batch.size() > 1 || this.forceSse(mcpRequest, batch.getJsonObject(0));
                for (Object e : batch) {
                    JsonObject message;
                    if (!(e instanceof JsonObject) || !Messages.isRequest((JsonObject)(message = (JsonObject)e))) continue;
                    containsRequest = true;
                    break;
                }
            }
        }
        return new ScanResult(forceSseInit, containsRequest);
    }

    private boolean forceSse(HttpMcpRequest mcpRequest, JsonObject message) {
        JsonObject params;
        String method = message.getString("method");
        if (method != null && Messages.isRequest((JsonObject)message) && FORCE_SSE_REQUESTS.contains(method) && (params = message.getJsonObject("params")) != null) {
            return switch (method) {
                case "tools/call" -> this.forceSseTool(params);
                case "prompts/get" -> this.forceSsePrompt(params);
                case "resources/read" -> this.forceSseResource(params);
                case "completion/complete" -> this.forceSseCompletion(params);
                default -> throw new IllegalArgumentException("Unexpected value: " + method);
            };
        }
        return false;
    }

    private boolean forceSseTool(JsonObject params) {
        String name = params.getString("name");
        FeatureMetadata fm = McpMetadata.findFeatureByName((List)this.metadata.tools(), (String)name);
        if (fm != null) {
            for (FeatureArgument a : fm.info().arguments()) {
                if (!FORCE_SSE_PROVIDERS.contains(a.provider())) continue;
                return true;
            }
        } else {
            ToolManager.ToolInfo info = this.toolManager.getTool(name);
            if (info != null && !info.isMethod()) {
                return true;
            }
        }
        return false;
    }

    private boolean forceSsePrompt(JsonObject params) {
        String name = params.getString("name");
        FeatureMetadata fm = McpMetadata.findFeatureByName((List)this.metadata.prompts(), (String)name);
        if (fm != null) {
            for (FeatureArgument a : fm.info().arguments()) {
                if (!FORCE_SSE_PROVIDERS.contains(a.provider())) continue;
                return true;
            }
        } else {
            PromptManager.PromptInfo info = this.promptManager.getPrompt(name);
            if (info != null && !info.isMethod()) {
                return true;
            }
        }
        return false;
    }

    private boolean forceSseResource(JsonObject params) {
        String resourceUri = params.getString("uri");
        FeatureMetadata fm = this.metadata.resources().stream().filter(m -> m.info().uri().equals(resourceUri)).findFirst().orElse(null);
        if (fm != null) {
            for (FeatureArgument a : fm.info().arguments()) {
                if (!FORCE_SSE_PROVIDERS.contains(a.provider())) continue;
                return true;
            }
        } else {
            ResourceTemplateManager.ResourceTemplateInfo rti;
            ResourceManager.ResourceInfo info = this.resourceManager.getResource(resourceUri);
            if (info != null ? !info.isMethod() : (rti = this.resourceTemplateManager.findMatching(resourceUri)) != null && !rti.isMethod()) {
                return true;
            }
        }
        return false;
    }

    private boolean forceSseCompletion(JsonObject params) {
        JsonObject ref = params.getJsonObject("ref");
        if (ref != null) {
            String argumentName;
            String referenceType = ref.getString("type");
            String referenceName = ref.getString("name");
            JsonObject argument = params.getJsonObject("argument");
            String string = argumentName = argument != null ? argument.getString("name") : null;
            if (referenceName != null && argumentName != null) {
                if ("ref/prompt".equals(referenceType)) {
                    return this.forceSseCompletion(referenceName, argumentName, this.metadata.promptCompletions(), (CompletionManager)this.promptCompletionManager);
                }
                if ("ref/resource".equals(referenceType)) {
                    return this.forceSseCompletion(referenceName, argumentName, this.metadata.resourceTemplateCompletions(), (CompletionManager)this.resourceTemplateCompletionManager);
                }
            }
        }
        return false;
    }

    private boolean forceSseCompletion(String referenceName, String argumentName, List<FeatureMetadata<CompletionResponse>> completions, CompletionManager completionManager) {
        FeatureMetadata fm = completions.stream().filter(m -> m.info().name().equals(referenceName) && argumentName.equals(m.info().arguments().stream().filter(FeatureArgument::isParam).findFirst().orElseThrow().name())).findFirst().orElse(null);
        if (fm != null) {
            for (FeatureArgument a : fm.info().arguments()) {
                if (!FORCE_SSE_PROVIDERS.contains(a.provider())) continue;
                return true;
            }
        } else {
            CompletionManager.CompletionInfo info = completionManager.getCompletion(referenceName, argumentName);
            if (info != null && !info.isMethod()) {
                return true;
            }
        }
        return false;
    }

    static class HttpMcpRequest
    extends McpRequestImpl<StreamableHttpMcpConnection>
    implements Sender {
        final boolean newSession;
        final AtomicBoolean sse;
        final HttpServerResponse response;

        public HttpMcpRequest(String serverName, Object json, StreamableHttpMcpConnection connection, SecuritySupport securitySupport, HttpServerResponse response, boolean newSession, ContextSupport contextSupport, CurrentIdentityAssociation currentIdentityAssociation) {
            super(serverName, json, (McpConnectionBase)connection, null, securitySupport, contextSupport, currentIdentityAssociation);
            this.newSession = newSession;
            this.sse = new AtomicBoolean(false);
            this.response = response;
        }

        public Sender sender() {
            return this;
        }

        boolean initiateSse() {
            if (this.sse.compareAndSet(false, true)) {
                this.response.setChunked(true);
                this.response.headers().add(HttpHeaders.CONTENT_TYPE, (CharSequence)"text/event-stream");
                return true;
            }
            return false;
        }

        public Future<Void> send(JsonObject message) {
            if (message == null) {
                return Future.succeededFuture();
            }
            this.messageSent(message);
            if (this.sse.get()) {
                return this.response.write("event: message\ndata: " + message.encode() + "\n\n");
            }
            if (this.response.ended()) {
                LOG.debugf("HTTP response ended, try to use a subsidiary SSE channel instead", new Object[0]);
                return ((StreamableHttpMcpConnection)this.connection()).send(message);
            }
            this.response.putHeader(HttpHeaders.CONTENT_TYPE, (CharSequence)"application/json");
            return this.response.end(message.toBuffer());
        }
    }

    record ScanResult(boolean forceSseInit, boolean containsRequest) {
    }
}

