diff --git a/src/main/scala/vexriscv/plugin/BranchPlugin.scala b/src/main/scala/vexriscv/plugin/BranchPlugin.scala index fc9dde3..1011937 100644 --- a/src/main/scala/vexriscv/plugin/BranchPlugin.scala +++ b/src/main/scala/vexriscv/plugin/BranchPlugin.scala @@ -11,25 +11,71 @@ object STATIC extends BranchPrediction object DYNAMIC extends BranchPrediction object DYNAMIC_TARGET extends BranchPrediction +object BranchCtrlEnum extends SpinalEnum(binarySequential){ + val INC,B,JAL,JALR = newElement() +} +object BRANCH_CTRL extends Stageable(BranchCtrlEnum()) + + +case class DecodePredictionCmd() extends Bundle { + val hadBranch = Bool +} +case class DecodePredictionRsp(stage : Stage) extends Bundle { + val wasWrong = Bool +} +case class DecodePredictionBus(stage : Stage) extends Bundle { + val cmd = DecodePredictionCmd() + val rsp = DecodePredictionRsp(stage) +} + +case class FetchPredictionCmd() extends Bundle{ + val hadBranch = Bool + val targetPc = UInt(32 bits) +} +case class FetchPredictionRsp(stage : Stage) extends Bundle{ + val wasRight = Bool + val targetPc = UInt(32 bits) +} +case class FetchPredictionBus(stage : Stage) extends Bundle { + val cmd = FetchPredictionCmd() + val rsp = FetchPredictionRsp(stage) +} + + +trait PredictionInterface{ + def askFetchPrediction() : FetchPredictionBus + def askDecodePrediction() : DecodePredictionBus +} + class BranchPlugin(earlyBranch : Boolean, catchAddressMisaligned : Boolean, prediction : BranchPrediction, historyRamSizeLog2 : Int = 10, - historyWidth : Int = 2) extends Plugin[VexRiscv]{ - object BranchCtrlEnum extends SpinalEnum(binarySequential){ - val INC,B,JAL,JALR = newElement() - } + historyWidth : Int = 2) extends Plugin[VexRiscv] with PredictionInterface{ + + + lazy val branchStage = if(earlyBranch) pipeline.execute else pipeline.memory - object BRANCH_CTRL extends Stageable(BranchCtrlEnum()) object BRANCH_CALC extends Stageable(UInt(32 bits)) object BRANCH_DO extends Stageable(Bool) object BRANCH_COND_RESULT extends Stageable(Bool) +// object PREDICTION_HAD_BRANCHED extends Stageable(Bool) var jumpInterface : Flow[UInt] = null var predictionJumpInterface : Flow[UInt] = null var predictionExceptionPort : Flow[ExceptionCause] = null var branchExceptionPort : Flow[ExceptionCause] = null + + var decodePrediction : DecodePredictionBus = null + + + override def askFetchPrediction() = ??? + override def askDecodePrediction() = { + decodePrediction = DecodePredictionBus(branchStage) + decodePrediction + } + override def setup(pipeline: VexRiscv): Unit = { import Riscv._ import pipeline.config._ @@ -66,7 +112,7 @@ class BranchPlugin(earlyBranch : Boolean, )) val pcManagerService = pipeline.service(classOf[JumpService]) - jumpInterface = pcManagerService.createJumpInterface(if(earlyBranch) pipeline.execute else pipeline.memory) + jumpInterface = pcManagerService.createJumpInterface(branchStage) prediction match { case NONE => @@ -76,7 +122,7 @@ class BranchPlugin(earlyBranch : Boolean, if (catchAddressMisaligned) { val exceptionService = pipeline.service(classOf[ExceptionService]) - branchExceptionPort = exceptionService.newExceptionPort(if (earlyBranch) pipeline.execute else pipeline.memory) + branchExceptionPort = exceptionService.newExceptionPort(branchStage) prediction match { case NONE => // case STATIC | DYNAMIC => predictionExceptionPort = exceptionService.newExceptionPort(pipeline.decode) @@ -85,9 +131,9 @@ class BranchPlugin(earlyBranch : Boolean, } } - override def build(pipeline: VexRiscv): Unit = prediction match { - case `NONE` => buildWithoutPrediction(pipeline) -// case `STATIC` => buildWithPrediction(pipeline) + override def build(pipeline: VexRiscv): Unit = (decodePrediction) match { + case null => buildWithoutPrediction(pipeline) + case _ => buildWithPrediction(pipeline) // case `DYNAMIC` => buildWithPrediction(pipeline) // case `DYNAMIC_TARGET` => buildDynamicTargetPrediction(pipeline) } @@ -128,7 +174,6 @@ class BranchPlugin(earlyBranch : Boolean, } //Apply branchs (JAL,JALR, Bxx) - val branchStage = if(earlyBranch) execute else memory branchStage plug new Area { import branchStage._ jumpInterface.valid := arbitration.isFiring && input(BRANCH_DO) @@ -147,21 +192,21 @@ class BranchPlugin(earlyBranch : Boolean, } -// def buildWithPrediction(pipeline: VexRiscv): Unit = { + def buildWithPrediction(pipeline: VexRiscv): Unit = { // case class BranchPredictorLine() extends Bundle{ // val history = SInt(historyWidth bits) // } -// -// object PREDICTION_HAD_BRANCHED extends Stageable(Bool) + + object PREDICTION_HAD_BRANCHED extends Stageable(Bool) // object HISTORY_LINE extends Stageable(BranchPredictorLine()) -// -// import pipeline._ -// import pipeline.config._ -// + + import pipeline._ + import pipeline.config._ + // val historyCache = if(prediction == DYNAMIC) Mem(BranchPredictorLine(), 1 << historyRamSizeLog2) setName("branchCache") else null // val historyCacheWrite = if(prediction == DYNAMIC) historyCache.writePort else null -// -// //Read historyCache + + //Read historyCache // if(prediction == DYNAMIC) fetch plug new Area{ // val readAddress = prefetch.output(PC)(2, historyRamSizeLog2 bits) // fetch.insert(HISTORY_LINE) := historyCache.readSync(readAddress,!prefetch.arbitration.isStuckByOthers) @@ -172,8 +217,8 @@ class BranchPlugin(earlyBranch : Boolean, //// fetch.insert(HISTORY_LINE) := writePortReg.data //// } // } -// -// //Branch JAL, predict Bxx and branch it + + //Branch JAL, predict Bxx and branch it // decode plug new Area{ // import decode._ // val imm = IMM(input(INSTRUCTION)) @@ -196,65 +241,71 @@ class BranchPlugin(earlyBranch : Boolean, // predictionExceptionPort.badAddr := predictionJumpInterface.payload // } // } -// -// //Do real branch calculation -// execute plug new Area { -// import execute._ -// -// val less = input(SRC_LESS) -// val eq = input(SRC1) === input(SRC2) -// -// insert(BRANCH_COND_RESULT) := input(BRANCH_CTRL).mux( -// BranchCtrlEnum.INC -> False, -// BranchCtrlEnum.JAL -> True, -// BranchCtrlEnum.JALR -> True, -// BranchCtrlEnum.B -> input(INSTRUCTION)(14 downto 12).mux( -// B"000" -> eq , -// B"001" -> !eq , -// M"1-1" -> !less, -// default -> less -// ) -// ) -// -// insert(BRANCH_DO) := input(PREDICTION_HAD_BRANCHED) =/= insert(BRANCH_COND_RESULT) -// -// //Calculation of the branch target / correction -// val imm = IMM(input(INSTRUCTION)) -// val branch_src1,branch_src2 = UInt(32 bits) -// switch(input(BRANCH_CTRL)){ -// is(BranchCtrlEnum.JALR){ -// branch_src1 := input(RS1).asUInt -// branch_src2 := imm.i_sext.asUInt -// } -// default{ -// branch_src1 := input(PC) -// branch_src2 := (input(PREDICTION_HAD_BRANCHED) ? B(4) | imm.b_sext).asUInt -// } -// } -// val branchAdder = branch_src1 + branch_src2 -// insert(BRANCH_CALC) := branchAdder(31 downto 1) @@ ((input(BRANCH_CTRL) === BranchCtrlEnum.JALR) ? False | branchAdder(0)) -// } -// -// -// // branch JALR or JAL/Bxx prediction miss corrections -// val branchStage = if(earlyBranch) execute else memory -// branchStage plug new Area { -// import branchStage._ -// jumpInterface.valid := input(BRANCH_DO) && arbitration.isFiring -// jumpInterface.payload := input(BRANCH_CALC) -// -// when(jumpInterface.valid) { -// stages(indexOf(branchStage) - 1).arbitration.flushAll := True -// } -// -// if(catchAddressMisaligned) { -// branchExceptionPort.valid := input(INSTRUCTION_READY) && arbitration.isValid && input(BRANCH_DO) && jumpInterface.payload(1 downto 0) =/= 0 -// branchExceptionPort.code := 0 -// branchExceptionPort.badAddr := jumpInterface.payload -// } -// } -// -// //Update historyCache + + decode plug new Area { + import decode._ + insert(PREDICTION_HAD_BRANCHED) := decodePrediction.cmd.hadBranch + } + + //Do real branch calculation + execute plug new Area { + import execute._ + + val less = input(SRC_LESS) + val eq = input(SRC1) === input(SRC2) + + insert(BRANCH_COND_RESULT) := input(BRANCH_CTRL).mux( + BranchCtrlEnum.INC -> False, + BranchCtrlEnum.JAL -> True, + BranchCtrlEnum.JALR -> True, + BranchCtrlEnum.B -> input(INSTRUCTION)(14 downto 12).mux( + B"000" -> eq , + B"001" -> !eq , + M"1-1" -> !less, + default -> less + ) + ) + + insert(BRANCH_DO) := input(PREDICTION_HAD_BRANCHED) =/= insert(BRANCH_COND_RESULT) + + //Calculation of the branch target / correction + val imm = IMM(input(INSTRUCTION)) + val branch_src1,branch_src2 = UInt(32 bits) + switch(input(BRANCH_CTRL)){ + is(BranchCtrlEnum.JALR){ + branch_src1 := input(RS1).asUInt + branch_src2 := imm.i_sext.asUInt + } + default{ + branch_src1 := input(PC) + branch_src2 := (input(PREDICTION_HAD_BRANCHED) ? B(4) | imm.b_sext).asUInt + } + } + val branchAdder = branch_src1 + branch_src2 + insert(BRANCH_CALC) := branchAdder(31 downto 1) @@ ((input(BRANCH_CTRL) === BranchCtrlEnum.JALR) ? False | branchAdder(0)) + } + + + // branch JALR or JAL/Bxx prediction miss corrections + val branchStage = if(earlyBranch) execute else memory + branchStage plug new Area { + import branchStage._ + jumpInterface.valid := input(BRANCH_DO) && arbitration.isFiring + jumpInterface.payload := input(BRANCH_CALC) + + when(jumpInterface.valid) { + stages(indexOf(branchStage) - 1).arbitration.flushAll := True + } + + if(catchAddressMisaligned) { + branchExceptionPort.valid := input(INSTRUCTION_READY) && arbitration.isValid && input(BRANCH_DO) && jumpInterface.payload(1 downto 0) =/= 0 + branchExceptionPort.code := 0 + branchExceptionPort.badAddr := jumpInterface.payload + } + } + + //Update historyCache + decodePrediction.rsp.wasWrong := jumpInterface.valid // if(prediction == DYNAMIC) branchStage plug new Area { // import branchStage._ // val newHistory = input(HISTORY_LINE).history.resize(historyWidth + 1) + Mux(input(BRANCH_COND_RESULT),S(-1),S(1)) @@ -264,7 +315,7 @@ class BranchPlugin(earlyBranch : Boolean, // historyCacheWrite.address := input(PC)(2, historyRamSizeLog2 bits) // historyCacheWrite.data.history := newHistory.resized // } -// } + } diff --git a/src/main/scala/vexriscv/plugin/IBusSimplePlugin.scala b/src/main/scala/vexriscv/plugin/IBusSimplePlugin.scala index fc78d14..d9d4eea 100644 --- a/src/main/scala/vexriscv/plugin/IBusSimplePlugin.scala +++ b/src/main/scala/vexriscv/plugin/IBusSimplePlugin.scala @@ -5,6 +5,7 @@ import spinal.core._ import spinal.lib._ import spinal.lib.bus.amba4.axi._ import spinal.lib.bus.avalon.{AvalonMM, AvalonMMConfig} +import vexriscv.Riscv.IMM import scala.collection.mutable.ArrayBuffer @@ -109,12 +110,14 @@ class IBusSimplePlugin(interfaceKeepData : Boolean, catchAccessFault : Boolean, var prefetchExceptionPort : Flow[ExceptionCause] = null def resetVector = BigInt(0x80000000l) def keepPcPlus4 = false - def decodePcGen = true - def compressedGen = true + def decodePcGen = false + def compressedGen = false def cmdToRspStageCount = 1 def rspStageGen = false - def injectorReadyCutGen = true - def relaxedPcCalculation = true + def injectorReadyCutGen = false + def relaxedPcCalculation = false + def prediction : BranchPrediction = STATIC + var decodePrediction : DecodePredictionBus = null assert(cmdToRspStageCount >= 1) assert(!(compressedGen && !decodePcGen)) lazy val fetcherHalt = False @@ -122,6 +125,8 @@ class IBusSimplePlugin(interfaceKeepData : Boolean, catchAccessFault : Boolean, lazy val decodeNextPc = UInt(32 bits) def nextPc() = (decodeNextPcValid, decodeNextPc) + var predictionJumpInterface : Flow[UInt] = null + override def haltIt(): Unit = fetcherHalt := True case class JumpInfo(interface : Flow[UInt], stage: Stage, priority : Int) @@ -142,6 +147,14 @@ class IBusSimplePlugin(interfaceKeepData : Boolean, catchAccessFault : Boolean, } pipeline(RVC_GEN) = compressedGen + + prediction match { + case NONE => + case STATIC | DYNAMIC => { + predictionJumpInterface = createJumpInterface(pipeline.decode) + decodePrediction = pipeline.service(classOf[PredictionInterface]).askDecodePrediction() + } + } } override def build(pipeline: VexRiscv): Unit = { @@ -365,6 +378,29 @@ class IBusSimplePlugin(interfaceKeepData : Boolean, catchAccessFault : Boolean, decodeExceptionPort.code := 1 decodeExceptionPort.badAddr := decode.input(PC) } + + prediction match { + case `NONE` => + case `STATIC` => { + val imm = IMM(decode.input(INSTRUCTION)) + + val conditionalBranchPrediction = (prediction match { + case `STATIC` => imm.b_sext.msb + //case `DYNAMIC` => input(HISTORY_LINE).history.msb + }) + decodePrediction.cmd.hadBranch := decode.input(BRANCH_CTRL) === BranchCtrlEnum.JAL || (decode.input(BRANCH_CTRL) === BranchCtrlEnum.B && conditionalBranchPrediction) + + predictionJumpInterface.valid := decodePrediction.cmd.hadBranch && decode.arbitration.isFiring //TODO OH Doublon de priorité + predictionJumpInterface.payload := decode.input(PC) + ((decode.input(BRANCH_CTRL) === BranchCtrlEnum.JAL) ? imm.j_sext | imm.b_sext).asUInt + + +// if(catchAddressMisaligned) { +// predictionExceptionPort.valid := input(INSTRUCTION_READY) && input(PREDICTION_HAD_BRANCHED) && arbitration.isValid && predictionJumpInterface.payload(1 downto 0) =/= 0 +// predictionExceptionPort.code := 0 +// predictionExceptionPort.badAddr := predictionJumpInterface.payload +// } + } + } } } }