package net.gorillagroove.track

import net.gorillagroove.util.GGLog.logCrit
import kotlin.math.abs
import kotlin.math.pow
import kotlin.random.Random

internal object TrackShuffler {
    fun shuffleTracks(
        currentTrack: NowPlayingTrack?,
        tracks: List<NowPlayingTrack>,
        leanValue: Double = NowPlayingService.shuffleLean,
        minimumPlayCount: Int = NowPlayingService.shuffleMinimumPlayCount,
        maximumPlayCount: Int = NowPlayingService.shuffleMaximumPlayCount,
    ): List<NowPlayingTrack> {
        require(leanValue >= -1.0 && leanValue <= 1.0) {
            "leanValue must be between 1 and -1"
        }
        require(minimumPlayCount >= 0 && maximumPlayCount >= 0) {
            "minimum and maximum play count parameters must not be negative!!"
        }

        val validTracks = tracks.filter { track ->
            track.track.playCount in minimumPlayCount..maximumPlayCount ||
                    // Don't remove the currently playing track, ever
                    track.nowPlayingTrackId == currentTrack?.nowPlayingTrackId
        }

        val newList = when {
            validTracks.isEmpty() -> emptyList()
            leanValue == 0.0 -> noLeanShuffle(validTracks)
            else -> {
                if (leanValue == 0.0) {
                    return noLeanShuffle(validTracks)
                }

                val lean = abs(leanValue)
                val weight = (lean * 10).pow((2.0 - (1 - lean))) / (15 - (lean * 7))
                val tree = TrackShuffleTree(
                    tracks = validTracks,
                    weightBump = weight,
                    reversed = leanValue < 0
                )

                validTracks.indices.map { tree.retrieveTrack() }
            }
        }

        // When we shuffle, the current track should be the first item. So, if it exists, find it and move it to the front.
        return if (currentTrack != null) {
            val index = newList.indexOfFirst { track -> track.nowPlayingTrackId == currentTrack.nowPlayingTrackId }
            if (index < 0) {
                logCrit("Failed to find current track when shuffling tracks!")
                return newList
            }

            newList.toMutableList().also { mutableTracks ->
                mutableTracks.removeAt(index)
                mutableTracks.add(0, currentTrack)
            }
        } else {
            newList
        }
    }

    // The other, fancy-tree algorithm would work, but we already wrote this code, and it's
    // going to run faster than the fancy algorithm. So if we don't need to use extra
    // CPU cycles, don't bother and keep things simple.
    private fun noLeanShuffle(tracks: List<NowPlayingTrack>): List<NowPlayingTrack> {
        val shuffledList = tracks.toMutableList()

        // do a fancy fisher-yates shuffle algorithm
        // https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm
        for (i in 0..shuffledList.size - 2) {
            val swapIndex = Random.Default.nextInt(shuffledList.size)
            val toSwap = shuffledList[swapIndex]
            shuffledList[swapIndex] = shuffledList[i]
            shuffledList[i] = toSwap
        }

        return shuffledList
    }

    fun generateShuffleDescription(lean: Double, minimumPlayCount: Int, maximumPlayCount: Int): String {
        var finalText = ""
        finalText += if (lean != 0.0) {
            val amplitude = abs(lean)
            val amplitudeText = when {
                (amplitude > 0.0) && (amplitude <= 0.33) -> "slightly"
                (amplitude < 0.667) -> "moderately"
                else -> "heavily"
            }
            val lessPlayedText = if (lean < 0) "less" else "more"
            "Shuffle will $amplitudeText favor playing $lessPlayedText played music. "
        } else {
            "Shuffle will be entirely random. "
        }

        finalText += if (minimumPlayCount > 0 && maximumPlayCount < NowPlayingService.MAXIMUM_ALLOWED_SHUFFLE_PLAY_COUNT_LIMIT) {
            "Tracks must have between $minimumPlayCount and $maximumPlayCount plays (inclusive) to be chosen."
        } else if (minimumPlayCount > 0) {
            val pluralPlays = if (minimumPlayCount == 2) "play" else "plays"
            "Tracks must have more than ${minimumPlayCount - 1} $pluralPlays to be chosen."
        } else if (maximumPlayCount < NowPlayingService.MAXIMUM_ALLOWED_SHUFFLE_PLAY_COUNT_LIMIT) {
            val pluralPlays = if (maximumPlayCount == 0) "play" else "plays"
            "Tracks must have fewer than ${maximumPlayCount + 1} $pluralPlays to be chosen."
        } else {
            "All tracks will be eligible to be played."
        }

        return finalText
    }
}

private class TrackShuffleTree(tracks: List<NowPlayingTrack>, weightBump: Double, reversed: Boolean) {
    private val nodes: List<Node>
    init {
        val sortedTracks = tracks.sortedBy { track ->
            // If weight is negative, sort in reverse to make the least played stuff get played more
            if (reversed) -track.track.playCount else track.track.playCount
        }

        var lastPlayCount = -1
        var lastWeight = Double.MAX_VALUE

        nodes = sortedTracks.mapIndexed { index, track ->
            if (track.track.playCount != lastPlayCount) {
                lastPlayCount = track.track.playCount
                lastWeight = 1 + (abs(weightBump) * index)
            }

            Node(
                track = track,
                weight = lastWeight,
                subWeight = 0.0,
            )
        }

        fun findNode(index: Int, isLeft: Boolean): Node? {
            return if (isLeft) {
                nodes.getOrNull(2 * index + 1)
            } else {
                nodes.getOrNull(2 * index + 2)
            }
        }
        nodes.forEachIndexed { index, node ->
            node.leftChild = findNode(index, true)
            node.rightChild = findNode(index, false)
            node.leftChild?.parent = node
            node.rightChild?.parent = node
        }

        fun Node.calculateSubWeight(): Double {
            val leftWeight = leftChild?.calculateSubWeight() ?: 0.0
            val rightWeight = rightChild?.calculateSubWeight() ?: 0.0

            subWeight = weight + leftWeight + rightWeight

            return subWeight
        }
        nodes.first().calculateSubWeight()
    }

    fun retrieveTrack(): NowPlayingTrack {
        var currentNode = nodes.first()
        var currentNum = Random.Default.nextDouble(nodes.first().subWeight)

        while (true) {
            val leftChild = currentNode.leftChild ?: break

            if (currentNum < leftChild.subWeight) {
                currentNode = leftChild
            } else if (currentNum < leftChild.subWeight + currentNode.weight) {
                // Found our node
                break
            } else {
                currentNum = currentNum - leftChild.subWeight - currentNode.weight
                currentNode = currentNode.rightChild!!
            }
        }

        val nodeToReturn = currentNode
        val weightToSubtract = currentNode.weight
        nodeToReturn.weight -= weightToSubtract
        nodeToReturn.subWeight -= weightToSubtract

        while (currentNode.parent != null) {
            currentNode = currentNode.parent!!
            currentNode.subWeight -= weightToSubtract
        }

        return nodeToReturn.track
    }

    private class Node(
        val track: NowPlayingTrack,
        var weight: Double,
        var subWeight: Double,
        var parent: Node? = null,
        var leftChild: Node? = null,
        var rightChild: Node? = null,
    )
}
