soc/interconnect/stream: use new Converter/StrideConverter

This commit is contained in:
Florent Kermarrec 2016-03-16 17:00:58 +01:00
parent 8c272c1f6f
commit c860581b86

View file

@ -29,11 +29,11 @@ class EndpointDescription:
attributed.add(f[0])
full_layout = [
("payload", _make_m2s(self.payload_layout)),
("param", _make_m2s(self.param_layout)),
("stb", 1, DIR_M_TO_S),
("ack", 1, DIR_S_TO_M),
("eop", 1, DIR_M_TO_S)
("eop", 1, DIR_M_TO_S),
("payload", _make_m2s(self.payload_layout)),
("param", _make_m2s(self.param_layout))
]
return full_layout
@ -62,9 +62,8 @@ class Sink(Endpoint):
class _FIFOWrapper(Module):
def __init__(self, fifo_class, layout, depth):
self.sink = Sink(layout)
self.source = Source(layout)
self.busy = Signal()
self.sink = Endpoint(layout)
self.source = Endpoint(layout)
# # #
@ -112,10 +111,10 @@ class AsyncFIFO(_FIFOWrapper):
class Multiplexer(Module):
def __init__(self, layout, n):
self.source = Source(layout)
self.source = Endpoint(layout)
sinks = []
for i in range(n):
sink = Sink(layout)
sink = Endpoint(layout)
setattr(self, "sink"+str(i), sink)
sinks.append(sink)
self.sel = Signal(max=n)
@ -130,10 +129,10 @@ class Multiplexer(Module):
class Demultiplexer(Module):
def __init__(self, layout, n):
self.sink = Sink(layout)
self.sink = Endpoint(layout)
sources = []
for i in range(n):
source = Source(layout)
source = Endpoint(layout)
setattr(self, "source"+str(i), source)
sources.append(source)
self.sel = Signal(max=n)
@ -145,6 +144,206 @@ class Demultiplexer(Module):
cases[i] = self.sink.connect(source)
self.comb += Case(self.sel, cases)
class _UpConverter(Module):
def __init__(self, nbits_from, nbits_to, ratio, reverse):
self.sink = sink = Endpoint([("data", nbits_from)])
self.source = source = Endpoint([("data", nbits_to),
("valid_token_count", bits_for(ratio))])
self.latency = 1
# # #
# control path
demux = Signal(max=ratio)
load_part = Signal()
strobe_all = Signal()
self.comb += [
sink.ack.eq(~strobe_all | source.ack),
source.stb.eq(strobe_all),
load_part.eq(sink.stb & sink.ack)
]
demux_last = ((demux == (ratio - 1)) | sink.eop)
self.sync += [
If(source.ack, strobe_all.eq(0)),
If(load_part,
If(demux_last,
demux.eq(0),
strobe_all.eq(1)
).Else(
demux.eq(demux + 1)
)
),
If(source.stb & source.ack,
source.eop.eq(sink.eop),
).Elif(sink.stb & sink.ack,
source.eop.eq(sink.eop | source.eop)
)
]
# data path
cases = {}
for i in range(ratio):
n = ratio-i-1 if reverse else i
cases[i] = source.data[n*nbits_from:(n+1)*nbits_from].eq(sink.data)
self.sync += If(load_part, Case(demux, cases))
# valid token count
self.sync += If(load_part, source.valid_token_count.eq(demux + 1))
class _DownConverter(Module):
def __init__(self, nbits_from, nbits_to, ratio, reverse):
self.sink = sink = Endpoint([("data", nbits_from)])
self.source = source = Endpoint([("data", nbits_to),
("valid_token_count", 1)])
self.latency = 0
# # #
# control path
mux = Signal(max=ratio)
last = Signal()
self.comb += [
last.eq(mux == (ratio-1)),
source.stb.eq(sink.stb),
source.eop.eq(sink.eop & last),
sink.ack.eq(last & source.ack)
]
self.sync += \
If(source.stb & source.ack,
If(last,
mux.eq(0)
).Else(
mux.eq(mux + 1)
)
)
# data path
cases = {}
for i in range(ratio):
n = ratio-i-1 if reverse else i
cases[i] = source.data.eq(sink.data[n*nbits_to:(n+1)*nbits_to])
self.comb += Case(mux, cases).makedefault()
# valid token count
self.comb += source.valid_token_count.eq(last)
class _IdentityConverter(Module):
def __init__(self, nbits_from, nbits_to, ratio, reverse):
self.sink = sink = Endpoint([("data", nbits_from)])
self.source = source = Endpoint([("data", nbits_to),
("valid_token_count", 1)])
self.latency = 0
# # #
self.comb += [
sink.connect(source),
source.valid_token_count.eq(1)
]
def _get_converter_ratio(nbits_from, nbits_to):
if nbits_from > nbits_to:
converter_cls = _DownConverter
if nbits_from % nbits_to:
raise ValueError("Ratio must be an int")
ratio = nbits_from//nbits_to
elif nbits_from < nbits_to:
converter_cls = _UpConverter
if nbits_to % nbits_from:
raise ValueError("Ratio must be an int")
ratio = nbits_to//nbits_from
else:
converter_cls = _IdentityConverter
ratio = 1
return converter_cls, ratio
class Converter(Module):
def __init__(self, nbits_from, nbits_to, reverse=False,
report_valid_token_count=False):
self.cls, self.ratio = _get_converter_ratio(nbits_from, nbits_to)
# # #
converter = self.cls(nbits_from, nbits_to, self.ratio, reverse)
self.submodules += converter
self.latency = converter.latency
self.sink = converter.sink
if report_valid_token_count:
self.source = converter.source
else:
self.source = Endpoint([("data", nbits_to)])
self.comb += converter.source.connect(self.source,
leave_out=set(["valid_token_count"]))
class StrideConverter(Module):
def __init__(self, description_from, description_to, reverse=False):
self.sink = sink = Endpoint(description_from)
self.source = source = Endpoint(description_to)
# # #
nbits_from = len(sink.payload.raw_bits())
nbits_to = len(source.payload.raw_bits())
converter = Converter(nbits_from, nbits_to, reverse)
self.submodules += converter
# cast sink to converter.sink (user fields --> raw bits)
self.comb += [
converter.sink.stb.eq(sink.stb),
converter.sink.eop.eq(sink.eop),
sink.ack.eq(converter.sink.ack)
]
if converter.cls == _DownConverter:
ratio = converter.ratio
for i in range(ratio):
j = 0
for name, width in source.description.payload_layout:
src = getattr(sink, name)[i*width:(i+1)*width]
dst = converter.sink.data[i*nbits_to+j:i*nbits_to+j+width]
self.comb += dst.eq(src)
j += width
else:
self.comb += converter.sink.data.eq(sink.payload.raw_bits())
# cast converter.source to source (raw bits --> user fields)
self.comb += [
source.stb.eq(converter.source.stb),
source.eop.eq(converter.source.eop),
converter.source.ack.eq(source.ack)
]
if converter.cls == _UpConverter:
ratio = converter.ratio
for i in range(ratio):
j = 0
for name, width in sink.description.payload_layout:
src = converter.source.data[i*nbits_from+j:i*nbits_from+j+width]
dst = getattr(source, name)[i*width:(i+1)*width]
self.comb += dst.eq(src)
j += width
else:
self.comb += source.payload.raw_bits().eq(converter.source.data)
# connect params
if converter.latency == 0:
self.comb += source.param.eq(sink.param)
elif converter.latency == 1:
self.sync += source.param.eq(sink.param)
else:
raise ValueError
# TODO: clean up code below
# XXX
@ -351,99 +550,6 @@ class Pack(Module):
]
class Chunkerize(CombinatorialActor):
def __init__(self, layout_from, layout_to, n, reverse=False):
self.sink = Sink(layout_from)
if isinstance(layout_to, EndpointDescription):
layout_to = copy(layout_to)
layout_to.payload_layout = pack_layout(layout_to.payload_layout, n)
else:
layout_to = pack_layout(layout_to, n)
self.source = Source(layout_to)
CombinatorialActor.__init__(self)
# # #
for i in range(n):
chunk = n-i-1 if reverse else i
for f in self.sink.description.payload_layout:
src = getattr(self.sink, f[0])
dst = getattr(getattr(self.source, "chunk"+str(chunk)), f[0])
self.comb += dst.eq(src[i*len(src)//n:(i+1)*len(src)//n])
for f in self.sink.description.param_layout:
src = getattr(self.sink, f[0])
dst = getattr(self.source, f[0])
self.comb += dst.eq(src)
class Unchunkerize(CombinatorialActor):
def __init__(self, layout_from, n, layout_to, reverse=False):
if isinstance(layout_from, EndpointDescription):
fields = layout_from.payload_layout
layout_from = copy(layout_from)
layout_from.payload_layout = pack_layout(layout_from.payload_layout, n)
else:
fields = layout_from
layout_from = pack_layout(layout_from, n)
self.sink = Sink(layout_from)
self.source = Source(layout_to)
CombinatorialActor.__init__(self)
# # #
for i in range(n):
chunk = n-i-1 if reverse else i
for f in fields:
src = getattr(getattr(self.sink, "chunk"+str(chunk)), f[0])
dst = getattr(self.source, f[0])
self.comb += dst[i*len(dst)//n:(i+1)*len(dst)//n].eq(src)
for f in self.sink.description.param_layout:
src = getattr(self.sink, f[0])
dst = getattr(self.source, f[0])
self.comb += dst.eq(src)
class Converter(Module):
def __init__(self, layout_from, layout_to, reverse=False):
self.sink = Sink(layout_from)
self.source = Source(layout_to)
# # #
width_from = len(self.sink.payload.raw_bits())
width_to = len(self.source.payload.raw_bits())
# downconverter
if width_from > width_to:
if width_from % width_to:
raise ValueError
ratio = width_from//width_to
self.submodules.chunkerize = Chunkerize(layout_from, layout_to, ratio, reverse)
self.submodules.unpack = Unpack(ratio, layout_to)
self.submodules += Pipeline(self.sink,
self.chunkerize,
self.unpack,
self.source)
# upconverter
elif width_to > width_from:
if width_to % width_from:
raise ValueError
ratio = width_to//width_from
self.submodules.pack = Pack(layout_from, ratio)
self.submodules.unchunkerize = Unchunkerize(layout_from, ratio, layout_to, reverse)
self.submodules += Pipeline(self.sink,
self.pack,
self.unchunkerize,
self.source)
# direct connection
else:
self.comb += self.sink.connect(self.source)
class Pipeline(Module):
def __init__(self, *modules):
n = len(modules)