package uk.ac.warwick.util.password;

import org.apache.http.HttpStatus;
import org.apache.http.util.EntityUtils;
import org.springframework.stereotype.Service;
import uk.ac.warwick.util.cache.*;
import uk.ac.warwick.util.collections.Pair;
import uk.ac.warwick.util.httpclient.httpclient4.HttpMethodExecutor;
import uk.ac.warwick.util.httpclient.httpclient4.SimpleHttpMethodExecutor;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.StringReader;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.ArrayList;
import java.util.stream.Collectors;

@Service
public class PwnedPasswordServiceImpl extends SingularCacheEntryFactory<String, ArrayList<Pair<String, Integer>>> implements PwnedPasswordService, CacheEntryFactory<String, ArrayList<Pair<String, Integer>>>  {
    private static final int PASSWORD_PREFIX_LENGTH = 6;
    private static final int MAXIMUM_CACHE_SIZE = 10000;
    private final String apiBaseUrl;
    private Cache<String, ArrayList<Pair<String, Integer>>> cache;

    public PwnedPasswordServiceImpl(final String apiBaseUrl, final String cacheStrategy) {
        this.apiBaseUrl = apiBaseUrl;
        this.cache = Caches.builder("PwnedPasswords", this, Caches.CacheStrategy.valueOf(cacheStrategy))
                .maximumSize(MAXIMUM_CACHE_SIZE)
                .expireAfterWrite(Duration.ofDays(1))
                .build();
    }

    @Override
    public int numMatches(String passwordHash) throws IOException, IllegalArgumentException {
        if (passwordHash.length() < PASSWORD_PREFIX_LENGTH) {
            throw new IllegalArgumentException("password hash must have length >= 7");
        }
        final String prefix = passwordHash.substring(0, 5);
        final String suffix = passwordHash.substring(5);

        final ArrayList<Pair<String, Integer>> searchResult;
        try {
            searchResult = this.cache.get(prefix);
        } catch (CacheEntryUpdateException e) {
            throw new IOException(e);
        }
        return countHashMatches(searchResult, suffix);
    }

    private int countHashMatches(final ArrayList<Pair<String, Integer>> matchedSuffixes, final String passwordSuffix) {
        for (Pair<String, Integer> suffixAndCount : matchedSuffixes) {
            if (passwordSuffix.equals(suffixAndCount.getLeft())) {
                return suffixAndCount.getRight();
            }
        }
        return 0;
    }

    private ArrayList<Pair<String, Integer>> rangeSearch(final String passwordHashPrefix) throws IOException {
        String url = apiBaseUrl + "/" + passwordHashPrefix;
        return new ArrayList<>(new SimpleHttpMethodExecutor(HttpMethodExecutor.Method.get, url)
                .execute(response -> {
                    if (response.getStatusLine().getStatusCode() > HttpStatus.SC_OK) {
                        throw new IOException("Response from 'haveibeenpwned' api had status: " + response.getStatusLine().getStatusCode() + ", reason: " + response.getStatusLine().getReasonPhrase());
                    }
                    String body = EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8);
                    BufferedReader reader = new BufferedReader(new StringReader(body));

                    return reader.lines()
                            .map(line -> {
                                String[] splitLine = line.split(":");
                                if (splitLine.length == 2) {
                                    return Pair.of(splitLine[0].trim(), Integer.parseInt(splitLine[1].trim()));
                                }
                                return null;
                            })
                            .collect(Collectors.toList());
                }).getRight());
    }

    @Override
    public ArrayList<Pair<String, Integer>> create(String key) throws CacheEntryUpdateException {
        try {
            return rangeSearch(key);
        } catch (IOException | IllegalArgumentException e) {
            throw new CacheEntryUpdateException(e);
        }
    }

    @Override
    public boolean shouldBeCached(ArrayList<Pair<String, Integer>> val) {
        return true;
    }
}
