treeviz: improve layout of unbalanced trees

This commit is contained in:
Sebastien Bourdeauducq 2013-08-07 18:32:02 +02:00
parent 7a243171bd
commit ceddd8afa4
1 changed files with 25 additions and 15 deletions

View File

@ -1,9 +1,11 @@
import cairo
import math
def _cairo_draw_node(ctx, radius, color, outer_color, s):
def _cairo_draw_node(ctx, dx, radius, color, outer_color, s):
ctx.save()
ctx.translate(dx, 0)
ctx.set_line_width(0.0)
gradient_color = cairo.RadialGradient(0, 0, 0, 0, 0, radius)
gradient_color.add_color_stop_rgb(0, *color)
@ -40,36 +42,44 @@ class RenderNode:
self.radius = radius
self.pitch = self.radius*3
def get_extents(self):
def get_dimensions(self):
if self.children:
cw, ch = zip(*[c.get_extents() for c in self.children])
w = max(cw)*len(self.children)
h = self.pitch + max(ch)
cws, chs, cdxs = zip(*[c.get_dimensions() for c in self.children])
w = sum(cws)
h = self.pitch + max(chs)
dx = cws[0]/4 - cws[-1]/4
else:
w = h = self.pitch
return w, h
dx = 0
return w, h, dx
def render(self, ctx):
_cairo_draw_node(ctx, self.radius, self.color, self.outer_color, self.label)
if self.children:
cpitch = max([c.get_extents()[0] for c in self.children])
first_child_x = -(cpitch*(len(self.children) - 1))/2
cws, chs, cdxs = zip(*[c.get_dimensions() for c in self.children])
first_child_x = -sum(cws)/2
ctx.save()
ctx.translate(first_child_x, self.pitch)
for c in self.children:
for c, w in zip(self.children, cws):
ctx.translate(w/2, 0)
c.render(ctx)
ctx.translate(cpitch, 0)
ctx.translate(w/2, 0)
ctx.restore()
dx = cws[0]/4 - cws[-1]/4
current_x = first_child_x
for c in self.children:
for c, w, cdx in zip(self.children, cws, cdxs):
current_y = self.pitch - c.radius
_cairo_draw_connection(ctx, 0, self.radius, self.outer_color, current_x, current_y, c.outer_color)
current_x += cpitch
current_x += w/2
_cairo_draw_connection(ctx, dx, self.radius, self.outer_color, current_x+cdx, current_y, c.outer_color)
current_x += w/2
else:
dx = 0
_cairo_draw_node(ctx, dx, self.radius, self.color, self.outer_color, self.label)
def to_svg(self, name):
w, h = self.get_extents()
w, h, dx = self.get_dimensions()
surface = cairo.SVGSurface(name, w, h)
ctx = cairo.Context(surface)
ctx.translate(w/2, self.pitch/2)