bus: replace simple bus module with new bidirectional Record

This commit is contained in:
Sebastien Bourdeauducq 2013-04-01 21:54:21 +02:00
parent 6a3c413717
commit 29b468529f
5 changed files with 149 additions and 228 deletions

View File

@ -1,23 +1,24 @@
from migen.fhdl.structure import * from migen.fhdl.structure import *
from migen.fhdl.specials import Memory from migen.fhdl.specials import Memory
from migen.fhdl.module import Module from migen.fhdl.module import Module
from migen.bus.simple import *
from migen.bus.transactions import * from migen.bus.transactions import *
from migen.bank.description import CSRStorage from migen.bank.description import CSRStorage
from migen.genlib.record import *
from migen.genlib.misc import chooser from migen.genlib.misc import chooser
data_width = 8 data_width = 8
class Interface(SimpleInterface): class Interface(Record):
def __init__(self): def __init__(self):
SimpleInterface.__init__(self, Description( Record.__init__(self, [
(M_TO_S, "adr", 14), ("adr", 14, DIR_M_TO_S),
(M_TO_S, "we", 1), ("we", 1, DIR_M_TO_S),
(M_TO_S, "dat_w", data_width), ("dat_w", data_width, DIR_M_TO_S),
(S_TO_M, "dat_r", data_width))) ("dat_r", data_width, DIR_S_TO_M)])
class Interconnect(SimpleInterconnect): class Interconnect(Module):
pass def __init__(self, master, slaves):
self.comb += master.connect(*slaves)
class Initiator(Module): class Initiator(Module):
def __init__(self, generator, bus=None): def __init__(self, generator, bus=None):
@ -53,80 +54,74 @@ def _compute_page_bits(nwords):
else: else:
return 0 return 0
class SRAM: class SRAM(Module):
def __init__(self, mem_or_size, address, read_only=None, bus=None): def __init__(self, mem_or_size, address, read_only=None, bus=None):
if isinstance(mem_or_size, Memory): if isinstance(mem_or_size, Memory):
self.mem = mem_or_size mem = mem_or_size
else: else:
self.mem = Memory(data_width, mem_or_size//(data_width//8)) mem = Memory(data_width, mem_or_size//(data_width//8))
self.address = address if mem.width > data_width:
if self.mem.width > data_width: csrw_per_memw = (self.mem.width + data_width - 1)//data_width
self.csrw_per_memw = (self.mem.width + data_width - 1)//data_width word_bits = bits_for(csrw_per_memw-1)
self.word_bits = bits_for(self.csrw_per_memw-1)
else: else:
self.csrw_per_memw = 1 csrw_per_memw = 1
self.word_bits = 0 word_bits = 0
page_bits = _compute_page_bits(self.mem.depth + self.word_bits) page_bits = _compute_page_bits(mem.depth + word_bits)
if page_bits: if page_bits:
self._page = CSRStorage(page_bits, name=self.mem.name_override + "_page") self._page = CSRStorage(page_bits, name=self.mem.name_override + "_page")
else: else:
self._page = None self._page = None
if read_only is None: if read_only is None:
if hasattr(self.mem, "bus_read_only"): if hasattr(mem, "bus_read_only"):
read_only = self.mem.bus_read_only read_only = mem.bus_read_only
else: else:
read_only = False read_only = False
self.read_only = read_only
if bus is None: if bus is None:
bus = Interface() bus = Interface()
self.bus = bus self.bus = bus
###
self.specials += mem
port = mem.get_port(write_capable=not read_only,
we_granularity=data_width if not read_only and word_bits else 0)
sel = Signal()
sel_r = Signal()
self.sync += sel_r.eq(sel)
self.comb += sel.eq(self.bus.adr[9:] == address)
if word_bits:
word_index = Signal(word_bits)
word_expanded = Signal(csrw_per_memw*data_width)
sync.append(word_index.eq(self.bus.adr[:word_bits]))
self.comb += [
word_expanded.eq(port.dat_r),
If(sel_r,
chooser(word_expanded, word_index, self.bus.dat_r, n=csrw_per_memw, reverse=True)
)
]
if not read_only:
self.comb += [
If(sel & self.bus.we, port.we.eq((1 << word_bits) >> self.bus.adr[:self.word_bits])),
port.dat_w.eq(Replicate(self.bus.dat_w, csrw_per_memw))
]
else:
self.comb += If(sel_r, self.bus.dat_r.eq(port.dat_r))
if not read_only:
self.comb += [
port.we.eq(sel & self.bus.we),
port.dat_w.eq(self.bus.dat_w)
]
if self._page is None:
self.comb += port.adr.eq(self.bus.adr[word_bits:len(port.adr)])
else:
pv = self._page.storage
self.comb += port.adr.eq(Cat(self.bus.adr[word_bits:len(port.adr)-len(pv)], pv))
def get_csrs(self): def get_csrs(self):
if self._page is None: if self._page is None:
return [] return []
else: else:
return [self._page] return [self._page]
def get_fragment(self):
port = self.mem.get_port(write_capable=not self.read_only,
we_granularity=data_width if not self.read_only and self.word_bits else 0)
sel = Signal()
sel_r = Signal()
sync = [sel_r.eq(sel)]
comb = [sel.eq(self.bus.adr[9:] == self.address)]
if self.word_bits:
word_index = Signal(self.word_bits)
word_expanded = Signal(self.csrw_per_memw*data_width)
sync.append(word_index.eq(self.bus.adr[:self.word_bits]))
comb += [
word_expanded.eq(port.dat_r),
If(sel_r,
chooser(word_expanded, word_index, self.bus.dat_r, n=self.csrw_per_memw, reverse=True)
)
]
if not self.read_only:
comb += [
If(sel & self.bus.we, port.we.eq((1 << self.word_bits) >> self.bus.adr[:self.word_bits])),
port.dat_w.eq(Replicate(self.bus.dat_w, self.csrw_per_memw))
]
else:
comb += [
If(sel_r,
self.bus.dat_r.eq(port.dat_r)
)
]
if not self.read_only:
comb += [
port.we.eq(sel & self.bus.we),
port.dat_w.eq(self.bus.dat_w)
]
if self._page is None:
comb.append(port.adr.eq(self.bus.adr[self.word_bits:len(port.adr)]))
else:
pv = self._page.storage
comb.append(port.adr.eq(Cat(self.bus.adr[self.word_bits:len(port.adr)-len(pv)], pv)))
return Fragment(comb, sync, specials={self.mem})

View File

@ -1,29 +1,31 @@
from migen.fhdl.structure import * from migen.fhdl.structure import *
from migen.bus.simple import * from migen.fhdl.module import Module
from migen.genlib.record import *
def phase_description(a, ba, d): def phase_description(a, ba, d):
return Description( return [
(M_TO_S, "address", a), ("address", a, DIR_M_TO_S),
(M_TO_S, "bank", ba), ("bank", ba, DIR_M_TO_S),
(M_TO_S, "cas_n", 1), ("cas_n", 1, DIR_M_TO_S),
(M_TO_S, "cke", 1), ("cke", 1, DIR_M_TO_S),
(M_TO_S, "cs_n", 1), ("cs_n", 1, DIR_M_TO_S),
(M_TO_S, "ras_n", 1), ("ras_n", 1, DIR_M_TO_S),
(M_TO_S, "we_n", 1), ("we_n", 1, DIR_M_TO_S),
(M_TO_S, "wrdata", d), ("wrdata", d, DIR_M_TO_S),
(M_TO_S, "wrdata_en", 1), ("wrdata_en", 1, DIR_M_TO_S),
(M_TO_S, "wrdata_mask", d//8), ("wrdata_mask", d//8, DIR_M_TO_S),
(M_TO_S, "rddata_en", 1), ("rddata_en", 1, DIR_M_TO_S),
(S_TO_M, "rddata", d), ("rddata", d, DIR_S_TO_M),
(S_TO_M, "rddata_valid", 1) ("rddata_valid", 1, DIR_S_TO_M)
) ]
class Interface: class Interface(Record):
def __init__(self, a, ba, d, nphases=1): def __init__(self, a, ba, d, nphases=1):
self.pdesc = phase_description(a, ba, d) layout = [("p"+str(i), phase_description(a, ba, d)) for i in range(nphases)]
self.phases = [SimpleInterface(self.pdesc) for i in range(nphases)] Record.__init__(self, layout)
self.phases = [getattr(self, "p"+str(i)) for i in range(nphases)]
for p in self.phases: for p in self.phases:
p.cas_n.reset = 1 p.cas_n.reset = 1
p.cs_n.reset = 1 p.cs_n.reset = 1
@ -35,28 +37,18 @@ class Interface:
r = [] r = []
add_suffix = len(self.phases) > 1 add_suffix = len(self.phases) > 1
for n, phase in enumerate(self.phases): for n, phase in enumerate(self.phases):
for signal in self.pdesc.desc: for field, size, direction in phase.layout:
if (m2s and signal[0] == M_TO_S) or (s2m and signal[0] == S_TO_M): if (m2s and direction == DIR_M_TO_S) or (s2m and direction == DIR_S_TO_M):
if add_suffix: if add_suffix:
if signal[0] == M_TO_S: if direction == DIR_M_TO_S:
suffix = "_p" + str(n) suffix = "_p" + str(n)
else: else:
suffix = "_w" + str(n) suffix = "_w" + str(n)
else: else:
suffix = "" suffix = ""
r.append(("dfi_" + signal[1] + suffix, getattr(phase, signal[1]))) r.append(("dfi_" + field + suffix, getattr(phase, field)))
return r return r
def interconnect_stmts(master, slave): class Interconnect(Module):
r = []
for pm, ps in zip(master.phases, slave.phases):
r += simple_interconnect_stmts(master.pdesc, pm, [ps])
return r
class Interconnect:
def __init__(self, master, slave): def __init__(self, master, slave):
self.master = master self.comb += master.connect(slave)
self.slave = slave
def get_fragment(self):
return Fragment(interconnect_stmts(self.master, self.slave))

View File

@ -1,47 +0,0 @@
from migen.fhdl.structure import *
from migen.genlib.misc import optree
(S_TO_M, M_TO_S) = range(2)
# desc is a list of tuples, each made up of:
# 0) S_TO_M/M_TO_S: data direction
# 1) string: name
# 2) int: width
class Description:
def __init__(self, *desc):
self.desc = desc
def get_names(self, direction, *exclude_list):
exclude = set(exclude_list)
return [signal[1]
for signal in self.desc
if signal[0] == direction and signal[1] not in exclude]
class SimpleInterface:
def __init__(self, desc):
self.desc = desc
modules = self.__module__.split(".")
busname = modules[len(modules)-1]
for signal in self.desc.desc:
signame = signal[1]
setattr(self, signame, Signal(signal[2], busname + "_" + signame))
def simple_interconnect_stmts(desc, master, slaves):
s2m = desc.get_names(S_TO_M)
m2s = desc.get_names(M_TO_S)
sl = [getattr(slave, name).eq(getattr(master, name))
for name in m2s for slave in slaves]
sl += [getattr(master, name).eq(
optree("|", [getattr(slave, name) for slave in slaves])
)
for name in s2m]
return sl
class SimpleInterconnect:
def __init__(self, master, slaves):
self.master = master
self.slaves = slaves
def get_fragment(self):
return Fragment(simple_interconnect_stmts(self.master.desc, self.master, self.slaves))

View File

@ -2,64 +2,59 @@ from migen.fhdl.structure import *
from migen.fhdl.specials import Memory from migen.fhdl.specials import Memory
from migen.fhdl.module import Module from migen.fhdl.module import Module
from migen.genlib import roundrobin from migen.genlib import roundrobin
from migen.genlib.record import *
from migen.genlib.misc import optree from migen.genlib.misc import optree
from migen.bus.simple import *
from migen.bus.transactions import * from migen.bus.transactions import *
from migen.sim.generic import Proxy from migen.sim.generic import Proxy
_desc = Description( _layout = [
(M_TO_S, "adr", 30), ("adr", 30, DIR_M_TO_S),
(M_TO_S, "dat_w", 32), ("dat_w", 32, DIR_M_TO_S),
(S_TO_M, "dat_r", 32), ("dat_r", 32, DIR_S_TO_M),
(M_TO_S, "sel", 4), ("sel", 4, DIR_M_TO_S),
(M_TO_S, "cyc", 1), ("cyc", 1, DIR_M_TO_S),
(M_TO_S, "stb", 1), ("stb", 1, DIR_M_TO_S),
(S_TO_M, "ack", 1), ("ack", 1, DIR_S_TO_M),
(M_TO_S, "we", 1), ("we", 1, DIR_M_TO_S),
(M_TO_S, "cti", 3), ("cti", 3, DIR_M_TO_S),
(M_TO_S, "bte", 2), ("bte", 2, DIR_M_TO_S),
(S_TO_M, "err", 1) ("err", 1, DIR_S_TO_M)
) ]
class Interface(SimpleInterface): class Interface(Record):
def __init__(self): def __init__(self):
SimpleInterface.__init__(self, _desc) Record.__init__(self, _layout)
class InterconnectPointToPoint(SimpleInterconnect): class InterconnectPointToPoint(Module):
def __init__(self, master, slave): def __init__(self, master, slave):
SimpleInterconnect.__init__(self, master, [slave]) self.comb += master.connect(slave)
class Arbiter: class Arbiter(Module):
def __init__(self, masters, target): def __init__(self, masters, target):
self.masters = masters self.submodules.rr = roundrobin.RoundRobin(len(masters))
self.target = target
self.rr = roundrobin.RoundRobin(len(self.masters))
def get_fragment(self):
comb = []
# mux master->slave signals # mux master->slave signals
for name in _desc.get_names(M_TO_S): for name, size, direction in _layout:
choices = Array(getattr(m, name) for m in self.masters) if direction == DIR_M_TO_S:
comb.append(getattr(self.target, name).eq(choices[self.rr.grant])) choices = Array(getattr(m, name) for m in masters)
self.comb += getattr(target, name).eq(choices[self.rr.grant])
# connect slave->master signals # connect slave->master signals
for name in _desc.get_names(S_TO_M): for name, size, direction in _layout:
source = getattr(self.target, name) if direction == DIR_S_TO_M:
for i, m in enumerate(self.masters): source = getattr(target, name)
for i, m in enumerate(masters):
dest = getattr(m, name) dest = getattr(m, name)
if name == "ack" or name == "err": if name == "ack" or name == "err":
comb.append(dest.eq(source & (self.rr.grant == i))) self.comb += dest.eq(source & (self.rr.grant == i))
else: else:
comb.append(dest.eq(source)) self.comb += dest.eq(source)
# connect bus requests to round-robin selector # connect bus requests to round-robin selector
reqs = [m.cyc for m in self.masters] reqs = [m.cyc for m in masters]
comb.append(self.rr.request.eq(Cat(*reqs))) self.comb += self.rr.request.eq(Cat(*reqs))
return Fragment(comb) + self.rr.get_fragment() class Decoder(Module):
class Decoder:
# slaves is a list of pairs: # slaves is a list of pairs:
# 0) function that takes the address signal and returns a FHDL expression # 0) function that takes the address signal and returns a FHDL expression
# that evaluates to 1 when the slave is selected and 0 otherwise. # that evaluates to 1 when the slave is selected and 0 otherwise.
@ -67,55 +62,43 @@ class Decoder:
# register adds flip-flops after the address comparators. Improves timing, # register adds flip-flops after the address comparators. Improves timing,
# but breaks Wishbone combinatorial feedback. # but breaks Wishbone combinatorial feedback.
def __init__(self, master, slaves, register=False): def __init__(self, master, slaves, register=False):
self.master = master ns = len(slaves)
self.slaves = slaves
self.register = register
def get_fragment(self):
comb = []
sync = []
ns = len(self.slaves)
slave_sel = Signal(ns) slave_sel = Signal(ns)
slave_sel_r = Signal(ns) slave_sel_r = Signal(ns)
# decode slave addresses # decode slave addresses
comb += [slave_sel[i].eq(fun(self.master.adr)) self.comb += [slave_sel[i].eq(fun(master.adr))
for i, (fun, bus) in enumerate(self.slaves)] for i, (fun, bus) in enumerate(slaves)]
if self.register: if register:
sync.append(slave_sel_r.eq(slave_sel)) self.sync += slave_sel_r.eq(slave_sel)
else: else:
comb.append(slave_sel_r.eq(slave_sel)) self.comb += slave_sel_r.eq(slave_sel)
# connect master->slaves signals except cyc # connect master->slaves signals except cyc
m2s_names = _desc.get_names(M_TO_S, "cyc") for slave in slaves:
comb += [getattr(slave[1], name).eq(getattr(self.master, name)) for name, size, direction in _layout:
for name in m2s_names for slave in self.slaves] if direction == DIR_M_TO_S and name != "cyc":
self.comb += getattr(slave[1], name).eq(getattr(master, name))
# combine cyc with slave selection signals # combine cyc with slave selection signals
comb += [slave[1].cyc.eq(self.master.cyc & slave_sel[i]) self.comb += [slave[1].cyc.eq(master.cyc & slave_sel[i])
for i, slave in enumerate(self.slaves)] for i, slave in enumerate(slaves)]
# generate master ack (resp. err) by ORing all slave acks (resp. errs) # generate master ack (resp. err) by ORing all slave acks (resp. errs)
comb += [ self.comb += [
self.master.ack.eq(optree("|", [slave[1].ack for slave in self.slaves])), master.ack.eq(optree("|", [slave[1].ack for slave in slaves])),
self.master.err.eq(optree("|", [slave[1].err for slave in self.slaves])) master.err.eq(optree("|", [slave[1].err for slave in slaves]))
] ]
# mux (1-hot) slave data return # mux (1-hot) slave data return
masked = [Replicate(slave_sel_r[i], len(self.master.dat_r)) & self.slaves[i][1].dat_r for i in range(len(self.slaves))] masked = [Replicate(slave_sel_r[i], len(master.dat_r)) & slaves[i][1].dat_r for i in range(ns)]
comb.append(self.master.dat_r.eq(optree("|", masked))) self.comb += master.dat_r.eq(optree("|", masked))
return Fragment(comb, sync) class InterconnectShared(Module):
class InterconnectShared:
def __init__(self, masters, slaves, register=False): def __init__(self, masters, slaves, register=False):
self._shared = Interface() shared = Interface()
self._arbiter = Arbiter(masters, self._shared) self.submodules += Arbiter(masters, shared)
self._decoder = Decoder(self._shared, slaves, register) self.submodules += Decoder(shared, slaves, register)
def get_fragment(self):
return self._arbiter.get_fragment() + self._decoder.get_fragment()
class Tap(Module): class Tap(Module):
def __init__(self, bus, handler=print): def __init__(self, bus, handler=print):
@ -200,34 +183,33 @@ class Target(Module):
else: else:
bus.ack = 0 bus.ack = 0
class SRAM: class SRAM(Module):
def __init__(self, mem_or_size, bus=None): def __init__(self, mem_or_size, bus=None):
if isinstance(mem_or_size, Memory): if isinstance(mem_or_size, Memory):
assert(mem_or_size.width <= 32) assert(mem_or_size.width <= 32)
self.mem = mem_or_size mem = mem_or_size
else: else:
self.mem = Memory(32, mem_or_size//4) mem = Memory(32, mem_or_size//4)
if bus is None: if bus is None:
bus = Interface() bus = Interface()
self.bus = bus self.bus = bus
def get_fragment(self): ###
# memory # memory
port = self.mem.get_port(write_capable=True, we_granularity=8) self.specials += mem
port = mem.get_port(write_capable=True, we_granularity=8)
# generate write enable signal # generate write enable signal
comb = [port.we[i].eq(self.bus.cyc & self.bus.stb & self.bus.we & self.bus.sel[i]) self.comb += [port.we[i].eq(self.bus.cyc & self.bus.stb & self.bus.we & self.bus.sel[i])
for i in range(4)] for i in range(4)]
# address and data # address and data
comb += [ self.comb += [
port.adr.eq(self.bus.adr[:len(port.adr)]), port.adr.eq(self.bus.adr[:len(port.adr)]),
port.dat_w.eq(self.bus.dat_w), port.dat_w.eq(self.bus.dat_w),
self.bus.dat_r.eq(port.dat_r) self.bus.dat_r.eq(port.dat_r)
] ]
# generate ack # generate ack
sync = [ self.sync += [
self.bus.ack.eq(0), self.bus.ack.eq(0),
If(self.bus.cyc & self.bus.stb & ~self.bus.ack, If(self.bus.cyc & self.bus.stb & ~self.bus.ack, self.bus.ack.eq(1))
self.bus.ack.eq(1)
)
] ]
return Fragment(comb, sync, specials={self.mem})

View File

@ -3,7 +3,7 @@ from migen.fhdl.specials import Memory
from migen.bus import wishbone from migen.bus import wishbone
from migen.genlib.fsm import FSM from migen.genlib.fsm import FSM
from migen.genlib.misc import split, displacer, chooser from migen.genlib.misc import split, displacer, chooser
from migen.genlib.record import Record from migen.genlib.record import Record, layout_len
# cachesize (in 32-bit words) is the size of the data store, must be a power of 2 # cachesize (in 32-bit words) is the size of the data store, must be a power of 2
class WB2ASMI: class WB2ASMI:
@ -60,15 +60,14 @@ class WB2ASMI:
] ]
# Tag memory # Tag memory
tag_mem = Memory(tagbits+1, 2**linebits)
tag_port = tag_mem.get_port(write_capable=True)
tag_layout = [("tag", tagbits), ("dirty", 1)] tag_layout = [("tag", tagbits), ("dirty", 1)]
tag_mem = Memory(layout_len(tag_layout), 2**linebits)
tag_port = tag_mem.get_port(write_capable=True)
tag_do = Record(tag_layout) tag_do = Record(tag_layout)
tag_di = Record(tag_layout) tag_di = Record(tag_layout)
comb += [ comb += [
Cat(*tag_do.flatten()).eq(tag_port.dat_r), tag_do.raw_bits().eq(tag_port.dat_r),
tag_port.dat_w.eq(Cat(*tag_di.flatten())) tag_port.dat_w.eq(tag_di.raw_bits())
] ]
comb += [ comb += [