#!/usr/bin/env python3 import copy import time from itertools import chain, groupby from multiprocessing import Manager, Process, cpu_count import networkx as nx from osaca.semantics import INSTR_FLAGS, ArchSemantics, MachineModel from osaca.parser.memory import MemoryOperand from osaca.parser.register import RegisterOperand from osaca.parser.immediate import ImmediateOperand from osaca.parser.flag import FlagOperand class KernelDG(nx.DiGraph): # threshold for checking dependency graph sequential or in parallel INSTRUCTION_THRESHOLD = 50 def __init__( self, parsed_kernel, parser, hw_model: MachineModel, semantics: ArchSemantics, timeout=10, flag_dependencies=False, ): self.timed_out = False self.kernel = parsed_kernel self.parser = parser self.model = hw_model self.arch_sem = semantics self.dg = self.create_DG(self.kernel, flag_dependencies) self.loopcarried_deps = self.check_for_loopcarried_dep( self.kernel, timeout, flag_dependencies ) @classmethod def _extend_path(cls, dst_list, kernel, dg, offset): for instr in kernel: generator_path = nx.algorithms.simple_paths.all_simple_paths( dg, instr.line_number, instr.line_number + offset ) tmp_list = list(generator_path) dst_list.extend(tmp_list) # print('Thread [{}-{}] done'.format(kernel[0]['line_number'], kernel[-1]['line_number'])) def create_DG(self, kernel, flag_dependencies=False): """ Create directed graph from given kernel :param kernel: Parsed asm kernel with assigned semantic information :type kerne: list :param flag_dependencies: indicating if dependencies of flags should be considered, defaults to `False` :type flag_dependencies: boolean, optional :returns: :class:`~nx.DiGraph` -- directed graph object """ # 1. go through kernel instruction forms and add them as node attribute # 2. find edges (to dependend further instruction) # 3. get LT value and set as edge weight dg = nx.DiGraph() for i, instruction_form in enumerate(kernel): dg.add_node(instruction_form.line_number) dg.nodes[instruction_form.line_number]["instruction_form"] = instruction_form # add load as separate node if existent if ( INSTR_FLAGS.HAS_LD in instruction_form.flags and INSTR_FLAGS.LD not in instruction_form.flags ): # add new node dg.add_node(instruction_form.line_number + 0.1) dg.nodes[instruction_form.line_number + 0.1]["instruction_form"] = instruction_form # and set LD latency as edge weight dg.add_edge( instruction_form.line_number + 0.1, instruction_form.line_number, latency=instruction_form.latency - instruction_form.latency_wo_load, ) for dep, dep_flags in self.find_depending( instruction_form, kernel[i + 1 :], flag_dependencies ): # print(instruction_form.line_number,"\t",dep.line_number,"\n") edge_weight = ( instruction_form.latency if "mem_dep" in dep_flags or instruction_form.latency_wo_load is None else instruction_form.latency_wo_load ) if "storeload_dep" in dep_flags and self.model is not None: edge_weight += self.model.get("store_to_load_forward_latency", 0) if "p_indexed" in dep_flags and self.model is not None: edge_weight = self.model.get("p_index_latency", 1) dg.add_edge( instruction_form.line_number, dep.line_number, latency=edge_weight, ) dg.nodes[dep.line_number]["instruction_form"] = dep return dg def check_for_loopcarried_dep(self, kernel, timeout=10, flag_dependencies=False): """ Try to find loop-carried dependencies in given kernel. :param kernel: Parsed asm kernel with assigned semantic information :type kernel: list :param timeout: Timeout in seconds for parallel execution, defaults to `10`. Set to `0` for no timeout :type timeout: int :returns: `dict` -- dependency dictionary with all cyclic LCDs """ # increase line number for second kernel loop offset = max(1000, max([i.line_number for i in kernel])) tmp_kernel = [] + kernel for orig_iform in kernel: temp_iform = copy.copy(orig_iform) temp_iform.line_number += offset tmp_kernel.append(temp_iform) # get dependency graph dg = self.create_DG(tmp_kernel, flag_dependencies) # build cyclic loop-carried dependencies loopcarried_deps = [] all_paths = [] klen = len(kernel) if klen >= self.INSTRUCTION_THRESHOLD: # parallel execution with static scheduling num_cores = cpu_count() workload = int((klen - 1) / num_cores) + 1 starts = [tid * workload for tid in range(num_cores)] ends = [min((tid + 1) * workload, klen) for tid in range(num_cores)] instrs = [kernel[s:e] for s, e in zip(starts, ends)] with Manager() as manager: all_paths = manager.list() processes = [ Process( target=KernelDG._extend_path, args=(all_paths, instr_section, dg, offset), ) for instr_section in instrs ] for p in processes: p.start() if timeout == -1: # no timeout for p in processes: p.join() else: start_time = time.time() while time.time() - start_time <= timeout: if any(p.is_alive() for p in processes): time.sleep(0.2) else: # all procs done for p in processes: p.join() break else: self.timed_out = True # terminate running processes for p in processes: if p.is_alive(): p.kill() p.join() all_paths = list(all_paths) else: # sequential execution to avoid overhead when analyzing smaller kernels for instr in kernel: all_paths.extend( nx.algorithms.simple_paths.all_simple_paths( dg, instr.line_number, instr.line_number + offset ) ) paths_set = set() for path in all_paths: lat_sum = 0.0 # extend path by edge bound latencies (e.g., store-to-load latency) lat_path = [] for s, d in nx.utils.pairwise(path): edge_lat = dg.edges[s, d]["latency"] # map source node back to original line numbers if s > offset: s -= offset lat_path.append((s, edge_lat)) lat_sum += edge_lat if d > offset: d -= offset lat_path.sort() # Ignore duplicate paths which differ only in the root node if tuple(lat_path) in paths_set: continue paths_set.add(tuple(lat_path)) loopcarried_deps.append((lat_sum, lat_path)) loopcarried_deps.sort(reverse=True) # map lcd back to nodes loopcarried_deps_dict = {} for lat_sum, involved_lines in loopcarried_deps: dict_key = "-".join([str(il[0]) for il in involved_lines]) loopcarried_deps_dict[dict_key] = { "root": self._get_node_by_lineno(involved_lines[0][0]), "dependencies": [ (self._get_node_by_lineno(ln), lat) for ln, lat in involved_lines ], "latency": lat_sum, } return loopcarried_deps_dict def _get_node_by_lineno(self, lineno, kernel=None, all=False): """Return instruction form with line number ``lineno`` from kernel""" if kernel is None: kernel = self.kernel result = [instr for instr in kernel if instr.line_number == lineno] if not all: return result[0] else: return result def get_critical_path(self): """Find and return critical path after the creation of a directed graph.""" max_latency_instr = max(self.kernel, key=lambda k: k.latency) if nx.algorithms.dag.is_directed_acyclic_graph(self.dg): longest_path = nx.algorithms.dag.dag_longest_path(self.dg, weight="latency") # TODO verify that we can remove the next two lince due to earlier initialization for line_number in longest_path: self._get_node_by_lineno(int(line_number)).latency_cp = 0 # set cp latency to instruction path_latency = 0.0 for s, d in nx.utils.pairwise(longest_path): node = self._get_node_by_lineno(int(s)) node.latency_cp = self.dg.edges[(s, d)]["latency"] path_latency += node.latency_cp # add latency for last instruction node = self._get_node_by_lineno(int(longest_path[-1])) node.latency_cp = node.latency if max_latency_instr.latency > path_latency: max_latency_instr.latency_cp = float(max_latency_instr.latency) return [max_latency_instr] else: return [x for x in self.kernel if x.line_number in longest_path] else: # split to DAG raise NotImplementedError("Kernel is cyclic.") def get_loopcarried_dependencies(self): """ Return all LCDs from kernel (after :func:`~KernelDG.check_for_loopcarried_dep` was run) """ if nx.algorithms.dag.is_directed_acyclic_graph(self.dg): return self.loopcarried_deps else: # split to DAG raise NotImplementedError("Kernel is cyclic.") def find_depending(self, instruction_form, instructions, flag_dependencies=False): """ Find instructions in `instructions` depending on a given instruction form's results. :param dict instruction_form: instruction form to check for dependencies :param list instructions: instructions to check :param flag_dependencies: indicating if dependencies of flags should be considered, defaults to `False` :type flag_dependencies: boolean, optional :returns: iterator if all directly dependent instruction forms and according flags """ if instruction_form.semantic_operands is None: return for dst in chain( instruction_form.semantic_operands["destination"], instruction_form.semantic_operands["src_dst"], ): # TODO instructions before must be considered as well, if they update registers # not used by insruction_form. E.g., validation/build/A64FX/gcc/O1/gs-2d-5pt.marked.s register_changes = self._update_reg_changes(instruction_form) # print("FROM", instruction_form.line, register_changes) for i, instr_form in enumerate(instructions): self._update_reg_changes(instr_form, register_changes) # print(" TO", instr_form.line, register_changes) if isinstance(dst, RegisterOperand): # read of register if self.is_read(dst, instr_form): if ( dst.pre_indexed or dst.post_indexed or (isinstance(dst.post_indexed, dict)) ): yield instr_form, ["p_indexed"] else: yield instr_form, [] # write to register -> abort if self.is_written(dst, instr_form): break if isinstance(dst, FlagOperand) and flag_dependencies: # read of flag if self.is_read(dst, instr_form): yield instr_form, [] # write to flag -> abort if self.is_written(dst, instr_form): break if isinstance(dst, MemoryOperand): # base register is altered during memory access if dst.pre_indexed: if self.is_written(dst.base, instr_form): break # if dst.memory.base: # if self.is_read(dst.memory.base, instr_form): # yield instr_form, [] # if dst.memory.index: # if self.is_read(dst.memory.index, instr_form): # yield instr_form, [] if dst.post_indexed: # Check for read of base register until overwrite if self.is_written(dst.base, instr_form): break # TODO record register changes # (e.g., mov, leaadd, sub, inc, dec) in instructions[:i] # and pass to is_memload and is_memstore to consider relevance. # load from same location (presumed) if self.is_memload(dst, instr_form, register_changes): yield instr_form, ["storeload_dep"] # store to same location (presumed) if self.is_memstore(dst, instr_form, register_changes): break self._update_reg_changes(instr_form, register_changes, only_postindexed=True) def _update_reg_changes(self, iform, reg_state=None, only_postindexed=False): if self.arch_sem is None: # This analysis requires semenatics to be available return {} if reg_state is None: reg_state = {} for reg, change in self.arch_sem.get_reg_changes(iform, only_postindexed).items(): if change is None or reg_state.get(reg, {}) is None: reg_state[reg] = None else: reg_state.setdefault(reg, {"name": reg, "value": 0}) if change["name"] != reg: # renaming occured, ovrwrite value with up-to-now change of source register reg_state[reg]["name"] = change["name"] src_reg_state = reg_state.get(change["name"], {"value": 0}) if src_reg_state is None: # original register's state was changed beyond reconstruction reg_state[reg] = None continue reg_state[reg]["value"] = src_reg_state["value"] reg_state[reg]["value"] += change["value"] return reg_state def get_dependent_instruction_forms(self, instr_form=None, line_number=None): """ Returns iterator """ if not instr_form and not line_number: raise ValueError("Either instruction form or line_number required.") line_number = line_number if line_number else instr_form["line_number"] if self.dg.has_node(line_number): return self.dg.successors(line_number) return iter([]) def is_read(self, register, instruction_form): """Check if instruction form reads from given register""" is_read = False if instruction_form.semantic_operands is None: return is_read for src in chain( instruction_form.semantic_operands["source"], instruction_form.semantic_operands["src_dst"], ): if isinstance(src, RegisterOperand): is_read = self.parser.is_reg_dependend_of(register, src) or is_read if isinstance(src, FlagOperand): is_read = self.parser.is_flag_dependend_of(register, src) or is_read if isinstance(src, MemoryOperand): if src.base is not None: is_read = self.parser.is_reg_dependend_of(register, src.base) or is_read if src.index is not None and isinstance(src.index, RegisterOperand): is_read = self.parser.is_reg_dependend_of(register, src.index) or is_read # Check also if read in destination memory address for dst in chain( instruction_form.semantic_operands["destination"], instruction_form.semantic_operands["src_dst"], ): if isinstance(dst, MemoryOperand): if dst.base is not None: is_read = self.parser.is_reg_dependend_of(register, dst.base) or is_read if dst.index is not None: is_read = self.parser.is_reg_dependend_of(register, dst.index) or is_read return is_read def is_memload(self, mem, instruction_form, register_changes={}): """Check if instruction form loads from given location, assuming register_changes""" if instruction_form.semantic_operands is None: return False for src in chain( instruction_form.semantic_operands["source"], instruction_form.semantic_operands["src_dst"], ): # Here we check for mem dependecies only if not isinstance(src, MemoryOperand): continue # src = src.memory # determine absolute address change addr_change = 0 if isinstance(src.offset, ImmediateOperand) and src.offset.value is not None: addr_change += src.offset.value if isinstance(mem.offset, ImmediateOperand) and mem.offset.value is not None: addr_change -= mem.offset.value if mem.base and src.base: base_change = register_changes.get( (src.base.prefix if src.base.prefix is not None else "") + src.base.name, { "name": (src.base.prefix if src.base.prefix is not None else "") + src.base.name, "value": 0, }, ) if base_change is None: # Unknown change occurred continue if ( mem.base.prefix if mem.base.prefix is not None else "" + mem.base.name != base_change["name"] ): # base registers do not match continue addr_change += base_change["value"] elif mem.base or src.base: # base registers do not match continue if mem.index and src.index: index_change = register_changes.get( (src.index.prefix if src.index.prefix is not None else "") + src.index.name, { "name": (src.index.prefix if src.index.prefix is not None else "") + src.index.name, "value": 0, }, ) if index_change is None: # Unknown change occurred continue if mem.scale != src.scale: # scale factors do not match continue if ( mem.index.prefix if mem.index.prefix is not None else "" + mem.index.name != index_change["name"] ): # index registers do not match continue addr_change += index_change["value"] * src.scale elif mem.index or src.index: # index registers do not match continue # if instruction_form.line_number == 3: if addr_change == 0: return True return False def is_written(self, register, instruction_form): """Check if instruction form writes in given register""" is_written = False if instruction_form.semantic_operands is None: return is_written for dst in chain( instruction_form.semantic_operands["destination"], instruction_form.semantic_operands["src_dst"], ): if isinstance(dst, RegisterOperand): is_written = self.parser.is_reg_dependend_of(register, dst) or is_written if isinstance(dst, FlagOperand): is_written = self.parser.is_flag_dependend_of(register, dst) or is_written if isinstance(dst, MemoryOperand): if dst.pre_indexed or dst.post_indexed: is_written = self.parser.is_reg_dependend_of(register, dst.base) or is_written # Check also for possible pre- or post-indexing in memory addresses for src in chain( instruction_form.semantic_operands["source"], instruction_form.semantic_operands["src_dst"], ): if isinstance(src, MemoryOperand): if src.pre_indexed or src.post_indexed: is_written = self.parser.is_reg_dependend_of(register, src.base) or is_written return is_written def is_memstore(self, mem, instruction_form, register_changes={}): """Check if instruction form stores to given location, assuming unchanged registers""" is_store = False if instruction_form.semantic_operands is None: return is_store for dst in chain( instruction_form.semantic_operands["destination"], instruction_form.semantic_operands["src_dst"], ): if isinstance(dst, MemoryOperand): is_store = mem == dst or is_store return is_store def export_graph(self, filepath=None): """ Export graph with highlighted CP and LCDs as DOT file. Writes it to 'osaca_dg.dot' if no other path is given. :param filepath: path to write DOT file, defaults to None. :type filepath: str, optional """ graph = copy.deepcopy(self.dg) cp = self.get_critical_path() cp_line_numbers = [x.line_number for x in cp] lcd = self.get_loopcarried_dependencies() lcd_line_numbers = {} for dep in lcd: lcd_line_numbers[dep] = [x.line_number for x, lat in lcd[dep]["dependencies"]] # create LCD edges for dep in lcd_line_numbers: min_line_number = min(lcd_line_numbers[dep]) max_line_number = max(lcd_line_numbers[dep]) graph.add_edge(min_line_number, max_line_number, dir="back") graph.edges[min_line_number, max_line_number]["latency"] = [ lat for x, lat in lcd[dep]["dependencies"] if x.line_number == max_line_number ] # add label to edges for e in graph.edges: graph.edges[e]["label"] = graph.edges[e]["latency"] # add CP values to graph for n in cp: graph.nodes[n.line_number]["instruction_form"].latency_cp = n.latency_cp # Make the critical path bold. for n in graph.nodes: if n in cp_line_numbers: # graph.nodes[n]['color'] = 1 graph.nodes[n]["style"] = "bold" graph.nodes[n]["penwidth"] = 4 # Make critical path edges bold. for e in graph.edges: if ( graph.nodes[e[0]]["instruction_form"].line_number in cp_line_numbers and graph.nodes[e[1]]["instruction_form"].line_number in cp_line_numbers and e[0] < e[1] ): bold_edge = True for i in range(e[0] + 1, e[1]): if i in cp_line_numbers: bold_edge = False if bold_edge: graph.edges[e]["style"] = "bold" graph.edges[e]["penwidth"] = 3 # Color the cycles created by loop-carried dependencies, longest first, never recoloring # any node or edge, so that the longest LCD and most long chains that are involved in the # loop are legible. lcd_by_latencies = sorted( ( (latency, list(deps)) for latency, deps in groupby(lcd, lambda dep: lcd[dep]["latency"]) ), reverse=True ) node_colors = {} edge_colors = {} colors_used = 0 for i, (latency, deps) in enumerate(lcd_by_latencies): color = None for dep in deps: path = lcd_line_numbers[dep] for n in path: if n not in node_colors: if not color: color = colors_used + 1 colors_used += 1 node_colors[n] = color for u, v in zip(path, path[1:] + [path[0]]): if (u, v) not in edge_colors: # Don’t introduce a color just for an edge. if not color: color = colors_used edge_colors[u, v] = color max_color = min(11, colors_used) colorscheme = f"spectral{max(3, max_color)}" graph.graph["node"] = {"colorscheme" : colorscheme} graph.graph["edge"] = {"colorscheme" : colorscheme} for n, color in node_colors.items(): if "style" not in graph.nodes[n]: graph.nodes[n]["style"] = "filled" else: graph.nodes[n]["style"] += ",filled" graph.nodes[n]["fillcolor"] = color if ( (max_color >= 4 and color in (1, max_color)) or (max_color >= 10 and color in (1, 2, max_color - 1 , max_color)) ): graph.nodes[n]["fontcolor"] = "white" for (u, v), color in edge_colors.items(): # The backward edge of the cycle is represented as the corresponding forward # edge with the attribute dir=back. edge = graph.edges[u, v] if (u, v) in graph.edges else graph.edges[v, u] edge["color"] = color # rename node from [idx] to [idx mnemonic] and add shape mapping = {} for n in graph.nodes: if int(n) != n: mapping[n] = "{}: LOAD".format(int(n)) graph.nodes[n]["fontname"] = "italic" graph.nodes[n]["fontsize"] = 11.0 else: node = graph.nodes[n]["instruction_form"] if node.mnemonic is not None: mapping[n] = "{}: {}".format(n, node.mnemonic) else: label = "label" if node.label is not None else None label = "directive" if node.directive is not None else label label = "comment" if node.comment is not None and label is None else label mapping[n] = "{}: {}".format(n, label) graph.nodes[n]["fontname"] = "italic" graph.nodes[n]["fontsize"] = 11.0 graph.nodes[n]["shape"] = "rectangle" nx.relabel.relabel_nodes(graph, mapping, copy=False) if filepath: nx.drawing.nx_agraph.write_dot(graph, filepath) else: nx.drawing.nx_agraph.write_dot(graph, "osaca_dg.dot")