/*
 * Decompiled with CFR 0.152.
 */
package io.quarkiverse.mcp.server.websocket.runtime;

import io.quarkiverse.mcp.server.InitialCheck;
import io.quarkiverse.mcp.server.InitialRequest;
import io.quarkiverse.mcp.server.runtime.ConnectionManager;
import io.quarkiverse.mcp.server.runtime.ContextSupport;
import io.quarkiverse.mcp.server.runtime.McpConnectionBase;
import io.quarkiverse.mcp.server.runtime.McpMessageHandler;
import io.quarkiverse.mcp.server.runtime.McpMetadata;
import io.quarkiverse.mcp.server.runtime.McpRequest;
import io.quarkiverse.mcp.server.runtime.McpRequestImpl;
import io.quarkiverse.mcp.server.runtime.NotificationManagerImpl;
import io.quarkiverse.mcp.server.runtime.PromptCompletionManagerImpl;
import io.quarkiverse.mcp.server.runtime.PromptManagerImpl;
import io.quarkiverse.mcp.server.runtime.ResourceManagerImpl;
import io.quarkiverse.mcp.server.runtime.ResourceTemplateCompletionManagerImpl;
import io.quarkiverse.mcp.server.runtime.ResourceTemplateManagerImpl;
import io.quarkiverse.mcp.server.runtime.ResponseHandlers;
import io.quarkiverse.mcp.server.runtime.SecuritySupport;
import io.quarkiverse.mcp.server.runtime.Sender;
import io.quarkiverse.mcp.server.runtime.ToolManagerImpl;
import io.quarkiverse.mcp.server.runtime.config.McpServerRuntimeConfig;
import io.quarkiverse.mcp.server.runtime.config.McpServersRuntimeConfig;
import io.quarkiverse.mcp.server.websocket.runtime.WebSocketMcpConnection;
import io.quarkus.arc.All;
import io.quarkus.security.identity.CurrentIdentityAssociation;
import io.quarkus.security.identity.SecurityIdentity;
import io.quarkus.websockets.next.OnClose;
import io.quarkus.websockets.next.OnOpen;
import io.quarkus.websockets.next.OnTextMessage;
import io.quarkus.websockets.next.UserData;
import io.quarkus.websockets.next.WebSocketConnection;
import io.smallrye.mutiny.Uni;
import io.smallrye.mutiny.vertx.UniHelper;
import io.vertx.core.Future;
import io.vertx.core.Vertx;
import io.vertx.core.json.Json;
import jakarta.enterprise.inject.Instance;
import java.util.List;
import org.jboss.logging.Logger;

public abstract class WebSocketMcpMessageHandler
extends McpMessageHandler<WebSocketMcpRequest> {
    private static final Logger LOG = Logger.getLogger(WebSocketMcpMessageHandler.class);
    private static final String MCP_CONNECTION_ID = "mcpConnectionId";
    CurrentIdentityAssociation currentIdentityAssociation;

    protected WebSocketMcpMessageHandler(McpServersRuntimeConfig config, ConnectionManager connectionManager, PromptManagerImpl promptManager, ToolManagerImpl toolManager, ResourceManagerImpl resourceManager, PromptCompletionManagerImpl promptCompleteManager, ResourceTemplateManagerImpl resourceTemplateManager, ResourceTemplateCompletionManagerImpl resourceTemplateCompleteManager, NotificationManagerImpl initManager, ResponseHandlers responseHandlers, McpMetadata metadata, Vertx vertx, @All List<InitialCheck> initialChecks, Instance<CurrentIdentityAssociation> currentIdentityAssociation) {
        super(config, connectionManager, promptManager, toolManager, resourceManager, promptCompleteManager, resourceTemplateManager, resourceTemplateCompleteManager, initManager, responseHandlers, metadata, vertx, initialChecks);
        this.currentIdentityAssociation = currentIdentityAssociation.isResolvable() ? (CurrentIdentityAssociation)currentIdentityAssociation.get() : null;
    }

    protected abstract String serverName();

    @OnOpen
    void openConnection(WebSocketConnection connection) {
        String id = ConnectionManager.connectionId();
        WebSocketMcpConnection mcpConnection = new WebSocketMcpConnection(id, (McpServerRuntimeConfig)this.config.servers().get(this.serverName()), connection);
        this.connectionManager.add((McpConnectionBase)mcpConnection);
        LOG.debugf("WebSocket connection initialized [%s]", (Object)id);
        connection.userData().put(UserData.TypedKey.forString((String)MCP_CONNECTION_ID), (Object)id);
    }

    @OnTextMessage
    Uni<Void> consumeMessage(WebSocketConnection connection, String message) {
        SecuritySupport securitySupport;
        String connectionId = (String)connection.userData().get(UserData.TypedKey.forString((String)MCP_CONNECTION_ID));
        WebSocketMcpConnection mcpConnection = (WebSocketMcpConnection)this.connectionManager.get(connectionId);
        Object json = Json.decodeValue((String)message);
        if (this.currentIdentityAssociation != null) {
            final SecurityIdentity securityIdentity = this.currentIdentityAssociation.getIdentity();
            securitySupport = new SecuritySupport(){

                public void setCurrentIdentity(CurrentIdentityAssociation currentIdentityAssociation) {
                    currentIdentityAssociation.setIdentity(securityIdentity);
                }
            };
        } else {
            securitySupport = null;
        }
        WebSocketMcpRequest mcpRequest = new WebSocketMcpRequest(this.serverName(), json, mcpConnection, securitySupport, null, this.currentIdentityAssociation);
        return UniHelper.toUni((Future)this.handle((McpRequest)mcpRequest));
    }

    @OnClose
    void closeConnection(WebSocketConnection connection) {
        String id = (String)connection.userData().get(UserData.TypedKey.forString((String)MCP_CONNECTION_ID));
        this.connectionManager.remove(id);
        LOG.debugf("WebSocket connection closed [%s]", (Object)id);
    }

    protected InitialRequest.Transport transport() {
        return InitialRequest.Transport.WEBSOCKET;
    }

    static class WebSocketMcpRequest
    extends McpRequestImpl<WebSocketMcpConnection> {
        WebSocketMcpRequest(String serverName, Object json, WebSocketMcpConnection connection, SecuritySupport securitySupport, ContextSupport requestContextSupport, CurrentIdentityAssociation currentIdentityAssociation) {
            super(serverName, json, (McpConnectionBase)connection, (Sender)connection, securitySupport, requestContextSupport, currentIdentityAssociation);
        }
    }
}

