package uk.ac.warwick.util.ais.auth.token;

import org.junit.Test;

import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;

import static org.junit.Assert.assertEquals;

public class TokenCacheTest {

    @Test
    public void getAccessToken_noCachedToken_returnsSuppliedToken() {
        TokenCache tokenCache = new TokenCache();
        AccessToken token = new AccessToken("token", "Bearer", 3600L);

        AccessToken result = tokenCache.getAccessToken("scope", () -> token);

        assertEquals(token, result);
    }

    @Test
    public void getAccessToken_expiringToken_fetchesNewToken() {
        TokenCache tokenCache = new TokenCache();
        AccessToken expiringToken = new AccessToken("expiringToken", "Bearer", 60L);
        AccessToken newToken = new AccessToken("newToken", "Bearer", 3600L);

        tokenCache.getAccessToken("scope", () -> expiringToken); // set a token that will expire soon
        AccessToken result = tokenCache.getAccessToken("scope", () -> newToken);

        assertEquals(newToken, result);
    }

    @Test
    public void getAccessToken_nonExpiringToken_returnsCachedToken() {
        TokenCache tokenCache = new TokenCache();
        AccessToken cachedToken = new AccessToken("expiringToken", "Bearer", 3600L);
        AccessToken newToken = new AccessToken("token", "Bearer", 3600L);

        tokenCache.getAccessToken("scope", () -> cachedToken); // set a token that still valid for a long time
        AccessToken result = tokenCache.getAccessToken("scope", () -> newToken);

        assertEquals(cachedToken, result);
    }

    @Test
    public void getAccessToken_whenConcurrentRequest_fetchOnlyOnce() throws InterruptedException {

        TokenCache tokenCache = new TokenCache();
        AccessToken newToken = new AccessToken("token", "Bearer", 3600L);
        AtomicInteger count = new AtomicInteger();
        Supplier<AccessToken> tokenSupplier = () -> {
            count.getAndIncrement();
            return newToken;
        };

        Thread thread1 = new Thread(() -> tokenCache.getAccessToken("scope", tokenSupplier));
        Thread thread2 = new Thread(() -> tokenCache.getAccessToken("scope", tokenSupplier));
        Thread thread3 = new Thread(() -> tokenCache.getAccessToken("scope", tokenSupplier));

        thread1.start();
        thread2.start();
        thread3.start();
        thread1.join();
        thread2.join();
        thread3.join();

        assertEquals(1, count.get());
    }
}
