use std::{
    collections::{HashMap, VecDeque},
    io::{Read, Write},
    path::Path,
    str::FromStr,
};

use keyfork_derive_openpgp::derive_util::{
    request::{DerivationAlgorithm, DerivationRequest},
    DerivationPath,
};
use openpgp::{
    armor::{Kind, Writer},
    cert::{Cert, CertParser, ValidCert},
    packet::{Packet, Tag, UserID, PKESK, SEIP},
    parse::{
        stream::{DecryptionHelper, DecryptorBuilder, VerificationHelper},
        Parse,
    },
    policy::{NullPolicy, Policy, StandardPolicy},
    serialize::{
        stream::{ArbitraryWriter, Encryptor, LiteralWriter, Message, Recipient, Signer},
        Marshal,
    },
    types::KeyFlags,
    KeyID, PacketPile,
};
pub use sequoia_openpgp as openpgp;
use sharks::{Share, Sharks};

mod keyring;
use keyring::Keyring;

mod smartcard;
use smartcard::SmartcardManager;

// TODO: better error handling

#[derive(Debug, Clone)]
pub struct WrappedError(String);

impl std::fmt::Display for WrappedError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.write_str(&self.0)
    }
}

impl std::error::Error for WrappedError {}

pub type Result<T, E = Box<dyn std::error::Error>> = std::result::Result<T, E>;

#[derive(Debug, Clone)]
pub struct EncryptedMessage {
    pkesks: Vec<PKESK>,
    message: SEIP,
}

impl EncryptedMessage {
    pub fn with_swap(pkesks: &mut Vec<PKESK>, seip: SEIP) -> Self {
        Self {
            pkesks: std::mem::take(pkesks),
            message: seip,
        }
    }

    pub fn decrypt_with<H>(&self, policy: &'_ dyn Policy, decryptor: H) -> Result<Vec<u8>>
    where
        H: VerificationHelper + DecryptionHelper,
    {
        let mut packets = vec![];

        for pkesk in &self.pkesks {
            let mut packet = vec![];
            pkesk.serialize(&mut packet)?;
            let message = Message::new(&mut packets);
            let mut message = ArbitraryWriter::new(message, Tag::PKESK)?;
            message.write_all(&packet)?;
            message.finalize()?;
        }
        let mut packet = vec![];
        self.message.serialize(&mut packet)?;
        let message = Message::new(&mut packets);
        let mut message = ArbitraryWriter::new(message, Tag::SEIP)?;
        message.write_all(&packet)?;
        message.finalize()?;

        let mut decryptor =
            DecryptorBuilder::from_bytes(&packets)?.with_policy(policy, None, decryptor)?;

        let mut content = vec![];
        decryptor.read_to_end(&mut content)?;
        Ok(content)
    }
}

pub fn discover_certs(path: impl AsRef<Path>) -> Result<Vec<Cert>> {
    let path = path.as_ref();

    if path.is_file() {
        let mut vec = vec![];
        for cert in CertParser::from_file(path)? {
            vec.push(cert?);
        }
        Ok(vec)
    } else {
        let mut vec = vec![];
        for entry in path
            .read_dir()?
            .filter_map(Result::ok)
            .filter(|p| p.path().is_file())
        {
            vec.push(Cert::from_file(entry.path())?);
        }
        Ok(vec)
    }
}

pub fn parse_messages(reader: impl Read + Send + Sync) -> Result<VecDeque<EncryptedMessage>> {
    let mut pkesks = Vec::new();
    let mut encrypted_messages = VecDeque::new();

    for packet in PacketPile::from_reader(reader)?.into_children() {
        match packet {
            Packet::PKESK(p) => pkesks.push(p),
            Packet::SEIP(s) => {
                encrypted_messages.push_back(EncryptedMessage::with_swap(&mut pkesks, s));
            }
            s => {
                panic!("Invalid variant found: {}", s.tag());
            }
        }
    }

    Ok(encrypted_messages)
}

fn get_encryption_keys<'a>(
    cert: &'a ValidCert,
) -> openpgp::cert::prelude::ValidKeyAmalgamationIter<
    'a,
    openpgp::packet::key::PublicParts,
    openpgp::packet::key::UnspecifiedRole,
> {
    cert.keys()
        .alive()
        .revoked(false)
        .supported()
        .for_storage_encryption()
}

fn get_decryption_keys<'a>(
    cert: &'a ValidCert,
) -> openpgp::cert::prelude::ValidKeyAmalgamationIter<
    'a,
    openpgp::packet::key::SecretParts,
    openpgp::packet::key::UnspecifiedRole,
> {
    cert.keys()
        /*
        .alive()
        .revoked(false)
        .supported()
        */
        .for_storage_encryption()
        .secret()
}

pub fn combine(
    threshold: u8,
    certs: Vec<Cert>,
    metadata: EncryptedMessage,
    messages: Vec<EncryptedMessage>,
    mut output: impl Write,
) -> Result<()> {
    // Be as liberal as possible when decrypting.
    // We don't want to invalidate someone's keys just because the old sig expired.
    let policy = NullPolicy::new();

    let mut keyring = Keyring::new(certs);
    let content = metadata.decrypt_with(&policy, &mut keyring)?;

    let mut cert_parser = CertParser::from_bytes(&content)?;
    let root_cert = match cert_parser.next() {
        Some(Ok(c)) => c,
        Some(Err(e)) => panic!("Could not find root (first) certificate: {e}"),
        None => panic!("No certs found in cert parser"),
    };
    let certs = cert_parser.collect::<openpgp::Result<Vec<_>>>()?;
    keyring.set_root_cert(root_cert);
    let mut messages: HashMap<KeyID, EncryptedMessage> =
        HashMap::from_iter(certs.iter().map(|c| c.keyid()).zip(messages));
    let mut decrypted_messages: HashMap<KeyID, Vec<u8>> = HashMap::new();

    // NOTE: This is ONLY stable because we control the generation of PKESK packets and
    // encode the policy to ourselves.
    for valid_cert in certs.iter().map(|cert| cert.with_policy(&policy, None)) {
        let valid_cert = valid_cert?;
        // get keys from keyring for cert
        let Some(secret_cert) = keyring.get_cert_for_primary_keyid(&valid_cert.keyid()) else {
            continue;
        };
        let secret_cert = secret_cert.with_policy(&policy, None)?;
        let keys = get_decryption_keys(&secret_cert).collect::<Vec<_>>();
        if !keys.is_empty() {
            if let Some(message) = messages.get_mut(&valid_cert.keyid()) {
                for (pkesk, key) in message.pkesks.iter_mut().zip(keys) {
                    pkesk.set_recipient(key.keyid());
                }
                // we have a pkesk, decrypt via keyring
                let result = message.decrypt_with(&policy, &mut keyring);
                match result {
                    Ok(message) => {
                        decrypted_messages.insert(valid_cert.keyid(), message);
                    }
                    Err(e) => {
                        eprintln!(
                            "Could not decrypt with fingerprint {}: {}",
                            valid_cert.keyid(),
                            e
                        );
                        // do nothing, key will be retained
                    }
                }
            }
        }
    }

    // clean decrypted messages from encrypted messages
    messages.retain(|k, _v| !decrypted_messages.contains_key(k));

    let left_from_threshold = threshold as usize - decrypted_messages.len();
    if left_from_threshold > 0 {
        eprintln!("remaining keys: {left_from_threshold}, prompting yubikeys");
        // TODO: allow decrypt metadata with Yubikey, avoid require stage 1
        let mut manager =
            SmartcardManager::new(keyring.root_cert().expect("stage 1 decrypt").clone());
        let mut remaining_usable_certs = certs
            .iter()
            .filter(|cert| messages.contains_key(&cert.keyid()))
            .collect::<Vec<_>>();

        while threshold as usize - decrypted_messages.len() > 0 {
            remaining_usable_certs.retain(|cert| messages.contains_key(&cert.keyid()));
            let mut fingerprints = HashMap::new();
            for valid_cert in remaining_usable_certs
                .iter()
                .map(|cert| cert.with_policy(&policy, None))
            {
                let valid_cert = valid_cert?;
                fingerprints.insert(
                    valid_cert.keyid(),
                    valid_cert
                        .keys()
                        .for_storage_encryption()
                        .map(|k| k.fingerprint())
                        .collect::<Vec<_>>(),
                );
            }
            for (cert_id, fingerprints) in fingerprints {
                if manager.load_any_fingerprint(fingerprints)?.is_some() {
                    // manager is loaded with a Card<Open>, utilize in tx
                    let message = messages.remove(&cert_id);
                    if let Some(message) = message {
                        let message = message.decrypt_with(&policy, &mut manager)?;
                        decrypted_messages.insert(cert_id, message);
                    }
                }
            }
        }
    }

    let shares = decrypted_messages
        .values()
        .map(|message| Share::try_from(message.as_slice()))
        .collect::<Result<Vec<_>, &str>>()
        .map_err(|e| WrappedError(e.to_string()))?;
    let secret = Sharks(threshold).recover(&shares)?;

    let userid = UserID::from("keyfork-sss");
    let kdr = DerivationRequest::new(
        DerivationAlgorithm::Ed25519,
        &DerivationPath::from_str("m/7366512'/0'")?,
    )
    .derive_with_master_seed(secret.to_vec())?;
    let derived_cert = keyfork_derive_openpgp::derive(
        kdr,
        &[KeyFlags::empty().set_certification().set_signing()],
        userid,
    )?;

    // NOTE: Signatures on certs will be different. Compare fingerprints instead.
    if Some(derived_cert.fingerprint()) != keyring.root_cert().map(Cert::fingerprint) {
        return Err(WrappedError(format!(
            "Derived {} != expected {}",
            derived_cert.fingerprint(),
            keyring
                .root_cert()
                .expect("cert was previously set")
                .fingerprint()
        ))
        .into());
    }

    output.write_all(smex::encode(&secret).as_bytes())?;

    Ok(())
}

pub fn split(threshold: u8, certs: Vec<Cert>, secret: &[u8], output: impl Write) -> Result<()> {
    // build cert to sign encrypted shares
    let userid = UserID::from("keyfork-sss");
    let kdr = DerivationRequest::new(
        DerivationAlgorithm::Ed25519,
        &DerivationPath::from_str("m/7366512'/0'")?,
    )
    .derive_with_master_seed(secret.to_vec())?;
    let derived_cert = keyfork_derive_openpgp::derive(
        kdr,
        &[KeyFlags::empty().set_certification().set_signing()],
        userid,
    )?;
    let signing_key = derived_cert
        .primary_key()
        .parts_into_secret()?
        .key()
        .clone()
        .into_keypair()?;

    let sharks = Sharks(threshold);
    let dealer = sharks.dealer(secret);
    let shares = dealer.map(|s| Vec::from(&s)).collect::<Vec<_>>();
    let policy = StandardPolicy::new();
    let mut writer = Writer::new(output, Kind::Message)?;

    let mut total_recipients = vec![];
    let mut messages = vec![];

    for (share, cert) in shares.iter().zip(certs) {
        total_recipients.push(cert.clone());
        let valid_cert = cert.with_policy(&policy, None)?;
        let encryption_keys = get_encryption_keys(&valid_cert).collect::<Vec<_>>();

        let mut message_output = vec![];
        let message = Message::new(&mut message_output);
        let message = Encryptor::for_recipients(
            message,
            encryption_keys
                .iter()
                .map(|k| Recipient::new(KeyID::wildcard(), k.key())),
        )
        .build()?;
        let message = Signer::new(message, signing_key.clone()).build()?;
        let mut message = LiteralWriter::new(message).build()?;
        message.write_all(share)?;
        message.finalize()?;

        messages.push(message_output);
    }

    let mut pp = vec![];
    // store derived cert to verify provided shares
    derived_cert.serialize(&mut pp)?;
    for recipient in &total_recipients {
        recipient.serialize(&mut pp)?;
    }

    // verify packet pile
    for (packet_cert, cert) in openpgp::cert::CertParser::from_bytes(&pp)?
        .skip(1)
        .zip(total_recipients.iter())
    {
        if packet_cert? != *cert {
            panic!(
                "packet pile could not recreate cert: {}",
                cert.fingerprint()
            );
        }
    }

    let valid_certs = total_recipients
        .iter()
        .map(|c| c.with_policy(&policy, None))
        .collect::<openpgp::Result<Vec<_>>>()?;

    let total_recipients = valid_certs.iter().flat_map(|vc| {
        get_encryption_keys(vc).map(|key| Recipient::new(KeyID::wildcard(), key.key()))
    });

    // metadata
    let mut message_output = vec![];
    let message = Message::new(&mut message_output);
    let message = Encryptor::for_recipients(message, total_recipients).build()?;
    let mut message = LiteralWriter::new(message).build()?;
    message.write_all(&pp)?;
    message.finalize()?;
    writer.write_all(&message_output)?;

    for message in messages {
        writer.write_all(&message)?;
    }

    writer.finalize()?;

    Ok(())
}