#!/usr/bin/python3
"""
    Copyright (C) 2023  Michael Ablassmeier <abi@grinser.de>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.
"""
import os
import sys
import signal
import logging
import argparse
from argparse import Namespace
from typing import List, Union, BinaryIO, Any
from datetime import datetime
from functools import partial
from threading import current_thread
from concurrent.futures import ThreadPoolExecutor, as_completed
from libvirt import virDomain
from libvirtnbdbackup import sighandle
from libvirtnbdbackup import argopt
from libvirtnbdbackup import __version__
from libvirtnbdbackup import nbdcli
from libvirtnbdbackup import extenthandler
from libvirtnbdbackup.qemu import util as qemu
from libvirtnbdbackup import virt
from libvirtnbdbackup.virt.client import DomainDisk
from libvirtnbdbackup.virt import checkpoint
from libvirtnbdbackup import output
from libvirtnbdbackup.output import stream
from libvirtnbdbackup import common as lib
from libvirtnbdbackup.processinfo import processInfo
from libvirtnbdbackup.logcount import logCount
from libvirtnbdbackup.sparsestream import streamer
from libvirtnbdbackup.sparsestream import types
from libvirtnbdbackup import metadata
from libvirtnbdbackup import exceptions
from libvirtnbdbackup import chunk
from libvirtnbdbackup import block
from libvirtnbdbackup import partialfile
from libvirtnbdbackup.ssh.exceptions import sshError
from libvirtnbdbackup.nbdcli.exceptions import NbdClientException
from libvirtnbdbackup.qemu.exceptions import ProcessError
from libvirtnbdbackup.virt.exceptions import (
    startBackupFailed,
    domainNotFound,
    connectionFailed,
)
from libvirtnbdbackup.output.exceptions import OutputException


def setOfflineArguments(args: Namespace, domObj: virDomain) -> None:
    """Check if to be saved VM is offline and set
    proper options/overwrite backup mode"""
    args.offline = False
    if domObj.isActive() == 0:
        if args.level == "full":
            logging.warning("Domain is offline, resetting backup options.")
            args.level = "copy"
            logging.info("Backup level: [%s].", args.level)
        args.offline = True


def startBackupJob(
    args: Namespace,
    virtClient: virt.client,
    domObj: virDomain,
    disks: List[DomainDisk],
) -> bool:
    """Start backup job via libvirt API"""
    try:
        logging.info("Starting backup job.")
        virtClient.startBackup(
            args,
            domObj,
            disks,
        )
        logging.debug("Backup job started.")
        return True
    except startBackupFailed as e:
        logging.error(e)

    return False


def main() -> None:
    """Handle backup operation"""
    parser = argparse.ArgumentParser(
        description="Backup libvirt/qemu virtual machines",
        epilog=(
            "Examples:\n"
            "   # full backup of domain 'webvm' with all attached disks:\n"
            "\t%(prog)s -d webvm -l full -o /backup/\n"
            "   # incremental backup:\n"
            "\t%(prog)s -d webvm -l inc -o /backup/\n"
            "   # differential backup:\n"
            "\t%(prog)s -d webvm -l diff -o /backup/\n"
            "   # full backup, exclude disk 'vda':\n"
            "\t%(prog)s -d webvm -l full -x vda -o /backup/\n"
            "   # full backup, backup only disk 'vdb':\n"
            "\t%(prog)s -d webvm -l full -i vdb -o /backup/\n"
            "   # full backup, compression enabled:\n"
            "\t%(prog)s -d webvm -l full -z -o /backup/\n"
            "   # full backup, create archive:\n"
            "\t%(prog)s -d webvm -l full -o - > backup.zip\n"
            "   # full backup of vm operating on remote libvirtd:\n"
            "\t%(prog)s -U qemu+ssh://root@remotehost/system "
            "--ssh-user root -d webvm -l full -o /backup/\n"
        ),
        formatter_class=argparse.RawTextHelpFormatter,
    )

    opt = parser.add_argument_group("General options")
    opt.add_argument("-d", "--domain", required=True, type=str, help="Domain to backup")
    opt.add_argument(
        "-l",
        "--level",
        default="copy",
        choices=["copy", "full", "inc", "diff", "auto"],
        type=str,
        help="Backup level. (default: %(default)s)",
    )
    opt.add_argument(
        "-t",
        "--type",
        default="stream",
        type=str,
        choices=["stream", "raw"],
        help="Output type: stream or raw. (default: %(default)s)",
    )
    opt.add_argument(
        "-r",
        "--raw",
        default=False,
        action="store_true",
        help="Include full provisioned disk images in backup. (default: %(default)s)",
    )
    opt.add_argument(
        "-o", "--output", required=True, type=str, help="Output target directory"
    )
    opt.add_argument(
        "-C",
        "--checkpointdir",
        required=False,
        default=None,
        type=str,
        help="Persistent libvirt checkpoint storage directory",
    )
    opt.add_argument(
        "-S",
        "--scratchdir",
        default="/var/tmp",
        required=False,
        type=str,
        help="Target dir for temporary scratch file. (default: %(default)s)",
    )
    opt.add_argument(
        "-i",
        "--include",
        default=None,
        type=str,
        help="Backup only disk with target dev name (-i vda)",
    )
    opt.add_argument(
        "-x",
        "--exclude",
        default=None,
        type=str,
        help="Exclude disk(s) with target dev name (-x vda,vdb)",
    )
    opt.add_argument(
        "-f",
        "--socketfile",
        default=f"/var/tmp/virtnbdbackup.{os.getpid()}",
        type=str,
        help="Use specified file for NBD Server socket (default: %(default)s)",
    )
    opt.add_argument(
        "-n",
        "--noprogress",
        default=False,
        help="Disable progress bar",
        action="store_true",
    )
    opt.add_argument(
        "-z",
        "--compress",
        default=False,
        type=int,
        const=2,
        nargs="?",
        help="Compress with lz4 compression level. (default: %(default)s)",
        action="store",
    )
    opt.add_argument(
        "-w",
        "--worker",
        type=int,
        default=None,
        help=(
            "Amount of concurrent workers used "
            "to backup multiple disks. (default: amount of disks)"
        ),
    )
    opt.add_argument(
        "-F",
        "--freeze-mountpoint",
        type=str,
        default=None,
        help=(
            "If qemu agent available, freeze only filesystems on specified mountpoints within"
            " virtual machine (default: all)"
        ),
    )
    opt.add_argument(
        "-e",
        "--strict",
        default=False,
        help=(
            "Change exit code if warnings occur during backup operation. "
            "(default: %(default)s)"
        ),
        action="store_true",
    )
    opt.add_argument(
        "-T",
        "--threshold",
        type=int,
        default=None,
        help=("Execute backup only if threshold is reached."),
    )
    remopt = parser.add_argument_group("Remote Backup options")
    argopt.addRemoteArgs(remopt)
    logopt = parser.add_argument_group("Logging options")
    logopt.add_argument(
        "-L",
        "--syslog",
        default=False,
        action="store_true",
        help="Additionally send log messages to syslog (default: %(default)s)",
    )
    debopt = parser.add_argument_group("Debug options")
    debopt.add_argument(
        "-q",
        "--qemu",
        default=False,
        action="store_true",
        help="Use Qemu tools to query extents.",
    )
    debopt.add_argument(
        "-s",
        "--startonly",
        default=False,
        help="Only initialize backup job via libvirt, do not backup any data",
        action="store_true",
    )
    debopt.add_argument(
        "-k",
        "--killonly",
        default=False,
        help="Kill any running block job",
        action="store_true",
    )
    debopt.add_argument(
        "-p",
        "--printonly",
        default=False,
        help="Quit after printing estimated checkpoint size.",
        action="store_true",
    )
    argopt.addDebugArgs(debopt)

    repository = output.target()
    args = lib.argparse(parser)

    args.stdout = args.output == "-"
    args.sshClient = None
    args.diskInfo = []

    try:
        fileStream = stream.get(args, repository)
    except OutputException as e:
        logging.error("Can't open output file: [%s]", e)
        sys.exit(1)

    if args.worker is not None and args.worker < 1:
        args.worker = 1

    now = datetime.now().strftime("%m%d%Y%H%M%S")
    logFile = f"{args.output}/backup.{args.level}.{now}.log"
    fileLog = lib.getLogFile(logFile) or sys.exit(1)

    counter = logCount()
    lib.configLogger(args, fileLog, counter)
    lib.printVersion(__version__)

    logging.info("Backup level: [%s]", args.level)
    if args.compress:
        logging.info("Compression enabled, level [%s]", args.compress)

    if args.stdout is True and args.type == "raw":
        logging.error("Output type raw not supported to stdout.")
        sys.exit(1)

    if args.stdout is True and args.raw is True:
        logging.error("Saving raw images to stdout is not supported.")
        sys.exit(1)

    if lib.targetIsEmpty(args) and args.level == "auto":
        logging.info("Backup mode auto, target folder is empty: executing full backup.")
        args.level = "full"
    elif not lib.targetIsEmpty(args) and args.level == "auto":
        if not lib.hasFullBackup(args):
            logging.error(
                "Can't execute switch to auto incremental backup: "
                "target folder seems not to include full backup."
            )
            sys.exit(1)
        logging.info("Backup mode auto: executing incremental backup.")
        args.level = "inc"
    elif not args.stdout and not args.startonly and not args.killonly:
        if not lib.targetIsEmpty(args):
            logging.error("Target directory already contains full or copy backup.")
            sys.exit(1)

    if args.raw is True and args.level in ("inc", "diff"):
        logging.warning(
            "Raw disks can't be included during incremental or differential backup."
        )
        logging.warning("Excluding raw disks.")
        args.raw = False

    if args.type == "raw" and args.level in ("inc", "diff"):
        logging.error(
            "Stream format raw does not support incremental or differential backup."
        )
        sys.exit(1)

    if partialfile.exists(args):
        sys.exit(1)

    if not args.checkpointdir:
        args.checkpointdir = f"{args.output}/checkpoints"
    else:
        logging.info("Store checkpoints in: [%s]", args.checkpointdir)

    repository.Directory(args.checkpointdir)

    try:
        virtClient = virt.client(args)
        domObj = virtClient.getDomain(args.domain)
    except domainNotFound as e:
        logging.error("%s", e)
        sys.exit(1)
    except connectionFailed as e:
        logging.error("Can't connect libvirt daemon: [%s]", e)
        sys.exit(1)

    logging.info("Libvirt library version: [%s]", virtClient.libvirtVersion)

    if virtClient.hasIncrementalEnabled(domObj) is False:
        logging.error("Domain is missing required incremental-backup capability.")
        sys.exit(1)

    try:
        checkpoint.checkForeign(args, domObj)
    except exceptions.CheckpointException:
        sys.exit(1)

    setOfflineArguments(args, domObj)
    if args.offline is True and args.startonly is True:
        logging.error("Virtual machine is currently offline")
        logging.error("Virtual machine must be active for this function.")
        sys.exit(1)

    signal.signal(
        signal.SIGINT,
        partial(sighandle.Backup.catch, args, domObj, virtClient, logging),
    )

    vmConfig = virtClient.getDomainConfig(domObj)
    disks: List[DomainDisk] = virtClient.getDomainDisks(args, vmConfig)
    args.info = virtClient.getDomainInfo(vmConfig)

    if args.level != "copy" and lib.hasQcowDisks(disks) is False:
        args.level = "copy"
        logging.info("Only raw disks attached, switching to backup mode copy.")

    if not disks:
        logging.error("Unable to detect disks suitable for backup.")
        metadata.saveFiles(args, vmConfig, disks, fileStream, logFile)
        sys.exit(1)
    if (
        not args.killonly
        and not args.offline
        and virtClient.blockJobActive(domObj, disks)
    ):
        logging.error("Detected an active backup operation for running domain.")
        logging.error(
            "Check with [virsh domjobinfo %s] or use option -k to kill the active job.",
            args.domain,
        )
        sys.exit(1)

    logging.info(
        "Backup will save [%s] attached disks.",
        len(disks),
    )
    if args.worker is None or args.worker > int(len(disks)):
        args.worker = int(len(disks))
    logging.info("Concurrent backup processes: [%s]", args.worker)

    if args.killonly is True:
        logging.info("Stopping backup job")
        if not virtClient.stopBackup(domObj):
            sys.exit(1)
        sys.exit(0)

    try:
        checkpoint.create(args, domObj)
    except exceptions.CheckpointException as errmsg:
        logging.error(errmsg)
        sys.exit(1)

    if args.printonly and args.cpt.parent and not args.offline:
        size = checkpoint.getSize(domObj, args.cpt.parent)
        logging.info("Estimated checkpoint backup size: [%s] Bytes", size)
        sys.exit(0)

    if args.threshold and args.cpt.parent and not args.offline:
        size = checkpoint.getSize(domObj, args.cpt.parent)
        if size < args.threshold:
            logging.info(
                "Backup size [%s] does not meet required threshold [%s], skipping backup.",
                size,
                args.threshold,
            )
            sys.exit(0)

    if virtClient.remoteHost != "":
        args.sshClient = lib.sshSession(args, virtClient.remoteHost)
        if not args.sshClient:
            logging.error("Remote backup detected but ssh session setup failed")
            sys.exit(1)
        logging.info(
            "Remote NBD Endpoint host: [%s]",
            virtClient.remoteHost,
        )
        if args.offline is True:
            logging.info(
                "Remote ports used for backup: [%s-%s]",
                args.nbd_port,
                args.nbd_port + args.worker,
            )
    else:
        logging.info("Local NBD Endpoint socket: [%s]", args.socketfile)

    if args.offline is not True:
        logging.info("Temporary scratch file target directory: [%s]", args.scratchdir)
        repository.Directory(args.scratchdir)
        if not startBackupJob(args, virtClient, domObj, disks):
            sys.exit(1)

    if args.level not in ("copy", "diff") and args.offline is False:
        logging.info("Started backup job with checkpoint, saving information.")
        try:
            checkpoint.save(args)
        except exceptions.CheckpointException as e:
            logging.error("Extending checkpoint file failed: [%s]", e)
            sys.exit(1)
        if not checkpoint.backup(args, domObj):
            virtClient.stopBackup(domObj)
            sys.exit(1)

    if args.startonly is True:
        logging.info("Started backup job for debugging, exiting.")
        sys.exit(0)

    try:
        with ThreadPoolExecutor(max_workers=args.worker) as executor:
            futures = {
                executor.submit(
                    backupDisk, args, disk, count, fileStream, virtClient
                ): disk
                for count, disk in enumerate(disks)
            }
            for future in as_completed(futures):
                if future.result() is not True:
                    raise exceptions.DiskBackupFailed("Backup of one disk failed")
    except exceptions.BackupException as e:
        logging.error("Disk backup failed: [%s]", e)
    except sshError as e:
        logging.error("Remote Disk backup failed: [%s]", e)
    except Exception as e:  # pylint: disable=broad-except
        logging.fatal("Unknown Exception during backup: %s", e)
        logging.exception(e)

    if args.offline is False:
        logging.info("Backup jobs finished, stopping backup task.")
        virtClient.stopBackup(domObj)

    metadata.saveFiles(args, vmConfig, disks, fileStream, logFile)

    if domObj.autostart() == 1:
        metadata.backupAutoStart(args)

    if counter.count.errors > 0:
        logging.error("Error during backup")
        sys.exit(1)

    if args.sshClient:
        args.sshClient.disconnect()

    if counter.count.warnings > 0 and args.strict is True:
        logging.info(
            "[%s] Warnings detected during backup operation, forcing exit code 2",
            counter.count.warnings,
        )
        sys.exit(2)

    logging.info("Finished successfully")


def setStreamType(args: Namespace, disk: DomainDisk) -> str:
    """Set target stream type based on disk format"""
    streamType = "raw"
    if disk.format != streamType:
        streamType = args.type

    return streamType


def setTargetFile(args: Namespace, disk: DomainDisk, ext: str = "data"):
    """Set Target file name to write data to, used for both data files
    and qemu disk info"""
    targetFile: str = ""
    if args.level in ("full", "copy"):
        level = args.level
        if disk.format == "raw":
            level = "copy"
        targetFile = f"{args.output}/{disk.target}.{level}.{ext}"
    elif args.level in ("inc", "diff"):
        cptName = lib.getIdent(args)
        targetFile = f"{args.output}/{disk.target}.{args.level}.{cptName}.{ext}"

    targetFilePartial = f"{targetFile}.partial"

    return targetFile, targetFilePartial


def getWriter(
    args: Namespace, fileStream, targetFile: str, targetFilePartial: str
) -> BinaryIO:
    """Open target file based on output writer"""
    if args.stdout is True:
        targetFile = os.path.basename(targetFile)
        logging.info("Writing data to zip archive.")
        writer = fileStream.open(targetFile)
    else:
        logging.info("Write data to target file: [%s].", targetFilePartial)
        writer = fileStream.open(targetFilePartial)

    return writer


def getExtentHandler(args: Namespace, nbdClient):
    """Query dirty blocks either via qemu client or self
    implemented extend handler"""
    if args.qemu:
        logging.info("Using qemu tools to query extents")
        extentHandler = extenthandler.ExtentHandler(
            qemu.util(nbdClient.cType.exportName), nbdClient.cType
        )
    else:
        extentHandler = extenthandler.ExtentHandler(nbdClient, nbdClient.cType)

    return extentHandler


def startOfflineNBD(
    args: Namespace, disk: DomainDisk, remoteHost: str, port: int
) -> processInfo:
    """Start background qemu-nbd process used during backup
    if domain is offline, in case of remote backup, initiate
    ssh session and start process on remote system."""
    bitMap: str = ""
    if args.level in ("inc", "diff"):
        bitMap = args.cpt.name
    socket = f"{args.socketfile}.{disk.target}"
    if remoteHost != "":
        logging.info(
            "Offline backup, starting remote NBD server, socket: [%s:%s], port: [%s]",
            remoteHost,
            socket,
            port,
        )
        nbdProc = qemu.util(disk.target).startRemoteBackupNbdServer(
            args, disk, bitMap, port
        )
        logging.info("Remote NBD server started, PID: [%s].", nbdProc.pid)
        return nbdProc

    logging.info("Offline backup, starting local NBD server, socket: [%s]", socket)
    nbdProc = qemu.util(disk.target).startBackupNbdServer(
        disk.format, disk.path, socket, bitMap
    )
    logging.info("Local NBD Service started, PID: [%s]", nbdProc.pid)
    return nbdProc


def connectNbd(
    args: Namespace,
    disk: DomainDisk,
    metaContext: str,
    remoteIP: str,
    port: int,
    virtClient: virt.client,
):
    """Connect to started nbd endpoint"""
    socket = args.socketfile
    if args.offline is True:
        socket = f"{args.socketfile}.{disk.target}"

    cType: Union[nbdcli.TCP, nbdcli.Unix]
    if virtClient.remoteHost != "":
        cType = nbdcli.TCP(disk.target, metaContext, remoteIP, args.tls, port)
    else:
        cType = nbdcli.Unix(disk.target, metaContext, socket)

    nbdClient = nbdcli.client(cType)

    try:
        return nbdClient.connect()
    except NbdClientException as e:
        raise exceptions.DiskBackupFailed(
            f"NBD endpoint: [{cType}]: connection failed: [{e}]"
        )


def backupDisk(
    args: Namespace,
    disk: DomainDisk,
    count: int,
    fileStream,
    virtClient: virt.client,
):
    """Backup domain disk data."""
    stream = streamer.SparseStream(types)
    sTypes = types.SparseStreamTypes()
    current_thread().name = disk.target
    streamType = setStreamType(args, disk)
    metaContext = nbdcli.context.get(args, disk)
    nbdProc: processInfo
    remoteIP: str = virtClient.remoteHost
    port: int = args.nbd_port
    if args.nbd_ip != "":
        remoteIP = args.nbd_ip

    if args.offline is True:
        port = args.nbd_port + count
        try:
            nbdProc = startOfflineNBD(args, disk, virtClient.remoteHost, port)
        except ProcessError as errmsg:
            logging.error(errmsg)
            raise exceptions.DiskBackupFailed("Failed to start NBD server.")

    connection = connectNbd(args, disk, metaContext, remoteIP, port, virtClient)

    extentHandler = getExtentHandler(args, connection)
    extents = extentHandler.queryBlockStatus()
    diskSize = connection.nbd.get_size()

    if extents is None:
        logging.error("No extents found.")
        return True

    thinBackupSize = sum(extent.length for extent in extents if extent.data is True)
    logging.info("Got %s extents to backup.", len(extents))
    logging.debug("%s", lib.dumpExtentJson(extents))
    logging.info("%s bytes disk size", diskSize)
    logging.info("%s bytes of data extents to backup", thinBackupSize)

    if args.level in ("inc", "diff") and thinBackupSize == 0:
        logging.info("No dirty blocks found")
        args.noprogress = True

    targetFile, targetFilePartial = setTargetFile(args, disk)
    writer = getWriter(args, fileStream, targetFile, targetFilePartial)

    if streamType == "raw":
        logging.info("Creating full provisioned raw backup image")
        writer.truncate(diskSize)
    else:
        logging.info("Creating thin provisioned stream backup image")
        header = stream.dumpMetadata(
            args,
            diskSize,
            thinBackupSize,
            disk,
        )
        stream.writeFrame(writer, sTypes.META, 0, len(header))
        writer.write(header)
        writer.write(sTypes.TERM)

    progressBar = lib.progressBar(
        thinBackupSize, f"saving disk {disk.target}", args, count=count
    )
    compressedSizes: List[Any] = []
    for save in extents:
        if save.data is True:
            if streamType == "stream":
                stream.writeFrame(writer, sTypes.DATA, save.offset, save.length)
                logging.debug(
                    "Read data from: start %s, length: %s", save.offset, save.length
                )

            cSizes = None

            if save.length >= connection.maxRequestSize:
                logging.debug(
                    "Chunked data read from: start %s, length: %s",
                    save.offset,
                    save.length,
                )
                size, cSizes = chunk.write(
                    writer,
                    save,
                    connection,
                    streamType,
                    args.compress,
                )
            else:
                size = block.write(
                    writer,
                    save,
                    connection,
                    streamType,
                    args.compress,
                )
                if streamType == "raw":
                    size = writer.seek(save.offset)

            if streamType == "stream":
                writer.write(sTypes.TERM)
                if args.compress:
                    logging.debug("Compressed size: %s", size)
                    if cSizes:
                        blockList = {}
                        blockList[size] = cSizes
                        compressedSizes.append(blockList)
                    else:
                        compressedSizes.append(size)
                else:
                    assert size == save.length

            progressBar.update(save.length)
        else:
            if streamType == "raw":
                writer.seek(save.offset)
            elif streamType == "stream" and args.level not in ("inc", "diff"):
                stream.writeFrame(writer, sTypes.ZERO, save.offset, save.length)
    if streamType == "stream":
        stream.writeFrame(writer, sTypes.STOP, 0, 0)
        if args.compress:
            stream.writeCompressionTrailer(writer, compressedSizes)

    progressBar.close()
    logging.debug("Closing write handle.")
    writer.close()
    connection.disconnect()
    if args.offline is True and virtClient.remoteHost == "":
        logging.info("Stopping NBD Service.")
        lib.killProc(nbdProc.pid)

    if args.offline is True:
        lib.remove(args, nbdProc.pidFile)

    if args.stdout is False:
        if args.noprogress is True:
            logging.info(
                "Backup of disk [%s] finished, file: [%s]", disk.target, targetFile
            )
        partialfile.rename(targetFilePartial, targetFile)

    return True


if __name__ == "__main__":
    main()
