package uk.ac.warwick.sso.client;

import net.logstash.logback.argument.StructuredArguments;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import uk.ac.warwick.sso.client.util.cookies.ServerCookieEncoder;
import uk.ac.warwick.userlookup.User;

import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.util.*;

import static org.springframework.web.servlet.HandlerMapping.BEST_MATCHING_PATTERN_ATTRIBUTE;

class CSRFImpl {
    static final String CSRF_HTTP_HEADER = "X-CSRF-Token";

    private static final Logger LOGGER = LoggerFactory.getLogger("uk.ac.warwick.SECURITY_REPORTS");

    // DO NOT remove this cookie prefix
    static final String CSRF_COOKIE_NAME = "__Host-SSO-CSRF";

    static final String CSRF_TOKEN_PROPERTY_NAME = "urn:websignon:csrf";

    static final String CSRF_FORCE_INVALIDATE = "urn:websignon:csrf:invalidate";

    static String CSRF_ERROR = "urn:websignon:csrf:error";

    static String CSRF_ERROR_TOKEN_ABSENT = "urn:websignon:csrf:error:absent";

    static String CSRF_ERROR_TOKEN_MISMATCH = "urn:websignon:csrf:error:mismatch";

    private ServerCookieEncoder encoder = new ServerCookieEncoder(false);

    private boolean reportOnlyMode;

    void setReportOnlyMode(boolean reportOnlyMode) {
        this.reportOnlyMode = reportOnlyMode;
    }

    boolean handle(HttpServletRequest request, HttpServletResponse response) throws IOException {
        User user = SSOClientFilter.getUserFromRequest(request);
        if (user.isFoundUser() && user.isLoggedIn()) {
            String term = this.reportOnlyMode ? "reporting" : "rejecting";
            Optional<Cookie> csrfCookie = Arrays.stream(this.getRequestCookiesSafe(request)).filter(c -> c.getName().equals(CSRF_COOKIE_NAME)).findFirst();
            boolean wasNull = false;
            String csrfToken;
            if (!csrfCookie.isPresent() || request.getAttribute(CSRF_FORCE_INVALIDATE) != null) {
                LOGGER.debug(csrfCookie.isPresent() ? "Forcing invalidation of token due to CSRF_FORCE_INVALIDATE" : "Couldn't find cookie with name " + CSRF_COOKIE_NAME);
                csrfToken = UUID.randomUUID().toString();
                this.addCsrfCookie(request, response, csrfToken);
                wasNull = true;
            } else {
                csrfToken = csrfCookie.get().getValue();
            }

            request.setAttribute(CSRF_TOKEN_PROPERTY_NAME, csrfToken);

            if (request.getMethod().equalsIgnoreCase("post")) {
                if (wasNull) {
                    LOGGER.warn("User didn't have a CSRF token known to the system, and they immediately POST'd.");
                }
                // Allow the token to be provided either as a POST param, or in an HTTP header
                final String providedToken;
                if (request.getParameterMap().containsKey(CSRF_TOKEN_PROPERTY_NAME)) {
                    providedToken = request.getParameter(CSRF_TOKEN_PROPERTY_NAME);
                } else {
                    providedToken = request.getHeader(CSRF_HTTP_HEADER);
                }

                if (providedToken == null || providedToken.length() == 0) {
                    this.logWithRequest(request, response, String.format("No CSRF token was provided in the POST; %s POST request to %s", term, request.getRequestURL()), csrfToken, providedToken, user);

                    response.setHeader("X-Error", "No CSRF token");
                    if (!this.reportOnlyMode) {
                        request.setAttribute(CSRF_ERROR, CSRF_ERROR_TOKEN_ABSENT);
                        response.sendError(HttpServletResponse.SC_BAD_REQUEST);
                        return false;
                    }
                } else if (!MessageDigest.isEqual(providedToken.getBytes(StandardCharsets.UTF_8), csrfToken.getBytes(StandardCharsets.UTF_8))) {
                    this.logWithRequest(request, response, String.format("Provided CSRF token does not match stored CSRF token; %s POST request to %s", term, request.getRequestURL()), csrfToken, providedToken, user);

                    response.setHeader("X-Error", "Wrong CSRF token");
                    if (!this.reportOnlyMode) {
                        request.setAttribute(CSRF_ERROR, CSRF_ERROR_TOKEN_MISMATCH);
                        response.sendError(HttpServletResponse.SC_BAD_REQUEST);
                        return false;
                    }
                } else {
                    LOGGER.debug("Allowing CSRF request through as token matches");
                }
            }
        }

        return true;
    }

    private Cookie[] getRequestCookiesSafe(HttpServletRequest request) {
        if (request.getCookies() == null) {
            return new Cookie[]{};
        }
        return request.getCookies();
    }

    private void addCsrfCookie(HttpServletRequest request, HttpServletResponse response, String csrfToken) {
        uk.ac.warwick.sso.client.core.Cookie cookie = new uk.ac.warwick.sso.client.core.Cookie(CSRF_COOKIE_NAME, csrfToken);
        cookie.setHttpOnly(true); // doesn't really matter, it's in the DOM
        // cookie.setSameSite(uk.ac.warwick.sso.client.core.Cookie.SameSiteValue.LAX); - SSO-2456
        cookie.setMaxAge(-1);
        cookie.setPath("/");
        cookie.setSecure(true);
        response.addHeader("Set-Cookie", this.encoder.encode(cookie));
    }

    private void logWithRequest(HttpServletRequest request, HttpServletResponse response, String message, String expected, String actual, User user) {
        String correlationId = UUID.randomUUID().toString();

        Map<String, Object> json = new HashMap<>();
        json.put("request_headers", new HashMap<String, Object>() {{
            put("user-agent", Optional.ofNullable(request.getHeader("User-Agent")).orElse("-"));
            put("x-requested-with", Optional.ofNullable(request.getHeader("X-Requested-With")).orElse("-"));
            put("referer", Optional.ofNullable(request.getHeader("Referer")).orElse("-"));
        }});
        json.put("csrf-report", new HashMap<String, Object>() {{
            put("document-uri", request.getRequestURL());
            Optional.ofNullable(request.getAttribute(BEST_MATCHING_PATTERN_ATTRIBUTE)).ifPresent(pattern ->
                    put("best-matching-pattern", pattern)
            );
            put("method", request.getMethod());
            put("expected-token", expected);
            put("actual-token", actual == null ? "-" : actual);
            put("error", message);
            put("correlation-id", correlationId);
        }});
        json.put("username", user.getUserId());

        LOGGER.warn("{}", StructuredArguments.entries(json));
        response.setHeader("X-Correlation-ID", correlationId);
    }
}
