diff --git a/src/main/scala/vexriscv/VexRiscv.scala b/src/main/scala/vexriscv/VexRiscv.scala index 7077096..6698586 100644 --- a/src/main/scala/vexriscv/VexRiscv.scala +++ b/src/main/scala/vexriscv/VexRiscv.scala @@ -6,14 +6,20 @@ import spinal.core._ import scala.collection.mutable.ArrayBuffer object VexRiscvConfig{ - def apply(plugins : Seq[Plugin[VexRiscv]] = ArrayBuffer()) : VexRiscvConfig = { + def apply(withMemoryStage : Boolean, withWriteBackStage : Boolean, plugins : Seq[Plugin[VexRiscv]]): VexRiscvConfig = { val config = VexRiscvConfig() config.plugins ++= plugins + config.withMemoryStage = withMemoryStage + config.withWriteBackStage = withWriteBackStage config } + + def apply(plugins : Seq[Plugin[VexRiscv]] = ArrayBuffer()) : VexRiscvConfig = apply(true,true,plugins) } case class VexRiscvConfig(){ + var withMemoryStage = true + var withWriteBackStage = true val plugins = ArrayBuffer[Plugin[VexRiscv]]() //Default Stageables @@ -76,8 +82,14 @@ class VexRiscv(val config : VexRiscvConfig) extends Component with Pipeline{ type T = VexRiscv import config._ - stages ++= List.fill(4)(new Stage()) - val /*prefetch :: fetch :: */decode :: execute :: memory :: writeBack :: Nil = stages.toList + stages ++= List.fill(2 + (if(withMemoryStage) 1 else 0) + (if(withWriteBackStage) 1 else 0))(new Stage()) + val decode = stages(0) + val execute = stages(1) + val memory = ifGen(withMemoryStage) (stages(2)) + val writeBack = ifGen(withWriteBackStage) (stages(3)) + + def stagesFromExecute = stages.dropWhile(_ != execute) + plugins ++= config.plugins //regression usage @@ -86,12 +98,16 @@ class VexRiscv(val config : VexRiscvConfig) extends Component with Pipeline{ decode.arbitration.isValid.addAttribute(Verilator.public) decode.arbitration.flushAll.addAttribute(Verilator.public) decode.arbitration.haltItself.addAttribute(Verilator.public) - writeBack.input(config.INSTRUCTION) keep() addAttribute(Verilator.public) - writeBack.input(config.PC) keep() addAttribute(Verilator.public) - writeBack.arbitration.isValid keep() addAttribute(Verilator.public) - writeBack.arbitration.isFiring keep() addAttribute(Verilator.public) + if(withWriteBackStage) { + writeBack.input(config.INSTRUCTION) keep() addAttribute (Verilator.public) + writeBack.input(config.PC) keep() addAttribute (Verilator.public) + writeBack.arbitration.isValid keep() addAttribute (Verilator.public) + writeBack.arbitration.isFiring keep() addAttribute (Verilator.public) + } decode.arbitration.removeIt.noBackendCombMerge //Verilator perf - memory.arbitration.removeIt.noBackendCombMerge + if(withMemoryStage){ + memory.arbitration.removeIt.noBackendCombMerge + } execute.arbitration.flushAll.noBackendCombMerge this(RVC_GEN) = false diff --git a/src/main/scala/vexriscv/plugin/BranchPlugin.scala b/src/main/scala/vexriscv/plugin/BranchPlugin.scala index 6cb4649..287e0e2 100644 --- a/src/main/scala/vexriscv/plugin/BranchPlugin.scala +++ b/src/main/scala/vexriscv/plugin/BranchPlugin.scala @@ -150,7 +150,7 @@ class BranchPlugin(earlyBranch : Boolean, decode.output(INSTRUCTION)(12) := False decode.output(INSTRUCTION)(22) := True } - execute.arbitration.haltByOther setWhen(execute.arbitration.isValid && execute.input(IS_FENCEI) && List(memory,writeBack).map(_.arbitration.isValid).orR) + execute.arbitration.haltByOther setWhen(execute.arbitration.isValid && execute.input(IS_FENCEI) && stagesFromExecute.tail.map(_.arbitration.isValid).asBits.orR) } } diff --git a/src/main/scala/vexriscv/plugin/CsrPlugin.scala b/src/main/scala/vexriscv/plugin/CsrPlugin.scala index c003c01..8bd32ae 100644 --- a/src/main/scala/vexriscv/plugin/CsrPlugin.scala +++ b/src/main/scala/vexriscv/plugin/CsrPlugin.scala @@ -317,7 +317,7 @@ class CsrPlugin(config: CsrPluginConfig) extends Plugin[VexRiscv] with Exception if(ebreakGen) decoderService.add(EBREAK, defaultEnv ++ List(ENV_CTRL -> EnvCtrlEnum.EBREAK, HAS_SIDE_EFFECT -> True)) val pcManagerService = pipeline.service(classOf[JumpService]) - jumpInterface = pcManagerService.createJumpInterface(pipeline.writeBack) + jumpInterface = pcManagerService.createJumpInterface(pipeline.stages.last) jumpInterface.valid := False jumpInterface.payload.assignDontCare() @@ -489,10 +489,13 @@ class CsrPlugin(config: CsrPluginConfig) extends Plugin[VexRiscv] with Exception import machineCsr._ import supervisorCsr._ + val lastStage = pipeline.stages.last + val beforeLastStage = pipeline.stages(pipeline.stages.size-2) + val stagesFromExecute = pipeline.stages.dropWhile(_ != execute) //Manage counters mcycle := mcycle + 1 - when(writeBack.arbitration.isFiring) { + when(lastStage.arbitration.isFiring) { minstret := minstret + 1 } @@ -541,7 +544,7 @@ class CsrPlugin(config: CsrPluginConfig) extends Plugin[VexRiscv] with Exception if(medelegAccess.canWrite) exceptionDelegators += DelegatorModel(medeleg,3, 1) - val mepcCaptureStage = if(exceptionPortsInfos.nonEmpty) writeBack else decode + val mepcCaptureStage = if(exceptionPortsInfos.nonEmpty) lastStage else decode //Aggregate all exception port and remove required instructions @@ -626,7 +629,7 @@ class CsrPlugin(config: CsrPluginConfig) extends Plugin[VexRiscv] with Exception interrupt.clearWhen(!allowInterrupts) val exception = if(exceptionPortCtrl != null) exceptionPortCtrl.exceptionValids.last && allowException else False - val writeBackWasWfi = if(wfiGenAsWait) RegNext(writeBack.arbitration.isFiring && writeBack.input(ENV_CTRL) === EnvCtrlEnum.WFI) init(False) else False + val lastStageWasWfi = if(wfiGenAsWait) RegNext(lastStage.arbitration.isFiring && lastStage.input(ENV_CTRL) === EnvCtrlEnum.WFI) init(False) else False @@ -636,7 +639,7 @@ class CsrPlugin(config: CsrPluginConfig) extends Plugin[VexRiscv] with Exception decode.arbitration.haltByOther := True } - val done = !List(execute, memory, writeBack).map(_.arbitration.isValid).orR && fetcher.pcValid(mepcCaptureStage) + val done = !stagesFromExecute.map(_.arbitration.isValid).orR && fetcher.pcValid(mepcCaptureStage) if(exceptionPortCtrl != null) done.clearWhen(exceptionPortCtrl.exceptionValidsRegs.tail.orR) } @@ -672,7 +675,7 @@ class CsrPlugin(config: CsrPluginConfig) extends Plugin[VexRiscv] with Exception when(hadException || interruptJump){ jumpInterface.valid := True jumpInterface.payload := (if(!mtvecModeGen) mtvec.base @@ "00" else (mtvec.mode === 0 || hadException) ? (mtvec.base @@ "00") | ((mtvec.base + trapCause) @@ "00") ) - memory.arbitration.flushAll := True + beforeLastStage.arbitration.flushAll := True switch(targetPrivilege){ if(supervisorGen) is(1) { @@ -699,14 +702,14 @@ class CsrPlugin(config: CsrPluginConfig) extends Plugin[VexRiscv] with Exception } } - writeBack plug new Area{ - import writeBack._ - def previousStage = memory + lastStage plug new Area{ + import lastStage._ + //Manage MRET / SRET instructions when(arbitration.isValid && input(ENV_CTRL) === EnvCtrlEnum.XRET) { jumpInterface.payload := mepc jumpInterface.valid := True - previousStage.arbitration.flushAll := True + beforeLastStage.arbitration.flushAll := True switch(input(INSTRUCTION)(29 downto 28)){ is(3){ mstatus.MIE := mstatus.MPIE @@ -750,12 +753,12 @@ class CsrPlugin(config: CsrPluginConfig) extends Plugin[VexRiscv] with Exception } } - decode.arbitration.haltByOther setWhen(List(execute,memory).map(s => s.arbitration.isValid && s.input(ENV_CTRL) === EnvCtrlEnum.XRET).orR) + decode.arbitration.haltByOther setWhen(stagesFromExecute.dropRight(1).map(s => s.arbitration.isValid && s.input(ENV_CTRL) === EnvCtrlEnum.XRET).asBits.orR) execute plug new Area { import execute._ def previousStage = decode - val blockedBySideEffects = List(memory, writeBack).map(s => s.arbitration.isValid).orR // && s.input(HAS_SIDE_EFFECT) to improve be less pessimistic + val blockedBySideEffects = stagesFromExecute.tail.map(s => s.arbitration.isValid).asBits().orR // && s.input(HAS_SIDE_EFFECT) to improve be less pessimistic val illegalAccess = arbitration.isValid && input(IS_CSR) val illegalInstruction = False diff --git a/src/main/scala/vexriscv/plugin/DBusSimplePlugin.scala b/src/main/scala/vexriscv/plugin/DBusSimplePlugin.scala index a388287..fd8bc3d 100644 --- a/src/main/scala/vexriscv/plugin/DBusSimplePlugin.scala +++ b/src/main/scala/vexriscv/plugin/DBusSimplePlugin.scala @@ -317,7 +317,7 @@ class DBusSimplePlugin(catchAddressMisaligned : Boolean = false, } //Reformat read responses, REGFILE_WRITE_DATA overriding - val injectionStage = if(earlyInjection) memory else writeBack + val injectionStage = if(earlyInjection) memory else stages.last injectionStage plug new Area { import injectionStage._ @@ -340,7 +340,7 @@ class DBusSimplePlugin(catchAddressMisaligned : Boolean = false, output(REGFILE_WRITE_DATA) := (if(!onlyLoadWords) rspFormated else input(MEMORY_READ_DATA)) } - if(!earlyInjection && !emitCmdInMemoryStage) + if(!earlyInjection && !emitCmdInMemoryStage && config.withWriteBackStage) assert(!(arbitration.isValid && input(MEMORY_ENABLE) && !input(INSTRUCTION)(5) && arbitration.isStuck),"DBusSimplePlugin doesn't allow writeback stage stall when read happend") //formal diff --git a/src/main/scala/vexriscv/plugin/Fetcher.scala b/src/main/scala/vexriscv/plugin/Fetcher.scala index 5aac4e7..27a22e0 100644 --- a/src/main/scala/vexriscv/plugin/Fetcher.scala +++ b/src/main/scala/vexriscv/plugin/Fetcher.scala @@ -31,7 +31,7 @@ abstract class IBusFetcherImpl(val resetVector : BigInt, assert(!(compressedGen && !decodePcGen)) var fetcherHalt : Bool = null var fetcherflushIt : Bool = null - lazy val pcValids = Vec(Bool, 4) + var pcValids : Vec[Bool] = null def pcValid(stage : Stage) = pcValids(pipeline.indexOf(stage)) var incomingInstruction : Bool = null override def incoming() = incomingInstruction @@ -50,6 +50,7 @@ abstract class IBusFetcherImpl(val resetVector : BigInt, case class JumpInfo(interface : Flow[UInt], stage: Stage, priority : Int) val jumpInfos = ArrayBuffer[JumpInfo]() override def createJumpInterface(stage: Stage, priority : Int = 0): Flow[UInt] = { + assert(stage != null) val interface = Flow(UInt(32 bits)) jumpInfos += JumpInfo(interface,stage, priority) interface @@ -78,6 +79,8 @@ abstract class IBusFetcherImpl(val resetVector : BigInt, } } } + + pcValids = Vec(Bool, pipeline.stages.size) } @@ -320,12 +323,13 @@ abstract class IBusFetcherImpl(val resetVector : BigInt, }).tail } + val stagesFromExecute = stages.dropWhile(_ != execute).toList val nextPcCalc = if (decodePcGen) new Area{ - val valids = pcUpdatedGen(True, False :: List(execute, memory, writeBack).map(_.arbitration.isStuck), true) - pcValids := Vec(valids.takeRight(4)) + val valids = pcUpdatedGen(True, False :: stagesFromExecute.map(_.arbitration.isStuck), true) + pcValids := Vec(valids.takeRight(stages.size)) } else new Area{ - val valids = pcUpdatedGen(True, iBusRsp.stages.tail.map(!_.input.ready) ++ (if (injectorStage) List(!decodeInput.ready) else Nil) ++ List(execute, memory, writeBack).map(_.arbitration.isStuck), false) - pcValids := Vec(valids.takeRight(4)) + val valids = pcUpdatedGen(True, iBusRsp.stages.tail.map(!_.input.ready) ++ (if (injectorStage) List(!decodeInput.ready) else Nil) ++ stagesFromExecute.map(_.arbitration.isStuck), false) + pcValids := Vec(valids.takeRight(stages.size)) } val decodeRemoved = RegInit(False) setWhen(decode.arbitration.isRemoved) clearWhen(flush) //!decode.arbitration.isStuck || decode.arbitration.isFlushed diff --git a/src/main/scala/vexriscv/plugin/HazardPessimisticPlugin.scala b/src/main/scala/vexriscv/plugin/HazardPessimisticPlugin.scala index 9fb11a6..a16324f 100644 --- a/src/main/scala/vexriscv/plugin/HazardPessimisticPlugin.scala +++ b/src/main/scala/vexriscv/plugin/HazardPessimisticPlugin.scala @@ -18,7 +18,7 @@ class HazardPessimisticPlugin() extends Plugin[VexRiscv] { import pipeline._ import pipeline.config._ - val writesInPipeline = List(execute,memory,writeBack).map(s => s.arbitration.isValid && s.input(REGFILE_WRITE_VALID)) :+ RegNext(writeBack.arbitration.isValid && writeBack.input(REGFILE_WRITE_VALID)) + val writesInPipeline = stages.dropWhile(_ != execute).map(s => s.arbitration.isValid && s.input(REGFILE_WRITE_VALID)) :+ RegNext(stages.last.arbitration.isValid && stages.last.input(REGFILE_WRITE_VALID)) decode.arbitration.haltItself.setWhen(decode.arbitration.isValid && writesInPipeline.orR) } } diff --git a/src/main/scala/vexriscv/plugin/HazardSimplePlugin.scala b/src/main/scala/vexriscv/plugin/HazardSimplePlugin.scala index b18ae8e..51d059d 100644 --- a/src/main/scala/vexriscv/plugin/HazardSimplePlugin.scala +++ b/src/main/scala/vexriscv/plugin/HazardSimplePlugin.scala @@ -59,9 +59,9 @@ class HazardSimplePlugin(bypassExecute : Boolean = false, val address = Bits(5 bits) val data = Bits(32 bits) })) - writeBackWrites.valid := writeBack.output(REGFILE_WRITE_VALID) && writeBack.arbitration.isFiring - writeBackWrites.address := writeBack.output(INSTRUCTION)(rdRange) - writeBackWrites.data := writeBack.output(REGFILE_WRITE_DATA) + writeBackWrites.valid := stages.last.output(REGFILE_WRITE_VALID) && stages.last.arbitration.isFiring + writeBackWrites.address := stages.last.output(INSTRUCTION)(rdRange) + writeBackWrites.data := stages.last.output(REGFILE_WRITE_DATA) val writeBackBuffer = writeBackWrites.stage() val addr0Match = if(pessimisticAddressMatch) True else writeBackBuffer.address === decode.input(INSTRUCTION)(rs1Range) @@ -84,9 +84,9 @@ class HazardSimplePlugin(bypassExecute : Boolean = false, } } - trackHazardWithStage(writeBack,bypassWriteBack,null) - trackHazardWithStage(memory ,bypassMemory ,BYPASSABLE_MEMORY_STAGE) - trackHazardWithStage(execute ,bypassExecute ,BYPASSABLE_EXECUTE_STAGE) + if(withWriteBackStage) trackHazardWithStage(writeBack,bypassWriteBack,null) + if(withMemoryStage) trackHazardWithStage(memory ,bypassMemory ,BYPASSABLE_MEMORY_STAGE) + trackHazardWithStage(execute ,bypassExecute , if(stages.last == execute) null else BYPASSABLE_EXECUTE_STAGE) if(!pessimisticUseSrc) { diff --git a/src/main/scala/vexriscv/plugin/RegFilePlugin.scala b/src/main/scala/vexriscv/plugin/RegFilePlugin.scala index 568b100..74d0050 100644 --- a/src/main/scala/vexriscv/plugin/RegFilePlugin.scala +++ b/src/main/scala/vexriscv/plugin/RegFilePlugin.scala @@ -67,7 +67,7 @@ class RegFilePlugin(regFileReadyKind : RegFileReadKind, } //Write register file - val writeStage = if(writeRfInMemoryStage) memory else writeBack + val writeStage = if(writeRfInMemoryStage) memory else stages.last writeStage plug new Area { import writeStage._