diff --git a/migen/graph/treeviz.py b/migen/graph/treeviz.py index 4e4c7750b..9e5a44fd7 100644 --- a/migen/graph/treeviz.py +++ b/migen/graph/treeviz.py @@ -1,9 +1,11 @@ import cairo import math -def _cairo_draw_node(ctx, radius, color, outer_color, s): +def _cairo_draw_node(ctx, dx, radius, color, outer_color, s): ctx.save() + ctx.translate(dx, 0) + ctx.set_line_width(0.0) gradient_color = cairo.RadialGradient(0, 0, 0, 0, 0, radius) gradient_color.add_color_stop_rgb(0, *color) @@ -40,36 +42,44 @@ class RenderNode: self.radius = radius self.pitch = self.radius*3 - def get_extents(self): + def get_dimensions(self): if self.children: - cw, ch = zip(*[c.get_extents() for c in self.children]) - w = max(cw)*len(self.children) - h = self.pitch + max(ch) + cws, chs, cdxs = zip(*[c.get_dimensions() for c in self.children]) + w = sum(cws) + h = self.pitch + max(chs) + dx = cws[0]/4 - cws[-1]/4 else: w = h = self.pitch - return w, h + dx = 0 + return w, h, dx def render(self, ctx): - _cairo_draw_node(ctx, self.radius, self.color, self.outer_color, self.label) if self.children: - cpitch = max([c.get_extents()[0] for c in self.children]) - first_child_x = -(cpitch*(len(self.children) - 1))/2 + cws, chs, cdxs = zip(*[c.get_dimensions() for c in self.children]) + first_child_x = -sum(cws)/2 ctx.save() ctx.translate(first_child_x, self.pitch) - for c in self.children: + for c, w in zip(self.children, cws): + ctx.translate(w/2, 0) c.render(ctx) - ctx.translate(cpitch, 0) + ctx.translate(w/2, 0) ctx.restore() + dx = cws[0]/4 - cws[-1]/4 + current_x = first_child_x - for c in self.children: + for c, w, cdx in zip(self.children, cws, cdxs): current_y = self.pitch - c.radius - _cairo_draw_connection(ctx, 0, self.radius, self.outer_color, current_x, current_y, c.outer_color) - current_x += cpitch + current_x += w/2 + _cairo_draw_connection(ctx, dx, self.radius, self.outer_color, current_x+cdx, current_y, c.outer_color) + current_x += w/2 + else: + dx = 0 + _cairo_draw_node(ctx, dx, self.radius, self.color, self.outer_color, self.label) def to_svg(self, name): - w, h = self.get_extents() + w, h, dx = self.get_dimensions() surface = cairo.SVGSurface(name, w, h) ctx = cairo.Context(surface) ctx.translate(w/2, self.pitch/2)