package uk.ac.warwick.util.ais.auth.token;

import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.http.HttpResponse;
import org.apache.http.ProtocolVersion;
import org.apache.http.client.methods.HttpEntityEnclosingRequestBase;
import org.apache.http.client.methods.HttpUriRequest;
import org.apache.http.entity.BasicHttpEntity;
import org.apache.http.entity.ContentType;
import org.apache.http.message.BasicHeader;
import org.apache.http.message.BasicHttpResponse;
import org.apache.http.util.EntityUtils;
import org.junit.Test;
import uk.ac.warwick.util.ais.auth.credentials.OAuth2ClientCredentials;
import uk.ac.warwick.util.ais.auth.exception.TokenFetchException;
import uk.ac.warwick.util.ais.auth.model.OAuth2TokenFetchParameters;
import uk.ac.warwick.util.ais.core.httpclient.HttpRequestExecutor;

import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.net.SocketTimeoutException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicReference;

import static org.junit.Assert.*;

public class DefaultOAuth2TokenFetcherTest {

    private static final OAuth2ClientCredentials credentials = new OAuth2ClientCredentials("authority", "clientId", "clientSecret");
    private final ObjectMapper objectMapper = new ObjectMapper();

    @Test
    public void apply_fetchTokenSuccess() {
        AtomicReference<HttpUriRequest> dumpRequest = new AtomicReference<>();
        HttpRequestExecutor<HttpUriRequest, HttpResponse> requestExecutor = req -> {
            dumpRequest.set(req);
            return CompletableFuture.completedFuture(createHttpResponse());
        };
        DefaultOAuth2TokenFetcher aisOAuth2TokenFetcher = new DefaultOAuth2TokenFetcher(objectMapper, requestExecutor::execute);
        OAuth2TokenFetchParameters parameters = new OAuth2TokenFetchParameters(credentials, "scope");
        AccessToken accessToken = aisOAuth2TokenFetcher.apply(parameters);

        assertNotNull(accessToken);
        assertEquals("token", accessToken.getTokenValue());
        assertEquals("Bearer", accessToken.getTokenType());
        assertEquals(3600L, (long) accessToken.getExpiresIn());

        // Check the request
        HttpUriRequest request = dumpRequest.get();
        assertEquals("POST", request.getMethod());
        assertEquals("authority", request.getURI().toString());
        assertTrue(request instanceof HttpEntityEnclosingRequestBase);
        HttpEntityEnclosingRequestBase entityRequest = (HttpEntityEnclosingRequestBase) request;
        try {
            String entity = EntityUtils.toString(entityRequest.getEntity());
            assertTrue(entity.contains("client_id=clientId"));
            assertTrue(entity.contains("client_secret=clientSecret"));
            assertTrue(entity.contains("scope=scope"));
            assertTrue(entity.contains("grant_type=client_credentials"));
        } catch (Exception e) {
            fail("Failed to read request entity");
        }
    }

    @Test(expected = TokenFetchException.class)
    public void apply_fetchTokenFailed_throwTokenFetchException() {
        HttpResponse errorResponse = new BasicHttpResponse(
                new ProtocolVersion("HTTP", 1, 1),
                400,
                "Bad Request"
        );
        errorResponse.setEntity(null);
        HttpRequestExecutor<HttpUriRequest, HttpResponse> requestExecutor = req -> CompletableFuture.completedFuture(errorResponse);
        DefaultOAuth2TokenFetcher aisOAuth2TokenFetcher = new DefaultOAuth2TokenFetcher(objectMapper, requestExecutor::execute);
        OAuth2TokenFetchParameters parameters = new OAuth2TokenFetchParameters(credentials, "scope");
        aisOAuth2TokenFetcher.apply(parameters);
    }

    @Test(expected = TokenFetchException.class)
    public void apply_requestTimeout_throwTokenFetchException() {
        HttpRequestExecutor<HttpUriRequest, HttpResponse> requestExecutor = req -> {
            CompletableFuture<HttpResponse> future = new CompletableFuture<>();
            future.completeExceptionally(new SocketTimeoutException("Request timeout"));
            return future;
        };

        DefaultOAuth2TokenFetcher aisOAuth2TokenFetcher = new DefaultOAuth2TokenFetcher(objectMapper, requestExecutor::execute);
        OAuth2TokenFetchParameters parameters = new OAuth2TokenFetchParameters(credentials, "scope");
        aisOAuth2TokenFetcher.apply(parameters);
    }

    private HttpResponse createHttpResponse() {
        String jsonToken = "{\"access_token\":\"token\",\"token_type\":\"Bearer\",\"expires_in\":3600}";

        InputStream inputStream = new ByteArrayInputStream(jsonToken.getBytes());
        BasicHttpEntity httpEntity = new BasicHttpEntity();
        httpEntity.setContent(inputStream);
        httpEntity.setContentLength(jsonToken.length());
        httpEntity.setContentType(ContentType.APPLICATION_JSON.toString());
        httpEntity.setContentEncoding(new BasicHeader("Content-Encoding", "UTF-8"));

        HttpResponse successResponse = new BasicHttpResponse(
                new ProtocolVersion("HTTP", 1, 1),
                200,
                "OK"
        );
        successResponse.setEntity(httpEntity);

        return successResponse;
    }
}
