package techla.guard

import kotlinx.serialization.KSerializer
import kotlinx.serialization.Serializable
import kotlinx.serialization.builtins.serializer
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder
import kotlinx.serialization.json.*
import techla.base.*

@Serializable(with = TokenSerializer::class)
sealed class Token(private val _discriminator: String) {
    data object Unknown : Token(TokenSerializer.UNKNOWN)

    data class Application(val token: String, val expiresAt: Date) : Token(TokenSerializer.APPLICATION)
    data class User(val token: String, val expiresAt: Date, val group: Key<Group>) : Token(TokenSerializer.USER)
    data class Admin(val token: String, val expiresAt: Date) : Token(TokenSerializer.ADMIN)
    data class AutoStart(val token: String) : Token(TokenSerializer.AUTO_START)
    data class QR(val code: String) : Token(TokenSerializer.QR)
    data class Other(val token: String, val expiresAt: Date?, val key: Key<Token>) : Token(TokenSerializer.OTHER)

    val rawValue: String get() = _discriminator

    override fun toString(): String {
        return when (this) {
            is Unknown -> "Token.Unknown"
            is Application -> "Token.Application"
            is User -> "Token.User"
            is Admin -> "Token.Admin"
            is AutoStart -> "Token.AutoStart"
            is Other -> "Token.Other"
            is QR -> "Token.QR"
        }
    }
}

data class TokenComponents(
    val rawValue: String,
    val token: String? = null,
    val expiresAt: Date? = null,
    val group: Key<Group>? = null,
    val key: Key<Token>? = null,
)

fun Token.flatten() =
    when (this) {
        is Token.Application -> TokenComponents(rawValue, token = token, expiresAt = expiresAt)
        is Token.User -> TokenComponents(rawValue, token = token, expiresAt = expiresAt, group = group)
        is Token.Admin -> TokenComponents(rawValue, token = token, expiresAt = expiresAt)
        is Token.AutoStart -> TokenComponents(rawValue, token = token)
        is Token.Other -> TokenComponents(rawValue, token = token, expiresAt = expiresAt, key = key)
        is Token.QR -> TokenComponents(rawValue, token = code)
        is Token.Unknown ->
            throw TechlaError.PreconditionFailed("Token.Unknown cannot be flatten")
    }

fun Token.Companion.unflatten(components: TokenComponents) =
    when (components.rawValue) {
        TokenSerializer.APPLICATION -> Token.Application(components.token!!, components.expiresAt!!)
        TokenSerializer.USER -> Token.User(components.token!!, components.expiresAt!!, components.group!!)
        TokenSerializer.ADMIN -> Token.Admin(components.token!!, components.expiresAt!!)
        TokenSerializer.AUTO_START -> Token.AutoStart(components.token!!)
        TokenSerializer.OTHER -> Token.Other(components.token!!, components.expiresAt, components.key!!)
        TokenSerializer.QR -> Token.QR(components.token!!)
        else -> Token.Unknown
    }

object TokenSerializer : KSerializer<Token> {
    const val UNKNOWN = "unknown"
    const val APPLICATION = "application"
    const val USER = "user"
    const val ADMIN = "admin"
    const val AUTO_START = "auto_start"
    const val OTHER = "other"
    const val QR = "qr"
    private const val TOKEN = "token"
    private const val EXPIRES_AT = "expires_at"
    private const val GROUP = "group"
    private const val KEY = "key"

    override val descriptor: SerialDescriptor =
        String.serializer().descriptor

    override fun serialize(encoder: Encoder, value: Token) {
        require(encoder is JsonEncoder)

        val jsonObject = buildJsonObject {
            val components = value.flatten()
            put("_discriminator", components.rawValue)
            components.token?.let { put(TOKEN, it) }
            components.expiresAt?.let { put(EXPIRES_AT, DateSerializer.serialize(it)) }
            components.group?.let { put(GROUP, it.rawValue) }
            components.key?.let { put(KEY, it.rawValue) }
        }

        encoder.encodeJsonElement(jsonObject)
    }

    override fun deserialize(decoder: Decoder): Token {
        require(decoder is JsonDecoder)
        val jsonObject = decoder.decodeJsonElement().jsonObject
        val discriminator = jsonObject["_discriminator"]
        if (discriminator == null || discriminator !is JsonPrimitive) {
            return Token.Unknown
        }

        val components = TokenComponents(
            rawValue = discriminator.content,
            token = jsonObject[TOKEN].parseString(),
            expiresAt = jsonObject[EXPIRES_AT].parseDate(),
            group = jsonObject[GROUP].parseKey(),
            key = jsonObject[KEY].parseKey(),
        )

        return Token.unflatten(components)
    }
}
