#!/usr/bin/env python
# blockifyasm ----- Split disassembly into basic blocks ---------*- python -*-
#
# This source file is part of the Swift.org open source project
#
# Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
# Licensed under Apache License v2.0 with Runtime Library Exception
#
# See https://swift.org/LICENSE.txt for license information
# See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
#
# ----------------------------------------------------------------------------
#
# Splits a disassembled function from lldb into basic blocks.
#
# Useful to show the control flow graph of a disassembled function.
# The control flow graph can the be viewed with the viewcfg utility:
#
# (lldb) disassemble
#    <copy-paste output to file.s>
# $ blockifyasm < file.s | viewcfg
#
# ----------------------------------------------------------------------------

from __future__ import print_function

import re
import sys
from collections import defaultdict


def help():
    print(
        """\
Usage:

blockifyasm [-<n>] < file

-<n>: only match <n> significant digits of relative branch addresses
"""
    )


def main():

    addr_len = 16
    if len(sys.argv) >= 2:
        m = re.match("^-([0-9]+)$", sys.argv[1])
        if m:
            addr_len = int(m.group(1))
        else:
            help()
            return

    lines = []
    block_starts = {}

    branch_re1 = re.compile(r"^\s[-\s>]*0x.*:\s.* 0x([0-9a-f]+)\s*;\s*<[+-]")
    branch_re2 = re.compile(r"^\s[-\s>]*0x.*:\s+tb.* 0x([0-9a-f]+)\s*(;.*)?")
    inst_re = re.compile(r"^\s[-\s>]*0x([0-9a-f]+)[\s<>0-9+-]*:\s+([a-z0-9.]+)\s")
    non_fall_through_insts = ["b", "ret", "brk", "jmp", "retq", "ud2"]

    def get_branch_addr(line):
        bm = branch_re1.match(line)
        if bm:
            return bm.group(1)[-addr_len:]
        bm = branch_re2.match(line)
        if bm:
            return bm.group(1)[-addr_len:]
        return None

    def print_function():
        if not lines:
            return
        predecessors = defaultdict(list)
        block_num = -1
        next_is_block = True
        prev_is_fallthrough = False

        # Collect predecessors for all blocks
        for line in lines:
            m = inst_re.match(line)
            assert m, "non instruction line in function"
            addr = m.group(1)[-addr_len:]
            inst = m.group(2)
            if next_is_block or addr in block_starts:
                if prev_is_fallthrough:
                    predecessors[addr].append(block_num)

                block_num += 1
                block_starts[addr] = block_num
                next_is_block = False

            prev_is_fallthrough = True
            br_addr = get_branch_addr(line)
            if br_addr:
                next_is_block = True
                predecessors[br_addr].append(block_num)

            prev_is_fallthrough = inst not in non_fall_through_insts

        # Print the function with basic block labels
        print("{")
        for line in lines:
            m = inst_re.match(line)
            if m:
                addr = m.group(1)[-addr_len:]
                if addr in block_starts:
                    blockstr = "bb" + str(block_starts[addr]) + ":"
                    if predecessors[addr]:
                        print(
                            blockstr + " " * (55 - len(blockstr)) + "; preds = ", end=""
                        )
                        print(
                            ", ".join("bb" + str(pred) for pred in predecessors[addr])
                        )
                    else:
                        print(blockstr)

            br_addr = get_branch_addr(line)
            if br_addr and block_starts[br_addr] >= 0:
                line = re.sub(r";\s<[+-].*", "; bb" + str(block_starts[br_addr]), line)

            print(line, end="")
        print("}")

    # Read disassembly code from stdin
    for line in sys.stdin:
        # let the line with the instruction pointer begin with a space
        line = re.sub("^-> ", " ->", line)

        if inst_re.match(line):
            lines.append(line)
            br_addr = get_branch_addr(line)
            if br_addr:
                if len(br_addr) < addr_len:
                    addr_len = len(br_addr)
                block_starts[br_addr] = -1
        else:
            print_function()
            lines = []
            block_starts = {}
            print(line, end="")

    print_function()


if __name__ == "__main__":
    main()
