diff --git a/src/main/scala/vexriscv/TestsWorkspace.scala b/src/main/scala/vexriscv/TestsWorkspace.scala index 51b0084..d7bab05 100644 --- a/src/main/scala/vexriscv/TestsWorkspace.scala +++ b/src/main/scala/vexriscv/TestsWorkspace.scala @@ -41,7 +41,8 @@ object TestsWorkspace { // ), new IBusCachedPlugin( resetVector = 0x80000000l, - compressedGen = true, + compressedGen = false, + prediction = DYNAMIC_TARGET, config = InstructionCacheConfig( cacheSize = 1024*16, bytePerLine = 32, diff --git a/src/main/scala/vexriscv/demo/Briey.scala b/src/main/scala/vexriscv/demo/Briey.scala index bb0422b..1fd872d 100644 --- a/src/main/scala/vexriscv/demo/Briey.scala +++ b/src/main/scala/vexriscv/demo/Briey.scala @@ -53,6 +53,8 @@ object BrieyConfig{ // catchAccessFault = true // ), new IBusCachedPlugin( + resetVector = 0x80000000l, + prediction = STATIC, config = InstructionCacheConfig( cacheSize = 4096, bytePerLine =32, @@ -64,7 +66,8 @@ object BrieyConfig{ catchAccessFault = true, catchMemoryTranslationMiss = true, asyncTagMemory = false, - twoCycleRam = true + twoCycleRam = true, + twoCycleCache = true ) // askMemoryTranslation = true, // memoryTranslatorPortConfig = MemoryTranslatorPortConfig( @@ -123,7 +126,7 @@ object BrieyConfig{ new BranchPlugin( earlyBranch = false, catchAddressMisaligned = true, - prediction = STATIC + prediction = NONE ), new CsrPlugin( config = CsrPluginConfig( diff --git a/src/main/scala/vexriscv/plugin/BranchPlugin.scala b/src/main/scala/vexriscv/plugin/BranchPlugin.scala index 060bc2d..3b84bc1 100644 --- a/src/main/scala/vexriscv/plugin/BranchPlugin.scala +++ b/src/main/scala/vexriscv/plugin/BranchPlugin.scala @@ -32,13 +32,13 @@ case class FetchPredictionCmd() extends Bundle{ val hadBranch = Bool val targetPc = UInt(32 bits) } -case class FetchPredictionRsp(stage : Stage) extends Bundle{ +case class FetchPredictionRsp() extends Bundle{ val wasRight = Bool - val targetPc = UInt(32 bits) + val finalPc = UInt(32 bits) } case class FetchPredictionBus(stage : Stage) extends Bundle { val cmd = FetchPredictionCmd() - val rsp = FetchPredictionRsp(stage) + val rsp = FetchPredictionRsp() } @@ -68,9 +68,14 @@ class BranchPlugin(earlyBranch : Boolean, var decodePrediction : DecodePredictionBus = null + var fetchPrediction : FetchPredictionBus = null - override def askFetchPrediction() = ??? + override def askFetchPrediction() = { + fetchPrediction = FetchPredictionBus(branchStage) + fetchPrediction + } + override def askDecodePrediction() = { decodePrediction = DecodePredictionBus(branchStage) decodePrediction @@ -131,9 +136,10 @@ class BranchPlugin(earlyBranch : Boolean, } } - override def build(pipeline: VexRiscv): Unit = (decodePrediction) match { - case null => buildWithoutPrediction(pipeline) - case _ => buildWithPrediction(pipeline) + override def build(pipeline: VexRiscv): Unit = (fetchPrediction,decodePrediction) match { + case (null, null) => buildWithoutPrediction(pipeline) + case (_ , null) => buildFetchPrediction(pipeline) + case (null, _) => buildDecodePrediction(pipeline) // case `DYNAMIC` => buildWithPrediction(pipeline) // case `DYNAMIC_TARGET` => buildDynamicTargetPrediction(pipeline) } @@ -192,7 +198,7 @@ class BranchPlugin(earlyBranch : Boolean, } - def buildWithPrediction(pipeline: VexRiscv): Unit = { + def buildDecodePrediction(pipeline: VexRiscv): Unit = { // case class BranchPredictorLine() extends Bundle{ // val history = SInt(historyWidth bits) // } @@ -298,9 +304,9 @@ class BranchPlugin(earlyBranch : Boolean, } if(catchAddressMisaligned) { - branchExceptionPort.valid := arbitration.isValid && input(BRANCH_DO) && (if(pipeline(RVC_GEN)) jumpInterface.payload(0 downto 0) =/= 0 else jumpInterface.payload(1 downto 0) =/= 0) + branchExceptionPort.valid := arbitration.isValid && input(BRANCH_DO) && (if(pipeline(RVC_GEN)) input(BRANCH_CALC)(0 downto 0) =/= 0 else input(BRANCH_CALC)(1 downto 0) =/= 0) branchExceptionPort.code := 0 - branchExceptionPort.badAddr := jumpInterface.payload + branchExceptionPort.badAddr := input(BRANCH_CALC) } } @@ -321,139 +327,64 @@ class BranchPlugin(earlyBranch : Boolean, -// def buildDynamicTargetPrediction(pipeline: VexRiscv): Unit = { -// import pipeline._ -// import pipeline.config._ -// -// case class BranchPredictorLine() extends Bundle{ -// val source = Bits(31 - historyRamSizeLog2 bits) -// val confidence = UInt(2 bits) -// val target = UInt(32 bits) -// } -// -// object PREDICTION_WRITE_HAZARD extends Stageable(Bool) -// object PREDICTION extends Stageable(BranchPredictorLine()) -// object PREDICTION_HIT extends Stageable(Bool) -// -// val history = Mem(BranchPredictorLine(), 1 << historyRamSizeLog2) -// val historyWrite = history.writePort -// -// -// fetch plug new Area{ -// import fetch._ -// val line = history.readSync((prefetch.output(PC) >> 2).resized, prefetch.arbitration.isFiring) -//// val line = history.readAsync((fetch.output(PC) >> 2).resized) -// val hit = line.source === (input(PC).asBits >> 1 + historyRamSizeLog2) -// -// //Avoid write to read hazard -// val historyWriteLast = RegNext(historyWrite) -// val hazard = historyWriteLast.valid && historyWriteLast.address === (output(PC) >> 2).resized -// insert(PREDICTION_WRITE_HAZARD) := hazard -// -// predictionJumpInterface.valid := line.confidence.msb && hit && arbitration.isFiring && !hazard -// predictionJumpInterface.payload := line.target -// -// insert(PREDICTION) := line -// insert(PREDICTION_HIT) := hit -// } -// -// -// -// //Do branch calculations (conditions + target PC) -// execute plug new Area { -// import execute._ -// -// val less = input(SRC_LESS) -// val eq = input(SRC1) === input(SRC2) -// -// insert(BRANCH_DO) := input(BRANCH_CTRL).mux( -// BranchCtrlEnum.INC -> False, -// BranchCtrlEnum.JAL -> True, -// BranchCtrlEnum.JALR -> True, -// BranchCtrlEnum.B -> input(INSTRUCTION)(14 downto 12).mux( -// B"000" -> eq , -// B"001" -> !eq , -// M"1-1" -> !less, -// default -> less -// ) -// ) -// -// val imm = IMM(input(INSTRUCTION)) -// val branch_src1 = (input(BRANCH_CTRL) === BranchCtrlEnum.JALR) ? input(RS1).asUInt | input(PC) -// val branch_src2 = input(BRANCH_CTRL).mux( -// BranchCtrlEnum.JAL -> imm.j_sext, -// BranchCtrlEnum.JALR -> imm.i_sext, -// default -> imm.b_sext -// ).asUInt -// -// val branchAdder = branch_src1 + branch_src2 -// insert(BRANCH_CALC) := branchAdder(31 downto 1) @@ ((input(BRANCH_CTRL) === BranchCtrlEnum.JALR) ? False | branchAdder(0)) -// } -// -// //Apply branchs (JAL,JALR, Bxx) -// val branchStage = if(earlyBranch) execute else memory -// branchStage plug new Area { -// import branchStage._ -// -// val predictionMissmatch = input(PREDICTION).confidence.msb =/= input(BRANCH_DO) || (input(BRANCH_DO) && input(PREDICTION).target =/= input(BRANCH_CALC)) -// -// historyWrite.valid := False -// historyWrite.address := (branchStage.output(PC) >> 2).resized -// historyWrite.data.source := input(PC).asBits >> 1 + historyRamSizeLog2 -// historyWrite.data.target := input(BRANCH_CALC) -// -// jumpInterface.valid := False -// jumpInterface.payload := input(BRANCH_CALC) -// -// -// when(!input(BRANCH_DO)){ -// historyWrite.valid := arbitration.isFiring && input(PREDICTION_HIT) -// historyWrite.data.confidence := input(PREDICTION).confidence - (input(PREDICTION).confidence =/= 0).asUInt -// historyWrite.data.target := input(BRANCH_CALC) -// -// -// jumpInterface.valid := input(PREDICTION_HIT) && input(PREDICTION).confidence.msb && !input(PREDICTION_WRITE_HAZARD) && arbitration.isFiring -// jumpInterface.payload := input(PC) + 4 -// } otherwise{ -// when(!input(PREDICTION_HIT) || input(PREDICTION_WRITE_HAZARD)){ -// jumpInterface.valid := arbitration.isFiring -// historyWrite.valid := arbitration.isFiring -// historyWrite.data.confidence := "10" -// } otherwise { -// historyWrite.valid := arbitration.isFiring -// historyWrite.data.confidence := input(PREDICTION).confidence + (input(PREDICTION).confidence =/= 3).asUInt -// when(!input(PREDICTION).confidence.msb || input(PREDICTION).target =/= input(BRANCH_CALC)){ -// jumpInterface.valid := arbitration.isFiring -// } -// } -// } -// -// //Prevent rewriting an history which already had hazard -// historyWrite.valid clearWhen(input(PREDICTION_WRITE_HAZARD)) -// -// -// -// when(jumpInterface.valid) { -// stages(indexOf(branchStage) - 1).arbitration.flushAll := True -// } -// -// if(catchAddressMisaligned) { -// branchExceptionPort.valid := arbitration.isValid && input(BRANCH_DO) && jumpInterface.payload(1 downto 0) =/= 0 -// branchExceptionPort.code := 0 -// branchExceptionPort.badAddr := jumpInterface.payload -// } -// } -// -// //Init History -// val historyInit = pipeline plug new Area{ -// val counter = Reg(UInt(historyRamSizeLog2 + 1 bits)) init(0) -// when(!counter.msb){ -// prefetch.arbitration.haltByOther := True -// historyWrite.valid := True -// historyWrite.address := counter.resized -// historyWrite.data.confidence := 0 -// counter := counter + 1 -// } -// } -// } + def buildFetchPrediction(pipeline: VexRiscv): Unit = { + import pipeline._ + import pipeline.config._ + + + //Do branch calculations (conditions + target PC) + execute plug new Area { + import execute._ + + val less = input(SRC_LESS) + val eq = input(SRC1) === input(SRC2) + + insert(BRANCH_DO) := input(BRANCH_CTRL).mux( + BranchCtrlEnum.INC -> False, + BranchCtrlEnum.JAL -> True, + BranchCtrlEnum.JALR -> True, + BranchCtrlEnum.B -> input(INSTRUCTION)(14 downto 12).mux( + B"000" -> eq , + B"001" -> !eq , + M"1-1" -> !less, + default -> less + ) + ) + + val imm = IMM(input(INSTRUCTION)) + val branch_src1 = (input(BRANCH_CTRL) === BranchCtrlEnum.JALR) ? input(RS1).asUInt | input(PC) + val branch_src2 = input(BRANCH_CTRL).mux( + BranchCtrlEnum.JAL -> imm.j_sext, + BranchCtrlEnum.JALR -> imm.i_sext, + default -> imm.b_sext + ).asUInt + + val branchAdder = branch_src1 + branch_src2 + insert(BRANCH_CALC) := branchAdder(31 downto 1) @@ ((input(BRANCH_CTRL) === BranchCtrlEnum.JALR) ? False | branchAdder(0)) + } + + //Apply branchs (JAL,JALR, Bxx) + val branchStage = if(earlyBranch) execute else memory + branchStage plug new Area { + import branchStage._ + + val predictionMissmatch = fetchPrediction.cmd.hadBranch =/= input(BRANCH_DO) || (input(BRANCH_DO) && fetchPrediction.cmd.targetPc =/= input(BRANCH_CALC)) + fetchPrediction.rsp.wasRight := ! predictionMissmatch + fetchPrediction.rsp.finalPc := input(BRANCH_CALC) + + jumpInterface.valid := arbitration.isFiring && predictionMissmatch //Probably just isValid instead of isFiring is better + jumpInterface.payload := (input(BRANCH_DO) ? input(BRANCH_CALC) | input(PC) + (if(pipeline(RVC_GEN)) ((input(IS_RVC)) ? U(2) | U(4)) else 4)) + + + when(jumpInterface.valid) { + stages(indexOf(branchStage) - 1).arbitration.flushAll := True + } + + if(catchAddressMisaligned) { + branchExceptionPort.valid := arbitration.isValid && input(BRANCH_DO) && (if(pipeline(RVC_GEN)) input(BRANCH_CALC)(0 downto 0) =/= 0 else input(BRANCH_CALC)(1 downto 0) =/= 0) + branchExceptionPort.code := 0 + branchExceptionPort.badAddr := input(BRANCH_CALC) + } + } + } } \ No newline at end of file diff --git a/src/main/scala/vexriscv/plugin/Fetcher.scala b/src/main/scala/vexriscv/plugin/Fetcher.scala index 89c4f6e..dcbd3b3 100644 --- a/src/main/scala/vexriscv/plugin/Fetcher.scala +++ b/src/main/scala/vexriscv/plugin/Fetcher.scala @@ -22,6 +22,7 @@ abstract class IBusFetcherImpl(val catchAccessFault : Boolean, var prefetchExceptionPort : Flow[ExceptionCause] = null var decodePrediction : DecodePredictionBus = null + var fetchPrediction : FetchPredictionBus = null assert(cmdToRspStageCount >= 1) assert(!(cmdToRspStageCount == 1 && !injectorStage)) assert(!(compressedGen && !decodePcGen)) @@ -68,6 +69,9 @@ abstract class IBusFetcherImpl(val catchAccessFault : Boolean, predictionJumpInterface = createJumpInterface(pipeline.decode) decodePrediction = pipeline.service(classOf[PredictionInterface]).askDecodePrediction() } + case DYNAMIC_TARGET => { + fetchPrediction = pipeline.service(classOf[PredictionInterface]).askFetchPrediction() + } } } @@ -98,6 +102,7 @@ abstract class IBusFetcherImpl(val catchAccessFault : Boolean, class PcFetch extends Area{ val preOutput = Stream(UInt(32 bits)) val output = preOutput.haltWhen(fetcherHalt) + val predictionPcLoad = ifGen(prediction == DYNAMIC_TARGET) (Flow(UInt(32 bits))) } val fetchPc = if(relaxedPcCalculation) new PcFetch { @@ -117,6 +122,11 @@ abstract class IBusFetcherImpl(val catchAccessFault : Boolean, } //application of the selected jump request + if(predictionPcLoad != null) { + when(predictionPcLoad.valid) { + pcReg := predictionPcLoad.payload + } + } when(jump.pcLoad.valid) { pcReg := jump.pcLoad.payload } @@ -131,6 +141,13 @@ abstract class IBusFetcherImpl(val catchAccessFault : Boolean, val pc = pcReg + (inc ## B"00").asUInt val samplePcNext = False + if(predictionPcLoad != null) { + when(predictionPcLoad.valid) { + inc := False + samplePcNext := True + pc := predictionPcLoad.payload + } + } when(jump.pcLoad.valid) { inc := False samplePcNext := True @@ -252,26 +269,26 @@ abstract class IBusFetcherImpl(val catchAccessFault : Boolean, def condApply[T](that : T, cond : Boolean)(func : (T) => T) = if(cond)func(that) else that val injector = new Area { - val inputBeforeHalt = condApply(if (decodePcGen) decompressor.output else iBusRsp.output, injectorReadyCutGen)(_.s2mPipe(flush)) + val inputBeforeStage = condApply(if (decodePcGen) decompressor.output else iBusRsp.output, injectorReadyCutGen)(_.s2mPipe(flush)) if (injectorReadyCutGen) { - iBusRsp.readyForError.clearWhen(inputBeforeHalt.valid) - incomingInstruction setWhen (inputBeforeHalt.valid) + iBusRsp.readyForError.clearWhen(inputBeforeStage.valid) + incomingInstruction setWhen (inputBeforeStage.valid) } val decodeInput = (if (injectorStage) { - val decodeInput = inputBeforeHalt.m2sPipeWithFlush(killLastStage, collapsBubble = false) - decode.insert(INSTRUCTION_ANTICIPATED) := Mux(decode.arbitration.isStuck, decode.input(INSTRUCTION), inputBeforeHalt.rsp.inst) + val decodeInput = inputBeforeStage.m2sPipeWithFlush(killLastStage, collapsBubble = false) + decode.insert(INSTRUCTION_ANTICIPATED) := Mux(decode.arbitration.isStuck, decode.input(INSTRUCTION), inputBeforeStage.rsp.inst) iBusRsp.readyForError.clearWhen(decodeInput.valid) incomingInstruction setWhen (decodeInput.valid) decodeInput } else { - inputBeforeHalt + inputBeforeStage }) if (decodePcGen) { decodeNextPcValid := True decodeNextPc := decodePc.pcReg } else { - val lastStageStream = if (injectorStage) inputBeforeHalt + val lastStageStream = if (injectorStage) inputBeforeStage else if (cmdToRspStageCount > 1) iBusRsp.inputPipeline(cmdToRspStageCount - 2) else throw new Exception("Fetch should at least have two stages") @@ -354,7 +371,7 @@ abstract class IBusFetcherImpl(val catchAccessFault : Boolean, } } - prediction match { + val predictor = prediction match { case NONE => case STATIC | DYNAMIC => { def historyWidth = 2 @@ -393,6 +410,67 @@ abstract class IBusFetcherImpl(val catchAccessFault : Boolean, // predictionExceptionPort.badAddr := predictionJumpInterface.payload } } + case DYNAMIC_TARGET => new Area{ + val historyRamSizeLog2 : Int = 10 + case class BranchPredictorLine() extends Bundle{ + val source = Bits(31 - historyRamSizeLog2 bits) + val branchWish = UInt(2 bits) + val target = UInt(32 bits) + } + + val history = Mem(BranchPredictorLine(), 1 << historyRamSizeLog2) + val historyWrite = history.writePort + + val line = history.readSync((fetchPc.output.payload >> 2).resized, iBusRsp.inputPipeline(0).ready) + val hit = line.source === (iBusRsp.inputPipeline(0).payload.asBits >> 1 + historyRamSizeLog2) + + //Avoid write to read hazard + val historyWriteLast = RegNextWhen(historyWrite, iBusRsp.inputPipeline(0).ready) + val hazard = historyWriteLast.valid && historyWriteLast.address === (iBusRsp.inputPipeline(0).payload >> 2).resized + fetchPc.predictionPcLoad.valid := line.branchWish.msb && hit && iBusRsp.inputPipeline(0).fire && !flush && !hazard + fetchPc.predictionPcLoad.payload := line.target + + case class PredictionResult() extends Bundle{ + val hazard = Bool + val hit = Bool + val line = BranchPredictorLine() + } + + val fetchContext = PredictionResult() + fetchContext.hazard := hazard + fetchContext.hit := hit + fetchContext.line := line //RegNextWhen(e._1, e._2.) + val iBusRspContext = iBusRsp.inputPipeline.tail.foldLeft(fetchContext)((data,stream) => RegNextWhen(data, stream.ready)) + val injectorContext = Delay(iBusRspContext,cycleCount=if(injectorStage) 1 else 0, when=injector.decodeInput.ready) + + object PREDICTION_CONTEXT extends Stageable(PredictionResult()) + pipeline.decode.insert(PREDICTION_CONTEXT) := injectorContext + val branchStage = fetchPrediction.stage + val branchContext = branchStage.input(PREDICTION_CONTEXT) + fetchPrediction.cmd.hadBranch := branchContext.hit && !branchContext.hazard && branchContext.line.branchWish.msb + fetchPrediction.cmd.targetPc := branchContext.line.target + + + historyWrite.valid := False + historyWrite.address := (branchStage.input(PC) >> 2).resized + historyWrite.data.source := branchStage.input(PC).asBits >> 1 + historyRamSizeLog2 + historyWrite.data.target := fetchPrediction.rsp.finalPc + + when(fetchPrediction.rsp.wasRight) { + historyWrite.valid := branchContext.hit + historyWrite.data.branchWish := branchContext.line.branchWish + (branchContext.line.branchWish === 2).asUInt - (branchContext.line.branchWish === 1).asUInt + } otherwise { + when(branchContext.hit) { + historyWrite.valid := True + historyWrite.data.branchWish := branchContext.line.branchWish - (branchContext.line.branchWish.msb).asUInt + (!branchContext.line.branchWish.msb).asUInt + } otherwise { + historyWrite.valid := True + historyWrite.data.branchWish := "10" + } + } + + historyWrite.valid clearWhen(branchContext.hazard || !branchStage.arbitration.isFiring) + } } } } \ No newline at end of file