From 94770f8e0b15fefdcd16072aef4fdcfe9981d58b Mon Sep 17 00:00:00 2001 From: Charles Papon Date: Wed, 22 Mar 2017 18:29:34 +0100 Subject: [PATCH] Add MachineCsr (untested) --- .../scala/SpinalRiscv/Plugin/MachineCsr.scala | 280 ++++++++++++++---- .../SpinalRiscv/Plugin/ShiftPlugins.scala | 2 +- src/main/scala/SpinalRiscv/Riscv.scala | 19 ++ src/main/scala/SpinalRiscv/TopLevel.scala | 20 +- src/test/cpp/testA/main.cpp | 10 +- src/test/cpp/testA/makefile | 7 + 6 files changed, 274 insertions(+), 64 deletions(-) diff --git a/src/main/scala/SpinalRiscv/Plugin/MachineCsr.scala b/src/main/scala/SpinalRiscv/Plugin/MachineCsr.scala index 723fa51..d5ca473 100644 --- a/src/main/scala/SpinalRiscv/Plugin/MachineCsr.scala +++ b/src/main/scala/SpinalRiscv/Plugin/MachineCsr.scala @@ -6,17 +6,68 @@ import SpinalRiscv._ import SpinalRiscv.Riscv._ import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable /** * Created by spinalvm on 21.03.17. */ - +trait CsrAccess +object CsrAccess { + object WRITE_ONLY extends CsrAccess + object READ_ONLY extends CsrAccess + object READ_WRITE extends CsrAccess + object NONE extends CsrAccess +} case class ExceptionPortInfo(port : Flow[UInt],stage : Stage) -case class MachineCsrConfig() +case class MachineCsrConfig( + mvendorid : BigInt, + marchid : BigInt, + mimpid : BigInt, + mhartid : BigInt, + misaExtensions : Int, + misaAccess : CsrAccess, + mtvecAccess : CsrAccess, + mtvecInit : BigInt, + mepcAccess : CsrAccess, + mscratchGen : Boolean, + mcauseAccess : CsrAccess, + mbadaddrAccess : CsrAccess -class MachineCsr extends Plugin[VexRiscv] with ExceptionService{ +) + + +case class CsrWrite(that : Data, bitOffset : Int) +case class CsrRead(that : Data , bitOffset : Int) +case class CsrMapping(){ + val mapping = mutable.HashMap[Int,ArrayBuffer[Any]]() + def addMappingAt(address : Int,that : Any) = mapping.getOrElseUpdate(address,new ArrayBuffer[Any]) += that + def r(csrAddress : Int, bitOffset : Int, that : Data): Unit = addMappingAt(csrAddress, CsrRead(that,bitOffset)) + def w(csrAddress : Int, bitOffset : Int, that : Data): Unit = addMappingAt(csrAddress, CsrWrite(that,bitOffset)) + def rw(csrAddress : Int, bitOffset : Int,that : Data): Unit ={ + r(csrAddress,bitOffset,that) + w(csrAddress,bitOffset,that) + } + + def rw(csrAddress : Int, thats : (Int, Data)*) : Unit = for(that <- thats) rw(csrAddress,that._1, that._2) + def r [T <: Data](csrAddress : Int, thats : (Int, Data)*) : Unit = for(that <- thats) r(csrAddress,that._1, that._2) + def rw[T <: Data](csrAddress : Int, that : T): Unit = rw(csrAddress,0,that) + def r [T <: Data](csrAddress : Int, that : T): Unit = r(csrAddress,0,that) + def rx [T <: Data](csrAddress : Int, thats : (Int, Data)*)(writable : Boolean) : Unit = + if(writable) + for(that <- thats) rw(csrAddress,that._1, that._2) + else + for(that <- thats) r(csrAddress,that._1, that._2) +} + + + +class MachineCsr(config : MachineCsrConfig) extends Plugin[VexRiscv] with ExceptionService { + import config._ + import CsrAccess._ + + //Mannage ExceptionService calls val exceptionPortsInfos = ArrayBuffer[ExceptionPortInfo]() def exceptionCodeWidth = 4 override def newExceptionPort(stage : Stage) = { @@ -26,6 +77,7 @@ class MachineCsr extends Plugin[VexRiscv] with ExceptionService{ } var jumpInterface : Flow[UInt] = null + var pluginExceptionPort : Flow[UInt] = null object EnvCtrlEnum extends SpinalEnum(binarySequential){ val NONE, EBREAK, ECALL, MRET = newElement() @@ -33,6 +85,7 @@ class MachineCsr extends Plugin[VexRiscv] with ExceptionService{ object ENV_CTRL extends Stageable(EnvCtrlEnum()) object EXCEPTION extends Stageable(Bool) + object IS_CSR extends Stageable(Bool) override def setup(pipeline: VexRiscv): Unit = { import pipeline.config._ @@ -41,22 +94,25 @@ class MachineCsr extends Plugin[VexRiscv] with ExceptionService{ LEGAL_INSTRUCTION -> True ) - val defaultActions = List[(Stageable[_ <: BaseType],Any)]( + val defaultCsrActions = List[(Stageable[_ <: BaseType],Any)]( LEGAL_INSTRUCTION -> True, + IS_CSR -> True, REGFILE_WRITE_VALID -> True, BYPASSABLE_EXECUTE_STAGE -> False, BYPASSABLE_MEMORY_STAGE -> False ) - val nonImmediatActions = defaultActions ++ List( - SRC1_CTRL -> Src1CtrlEnum.RS + val nonImmediatActions = defaultCsrActions ++ List( + SRC1_CTRL -> Src1CtrlEnum.RS, + REG1_USE -> True ) - val immediatActions = defaultActions + val immediatActions = defaultCsrActions val decoderService = pipeline.service(classOf[DecoderService]) decoderService.addDefault(ENV_CTRL, EnvCtrlEnum.NONE) + decoderService.addDefault(IS_CSR, False) decoderService.add(List( CSRRW -> nonImmediatActions, CSRRS -> nonImmediatActions, @@ -69,80 +125,188 @@ class MachineCsr extends Plugin[VexRiscv] with ExceptionService{ MRET -> (defaultEnv ++ List(ENV_CTRL -> EnvCtrlEnum.MRET)) )) - pipeline.fetch.insert(EXCEPTION) := False val pcManagerService = pipeline.service(classOf[JumpService]) jumpInterface = pcManagerService.createJumpInterface(pipeline.execute) jumpInterface.valid := False jumpInterface.payload.assignDontCare() - } + pluginExceptionPort = newExceptionPort(pipeline.execute) + pluginExceptionPort.valid := False + pluginExceptionPort.payload.assignDontCare() + } + def xlen = 32 override def build(pipeline: VexRiscv): Unit = { import pipeline._ import pipeline.config._ - - - val mtvec = UInt(32 bits) - val mepc = Reg(UInt(32 bits)) - - val mstatus = new Area{ - val MIE, MPIE = Reg(Bool) init(False) + //Manage ECALL instructions + when(execute.arbitration.isValid && execute.input(ENV_CTRL) === EnvCtrlEnum.ECALL){ + pluginExceptionPort.valid := True + pluginExceptionPort.payload := 3 } - val pipelineLiberator = new Area{ - val enable = False - decode.arbitration.haltIt setWhen(enable) - val done = ! List(fetch, decode, execute, memory, writeBack).map(_.arbitration.isValid).orR - } + pipeline plug new Area{ + //Define CSR registers + val csrMapping = new CsrMapping() + implicit class CsrAccessPimper(csrAccess : CsrAccess){ + def apply(csrAddress : Int, thats : (Int, Data)*) : Unit = csrAccess match{ + case `WRITE_ONLY` | `READ_WRITE` => for(that <- thats) csrMapping.w(csrAddress,that._1, that._2) + case `READ_ONLY` | `READ_WRITE` => for(that <- thats) csrMapping.r(csrAddress,that._1, that._2) + } + def apply(csrAddress : Int, that : Data) : Unit = csrAccess match{ + case `WRITE_ONLY` | `READ_WRITE` => csrMapping.w(csrAddress, 0, that) + case `READ_ONLY` | `READ_WRITE` => csrMapping.r(csrAddress, 0, that) + } + } - val exceptionPortCtrl = new Area{ - val pipelineHasException = List(fetch, decode, execute, memory, writeBack).map(s => s.arbitration.isValid && s.input(EXCEPTION)).orR - decode.arbitration.haltIt setWhen(pipelineHasException) + //Define CSR registers + val mtvec = RegInit(U(mtvecInit,xlen bits)) + val mepc = Reg(UInt(xlen bits)) + val mstatus = new Area{ + val MIE, MPIE = RegInit(False) + } + val mip = new Area{ + val MEIP, MTIP, MSIP = False //TODO + } + val mie = new Area{ + val MEIE, MTIE, MSIE = RegInit(False) + } + val mscratch = if(mscratchGen) Reg(Bits(xlen bits)) else null + val mcause = new Area{ + val interrupt = Reg(Bool) + val exceptionCode = Reg(UInt(exceptionCodeWidth bits)) + } + val mbadaddr = Reg(UInt(xlen bits)) - val groupedByStage = exceptionPortsInfos.map(_.stage).distinct.map(s => { - val stagePortsInfos = exceptionPortsInfos.filter(_.stage == s) - val stagePort = stagePortsInfos.length match{ - case 1 => stagePortsInfos.head.port - case _ => { - val groupedPort = Flow(UInt(exceptionCodeWidth bits)) - val valids = stagePortsInfos.map(_.port.valid) - val codes = stagePortsInfos.map(_.port.payload) - groupedPort.valid := valids.orR - groupedPort.payload := MuxOH(stagePortsInfos.map(_.port.valid), codes) - groupedPort + //Define CSR registers accessibility + if(mvendorid != null) READ_ONLY(CSR.MVENDORID, U(mvendorid)) + if(marchid != null) READ_ONLY(CSR.MARCHID , U(marchid )) + if(mimpid != null) READ_ONLY(CSR.MIMPID , U(mimpid )) + if(mhartid != null) READ_ONLY(CSR.MHARTID , U(mhartid )) + + misaAccess(CSR.MISA, xlen-2 -> U"01" , 0 -> U(misaExtensions)) + READ_ONLY(CSR.MIP, 11 -> mip.MEIP, 7 -> mip.MTIP, 3 -> mip.MSIP) + READ_WRITE(CSR.MIE, 11 -> mie.MEIE, 7 -> mie.MTIE, 3 -> mie.MSIE) + + mtvecAccess(CSR.MTVEC , mtvec) + mepcAccess(CSR.MEPC , mepc) + READ_ONLY(CSR.MSTATUS, 7 -> mstatus.MPIE, 3 -> mstatus.MIE) + if(mscratchGen) READ_WRITE(CSR.MSCRATCH, mscratch) + mcauseAccess(CSR.MCAUSE, xlen-1 -> mcause.interrupt, 0 -> mcause.exceptionCode) + mbadaddrAccess(CSR.MBADADDR, mbadaddr) + + + + //Used to make the pipeline empty softly (for interrupts) + val pipelineLiberator = new Area{ + val enable = False + decode.arbitration.haltIt setWhen(enable) + val done = ! List(fetch, decode, execute, memory, writeBack).map(_.arbitration.isValid).orR + } + + //Aggregate all exception port and remove required instructions + val exceptionPortCtrl = if(exceptionPortsInfos.nonEmpty) new Area{ + val firstStageIndexWithExceptionPort = exceptionPortsInfos.map(i => indexOf(i.stage)).min + val pipelineHasException = stages.drop(firstStageIndexWithExceptionPort).map(s => s.arbitration.isValid && s.input(EXCEPTION)).orR + decode.arbitration.haltIt setWhen(pipelineHasException) + + val groupedByStage = exceptionPortsInfos.map(_.stage).distinct.map(s => { + val stagePortsInfos = exceptionPortsInfos.filter(_.stage == s) + val stagePort = stagePortsInfos.length match{ + case 1 => stagePortsInfos.head.port + case _ => { + val groupedPort = Flow(UInt(exceptionCodeWidth bits)) + val valids = stagePortsInfos.map(_.port.valid) + val codes = stagePortsInfos.map(_.port.payload) + groupedPort.valid := valids.orR + groupedPort.payload := MuxOH(stagePortsInfos.map(_.port.valid), codes) + groupedPort + } + } + ExceptionPortInfo(stagePort,s) + }) + val sortedByStage = groupedByStage.sortWith((a, b) => pipeline.indexOf(a.stage) > pipeline.indexOf(b.stage)) + + sortedByStage.head.stage.insert(EXCEPTION) := False + for(portInfo <- sortedByStage; port = portInfo.port ; stage = portInfo.stage){ + when(port.valid){ + stages(indexOf(stage) - 1).arbitration.flushIt := True + stage.input(EXCEPTION) := True } } - ExceptionPortInfo(stagePort,s) - }) - val sortedByStage = groupedByStage.sortWith((a, b) => pipeline.indexOf(a.stage) > pipeline.indexOf(b.stage)) + } else null - for(portInfo <- sortedByStage; port = portInfo.port ; stage = portInfo.stage){ - when(port.valid){ - stages(indexOf(stage)).arbitration.flushIt := True - stage.input(EXCEPTION) := True + val interrupt = False + val exception = if(exceptionPortsInfos.nonEmpty) + writeBack.arbitration.isValid && writeBack.input(EXCEPTION) + else + False + + when(mstatus.MIE){ + pipelineLiberator.enable := interrupt + when(exception || (interrupt && pipelineLiberator.done)){ + jumpInterface.valid := True + jumpInterface.payload := mtvec + mstatus.MIE := False + mstatus.MPIE := mstatus.MIE + mepc := exception ? writeBack.input(PC) | prefetch.input(PC_CALC_WITHOUT_JUMP) } } - } - val interrupt = False - val exception = writeBack.arbitration.isValid && writeBack.input(EXCEPTION) - - when(mstatus.MIE){ - pipelineLiberator.enable := True - when(exception || (interrupt && pipelineLiberator.done)){ + when(memory.arbitration.isFiring && memory.input(ENV_CTRL) === EnvCtrlEnum.MRET){ jumpInterface.valid := True - jumpInterface.payload := mtvec - mstatus.MIE := False - mstatus.MPIE := mstatus.MIE - mepc := exception ? writeBack.input(PC) | prefetch.input(PC_CALC_WITHOUT_JUMP) + jumpInterface.payload := mepc + mstatus.MIE := mstatus.MPIE } - } - when(memory.arbitration.isFiring && memory.input(ENV_CTRL) === EnvCtrlEnum.MRET){ - jumpInterface.valid := True - jumpInterface.payload := mepc - mstatus.MIE := mstatus.MPIE + + + execute plug new Area { + + import execute._ + + val imm = IMM(input(INSTRUCTION)) + + val writeEnable = !((input(INSTRUCTION)(14 downto 13) === "01" && input(INSTRUCTION)(rs1Range) === 0) + || (input(INSTRUCTION)(14 downto 13) === "10" && imm.z === 0)) + val readEnable = input(INSTRUCTION)(rdRange) =/= 0 + + val writeSrc = input(INSTRUCTION)(14) ? imm.z.asBits.resized | input(SRC1) + val readData = B(0, 32 bits) + val writeData = input(INSTRUCTION)(12).mux( + False -> writeSrc, + True -> Mux(input(INSTRUCTION)(13), readData & ~writeSrc, readData | writeSrc) + ) + + when(arbitration.isValid && input(IS_CSR)) { + when(writeEnable) { + output(REGFILE_WRITE_DATA) := writeData + } + } + + + val csrAddress = input(INSTRUCTION)(csrRange) + when(arbitration.isValid && input(IS_CSR)) { + for ((address, jobs) <- csrMapping.mapping) { + when(csrAddress === address) { + when(writeEnable) { + for (element <- jobs) element match { + case element: CsrWrite => element.that.assignFromBits(writeData(element.bitOffset, element.that.getBitsWidth bits)) + // case element: BusSlaveFactoryOnWriteAtAddress => element.doThat() + case _ => + } + } + + for (element <- jobs) element match { + case element: CsrRead => readData(element.bitOffset, element.that.getBitsWidth bits) := element.that.asBits + // case element: BusSlaveFactoryOnReadAtAddress => when(readEnable) { element.doThat() } + case _ => + } + } + } + } + } } } } diff --git a/src/main/scala/SpinalRiscv/Plugin/ShiftPlugins.scala b/src/main/scala/SpinalRiscv/Plugin/ShiftPlugins.scala index 920a0c2..97b5af3 100644 --- a/src/main/scala/SpinalRiscv/Plugin/ShiftPlugins.scala +++ b/src/main/scala/SpinalRiscv/Plugin/ShiftPlugins.scala @@ -151,7 +151,7 @@ class LightShifterPlugin extends Plugin[VexRiscv]{ when(arbitration.isValid && isShift && input(SRC2)(4 downto 0) =/= 0){ - insert(REGFILE_WRITE_DATA) := input(SHIFT_CTRL).mux( + output(REGFILE_WRITE_DATA) := input(SHIFT_CTRL).mux( ShiftCtrlEnum.SLL -> (shiftInput |<< 1), default -> (((input(SHIFT_CTRL) === ShiftCtrlEnum.SRA && shiftInput.msb) ## shiftInput).asSInt >> 1).asBits //ALU.SRL,ALU.SRA ) diff --git a/src/main/scala/SpinalRiscv/Riscv.scala b/src/main/scala/SpinalRiscv/Riscv.scala index e7c83c3..a8bc584 100644 --- a/src/main/scala/SpinalRiscv/Riscv.scala +++ b/src/main/scala/SpinalRiscv/Riscv.scala @@ -9,6 +9,7 @@ object Riscv{ def funct3Range = 14 downto 12 def rs1Range = 19 downto 15 def rs2Range = 24 downto 20 + def csrRange = 31 downto 20 case class IMM(instruction : Bits) extends Area{ // immediates @@ -82,4 +83,22 @@ object Riscv{ def ECALL = M"00000000000000000000000001110011" def EBREAK = M"00000000000100000000000001110011" def MRET = M"00110000001000000000000001110011" + + object CSR{ + def MVENDORID = 0xF11 // MRO Vendor ID. + def MARCHID = 0xF12 // MRO Architecture ID. + def MIMPID = 0xF13 // MRO Implementation ID. + def MHARTID = 0xF14 // MRO Hardware thread ID.Machine Trap Setup + def MSTATUS = 0x300 // MRW Machine status register. + def MISA = 0x301 // MRW ISA and extensions + def MEDELEG = 0x302 // MRW Machine exception delegation register. + def MIDELEG = 0x303 // MRW Machine interrupt delegation register. + def MIE = 0x304 // MRW Machine interrupt-enable register. + def MTVEC = 0x305 // MRW Machine trap-handler base address. Machine Trap Handling + def MSCRATCH = 0x340 // MRW Scratch register for machine trap handlers. + def MEPC = 0x341 // MRW Machine exception program counter. + def MCAUSE = 0x342 // MRW Machine trap cause. + def MBADADDR = 0x343 // MRW Machine bad address. + def MIP = 0x344 // MRW Machine interrupt pending. + } } diff --git a/src/main/scala/SpinalRiscv/TopLevel.scala b/src/main/scala/SpinalRiscv/TopLevel.scala index 0193c4c..226edce 100644 --- a/src/main/scala/SpinalRiscv/TopLevel.scala +++ b/src/main/scala/SpinalRiscv/TopLevel.scala @@ -30,8 +30,25 @@ object TopLevel { pcWidth = 32 ) + + import CsrAccess._ + val csrConfig = MachineCsrConfig( + mvendorid = 11, + marchid = 22, + mimpid = 33, + mhartid = 44, + misaExtensions = 66, + misaAccess = READ_WRITE, + mtvecAccess = READ_WRITE, + mtvecInit = 0x80000004l, + mepcAccess = READ_WRITE, + mscratchGen = true, + mcauseAccess = READ_WRITE, + mbadaddrAccess = READ_WRITE + ) + config.plugins ++= List( - new PcManagerSimplePlugin(0, false), + new PcManagerSimplePlugin(0x00000000l, false), new IBusSimplePlugin(true), new DecoderSimplePlugin, new RegFilePlugin(Plugin.SYNC), @@ -45,6 +62,7 @@ object TopLevel { // new HazardSimplePlugin(false, false, false, false), new MulPlugin, new DivPlugin, + new MachineCsr(csrConfig), new BranchPlugin(false, DYNAMIC) ) diff --git a/src/test/cpp/testA/main.cpp b/src/test/cpp/testA/main.cpp index f4a3aea..917e7ca 100644 --- a/src/test/cpp/testA/main.cpp +++ b/src/test/cpp/testA/main.cpp @@ -502,23 +502,25 @@ int main(int argc, char **argv, char **env) { #ifndef REF TestA().run(); for(const string &name : riscvTestMain){ - redo(5,RiscvTest(name).run();) + redo(REDO,RiscvTest(name).run();) } for(const string &name : riscvTestMemory){ - redo(5,RiscvTest(name).run();) + redo(REDO,RiscvTest(name).run();) } for(const string &name : riscvTestMul){ - redo(5,RiscvTest(name).run();) + redo(REDO,RiscvTest(name).run();) } for(const string &name : riscvTestDiv){ - redo(5,RiscvTest(name).run();) + redo(REDO,RiscvTest(name).run();) } #endif + #ifdef DHRYSTONE Dhrystone("dhrystoneO3",true,true).run(1e6); Dhrystone("dhrystoneO3M",true,true).run(0.8e6); Dhrystone("dhrystoneO3M",false,false).run(0.8e6); // Dhrystone("dhrystoneO3ML",false,false).run(8e6); // Dhrystone("dhrystoneO3MLL",false,false).run(80e6); + #endif } uint64_t duration = timer_end(startedAt); diff --git a/src/test/cpp/testA/makefile b/src/test/cpp/testA/makefile index 1023702..f4b4318 100644 --- a/src/test/cpp/testA/makefile +++ b/src/test/cpp/testA/makefile @@ -1,4 +1,11 @@ TRACE=no +DHRYSTONE=yes +REDO=5 + +ADDCFLAGS += -CFLAGS -DREDO=${REDO} +ifeq ($(DHRYSTONE),yes) + ADDCFLAGS += -CFLAGS -DDHRYSTONE +endif ifeq ($(TRACE),yes) VERILATOR_ARGS += --trace