package uk.ac.warwick.util.ais.auth.token;

import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.http.HttpResponse;
import org.apache.http.client.entity.EntityBuilder;
import org.apache.http.client.methods.HttpUriRequest;
import org.apache.http.client.methods.RequestBuilder;
import org.apache.http.entity.ContentType;
import org.apache.http.message.BasicNameValuePair;
import org.apache.http.util.EntityUtils;
import uk.ac.warwick.util.ais.auth.exception.TokenFetchException;
import uk.ac.warwick.util.ais.auth.model.OAuth2TokenFetchParameters;

import java.io.IOException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;

/**
 * The OAuth2TokenFetcher is responsible for fetching the access token from the Authorization Server.
 */
public final class DefaultOAuth2TokenFetcher implements Function<OAuth2TokenFetchParameters, AccessToken> {

    private static final int NON_STANDARD_HTTP_STATUS_CODE = -1;
    private final ObjectMapper objectMapper;
    private final Function<HttpUriRequest, CompletableFuture<HttpResponse>> requestExecutor;

    public DefaultOAuth2TokenFetcher(ObjectMapper objectMapper,
                                     Function<HttpUriRequest, CompletableFuture<HttpResponse>> requestExecutor) {
        this.objectMapper = objectMapper;
        this.requestExecutor = requestExecutor;
    }

    @Override
    public AccessToken apply(OAuth2TokenFetchParameters parameters) {
        HttpUriRequest accessTokenRequest = createAccessTokenRequest(parameters);
        return sendRequest(accessTokenRequest);
    }

    private AccessToken sendRequest(HttpUriRequest request) {
        try {
            // Make HTTP request to authority to acquire the new access token
            HttpResponse response = requestExecutor.apply(request).get(30, TimeUnit.SECONDS);

            return handleResponse(response);
        } catch (Throwable ex) {
            Throwable cause = ex;
            if (ex instanceof InterruptedException) {
                Thread.currentThread().interrupt(); // Restore the interrupted status
            }
            if (ex instanceof ExecutionException) {
                cause = ex.getCause(); // Unwrap the ExecutionException
            }

            throw new TokenFetchException(
                    String.format("An error occurred while fetching access token from the Authorization Server: %s.", cause.getMessage()),
                    NON_STANDARD_HTTP_STATUS_CODE,
                    null,
                    cause);
        }
    }

    private AccessToken handleResponse(HttpResponse httpResponse) throws IOException {

        // Normally, it should return 200 OK with the access token
        // or 4xx (Bad Request, Unauthorized, ect) if the credentials are invalid or missing required parameters
        // or 5xx Bad Request if the request is invalid or missing required parameters
        int statusCode = httpResponse.getStatusLine().getStatusCode();
        String reasonPhrase = httpResponse.getStatusLine().getReasonPhrase();
        String responseBody = EntityUtils.toString(httpResponse.getEntity());

        // Cases where the response is not 200 OK, i.e. 3xx, 4xx or 5xx
        if (statusCode != 200) {
            throw new TokenFetchException(
                    String.format("Failed to acquire access token from the Authorization Server: %d - %s.", statusCode, reasonPhrase),
                    statusCode,
                    responseBody);
        }

        return objectMapper.readValue(responseBody, AccessToken.class);
    }

    private HttpUriRequest createAccessTokenRequest(OAuth2TokenFetchParameters parameters) {
        return RequestBuilder.post(parameters.getAuthority())
                .setEntity(EntityBuilder.create()
                        .setContentType(ContentType.APPLICATION_FORM_URLENCODED)
                        .setParameters(
                                new BasicNameValuePair("grant_type", parameters.getGrantType()),
                                new BasicNameValuePair("client_id", parameters.getClientId()),
                                new BasicNameValuePair("client_secret", parameters.getClientSecret()),
                                new BasicNameValuePair("scope", parameters.getScope())
                        ).build())
                .build();
    }
}
