DYNAMIC_TARGET branch prediction back for not compressed ISA (PASS)

This commit is contained in:
Dolu1990 2018-05-21 13:45:08 +02:00
parent 7ffbfab312
commit ff760a0bf0
4 changed files with 169 additions and 156 deletions

View File

@ -41,7 +41,8 @@ object TestsWorkspace {
// ), // ),
new IBusCachedPlugin( new IBusCachedPlugin(
resetVector = 0x80000000l, resetVector = 0x80000000l,
compressedGen = true, compressedGen = false,
prediction = DYNAMIC_TARGET,
config = InstructionCacheConfig( config = InstructionCacheConfig(
cacheSize = 1024*16, cacheSize = 1024*16,
bytePerLine = 32, bytePerLine = 32,

View File

@ -53,6 +53,8 @@ object BrieyConfig{
// catchAccessFault = true // catchAccessFault = true
// ), // ),
new IBusCachedPlugin( new IBusCachedPlugin(
resetVector = 0x80000000l,
prediction = STATIC,
config = InstructionCacheConfig( config = InstructionCacheConfig(
cacheSize = 4096, cacheSize = 4096,
bytePerLine =32, bytePerLine =32,
@ -64,7 +66,8 @@ object BrieyConfig{
catchAccessFault = true, catchAccessFault = true,
catchMemoryTranslationMiss = true, catchMemoryTranslationMiss = true,
asyncTagMemory = false, asyncTagMemory = false,
twoCycleRam = true twoCycleRam = true,
twoCycleCache = true
) )
// askMemoryTranslation = true, // askMemoryTranslation = true,
// memoryTranslatorPortConfig = MemoryTranslatorPortConfig( // memoryTranslatorPortConfig = MemoryTranslatorPortConfig(
@ -123,7 +126,7 @@ object BrieyConfig{
new BranchPlugin( new BranchPlugin(
earlyBranch = false, earlyBranch = false,
catchAddressMisaligned = true, catchAddressMisaligned = true,
prediction = STATIC prediction = NONE
), ),
new CsrPlugin( new CsrPlugin(
config = CsrPluginConfig( config = CsrPluginConfig(

View File

@ -32,13 +32,13 @@ case class FetchPredictionCmd() extends Bundle{
val hadBranch = Bool val hadBranch = Bool
val targetPc = UInt(32 bits) val targetPc = UInt(32 bits)
} }
case class FetchPredictionRsp(stage : Stage) extends Bundle{ case class FetchPredictionRsp() extends Bundle{
val wasRight = Bool val wasRight = Bool
val targetPc = UInt(32 bits) val finalPc = UInt(32 bits)
} }
case class FetchPredictionBus(stage : Stage) extends Bundle { case class FetchPredictionBus(stage : Stage) extends Bundle {
val cmd = FetchPredictionCmd() val cmd = FetchPredictionCmd()
val rsp = FetchPredictionRsp(stage) val rsp = FetchPredictionRsp()
} }
@ -68,9 +68,14 @@ class BranchPlugin(earlyBranch : Boolean,
var decodePrediction : DecodePredictionBus = null var decodePrediction : DecodePredictionBus = null
var fetchPrediction : FetchPredictionBus = null
override def askFetchPrediction() = ??? override def askFetchPrediction() = {
fetchPrediction = FetchPredictionBus(branchStage)
fetchPrediction
}
override def askDecodePrediction() = { override def askDecodePrediction() = {
decodePrediction = DecodePredictionBus(branchStage) decodePrediction = DecodePredictionBus(branchStage)
decodePrediction decodePrediction
@ -131,9 +136,10 @@ class BranchPlugin(earlyBranch : Boolean,
} }
} }
override def build(pipeline: VexRiscv): Unit = (decodePrediction) match { override def build(pipeline: VexRiscv): Unit = (fetchPrediction,decodePrediction) match {
case null => buildWithoutPrediction(pipeline) case (null, null) => buildWithoutPrediction(pipeline)
case _ => buildWithPrediction(pipeline) case (_ , null) => buildFetchPrediction(pipeline)
case (null, _) => buildDecodePrediction(pipeline)
// case `DYNAMIC` => buildWithPrediction(pipeline) // case `DYNAMIC` => buildWithPrediction(pipeline)
// case `DYNAMIC_TARGET` => buildDynamicTargetPrediction(pipeline) // case `DYNAMIC_TARGET` => buildDynamicTargetPrediction(pipeline)
} }
@ -192,7 +198,7 @@ class BranchPlugin(earlyBranch : Boolean,
} }
def buildWithPrediction(pipeline: VexRiscv): Unit = { def buildDecodePrediction(pipeline: VexRiscv): Unit = {
// case class BranchPredictorLine() extends Bundle{ // case class BranchPredictorLine() extends Bundle{
// val history = SInt(historyWidth bits) // val history = SInt(historyWidth bits)
// } // }
@ -298,9 +304,9 @@ class BranchPlugin(earlyBranch : Boolean,
} }
if(catchAddressMisaligned) { if(catchAddressMisaligned) {
branchExceptionPort.valid := arbitration.isValid && input(BRANCH_DO) && (if(pipeline(RVC_GEN)) jumpInterface.payload(0 downto 0) =/= 0 else jumpInterface.payload(1 downto 0) =/= 0) branchExceptionPort.valid := arbitration.isValid && input(BRANCH_DO) && (if(pipeline(RVC_GEN)) input(BRANCH_CALC)(0 downto 0) =/= 0 else input(BRANCH_CALC)(1 downto 0) =/= 0)
branchExceptionPort.code := 0 branchExceptionPort.code := 0
branchExceptionPort.badAddr := jumpInterface.payload branchExceptionPort.badAddr := input(BRANCH_CALC)
} }
} }
@ -321,139 +327,64 @@ class BranchPlugin(earlyBranch : Boolean,
// def buildDynamicTargetPrediction(pipeline: VexRiscv): Unit = { def buildFetchPrediction(pipeline: VexRiscv): Unit = {
// import pipeline._ import pipeline._
// import pipeline.config._ import pipeline.config._
//
// case class BranchPredictorLine() extends Bundle{
// val source = Bits(31 - historyRamSizeLog2 bits) //Do branch calculations (conditions + target PC)
// val confidence = UInt(2 bits) execute plug new Area {
// val target = UInt(32 bits) import execute._
// }
// val less = input(SRC_LESS)
// object PREDICTION_WRITE_HAZARD extends Stageable(Bool) val eq = input(SRC1) === input(SRC2)
// object PREDICTION extends Stageable(BranchPredictorLine())
// object PREDICTION_HIT extends Stageable(Bool) insert(BRANCH_DO) := input(BRANCH_CTRL).mux(
// BranchCtrlEnum.INC -> False,
// val history = Mem(BranchPredictorLine(), 1 << historyRamSizeLog2) BranchCtrlEnum.JAL -> True,
// val historyWrite = history.writePort BranchCtrlEnum.JALR -> True,
// BranchCtrlEnum.B -> input(INSTRUCTION)(14 downto 12).mux(
// B"000" -> eq ,
// fetch plug new Area{ B"001" -> !eq ,
// import fetch._ M"1-1" -> !less,
// val line = history.readSync((prefetch.output(PC) >> 2).resized, prefetch.arbitration.isFiring) default -> less
//// val line = history.readAsync((fetch.output(PC) >> 2).resized) )
// val hit = line.source === (input(PC).asBits >> 1 + historyRamSizeLog2) )
//
// //Avoid write to read hazard val imm = IMM(input(INSTRUCTION))
// val historyWriteLast = RegNext(historyWrite) val branch_src1 = (input(BRANCH_CTRL) === BranchCtrlEnum.JALR) ? input(RS1).asUInt | input(PC)
// val hazard = historyWriteLast.valid && historyWriteLast.address === (output(PC) >> 2).resized val branch_src2 = input(BRANCH_CTRL).mux(
// insert(PREDICTION_WRITE_HAZARD) := hazard BranchCtrlEnum.JAL -> imm.j_sext,
// BranchCtrlEnum.JALR -> imm.i_sext,
// predictionJumpInterface.valid := line.confidence.msb && hit && arbitration.isFiring && !hazard default -> imm.b_sext
// predictionJumpInterface.payload := line.target ).asUInt
//
// insert(PREDICTION) := line val branchAdder = branch_src1 + branch_src2
// insert(PREDICTION_HIT) := hit insert(BRANCH_CALC) := branchAdder(31 downto 1) @@ ((input(BRANCH_CTRL) === BranchCtrlEnum.JALR) ? False | branchAdder(0))
// } }
//
// //Apply branchs (JAL,JALR, Bxx)
// val branchStage = if(earlyBranch) execute else memory
// //Do branch calculations (conditions + target PC) branchStage plug new Area {
// execute plug new Area { import branchStage._
// import execute._
// val predictionMissmatch = fetchPrediction.cmd.hadBranch =/= input(BRANCH_DO) || (input(BRANCH_DO) && fetchPrediction.cmd.targetPc =/= input(BRANCH_CALC))
// val less = input(SRC_LESS) fetchPrediction.rsp.wasRight := ! predictionMissmatch
// val eq = input(SRC1) === input(SRC2) fetchPrediction.rsp.finalPc := input(BRANCH_CALC)
//
// insert(BRANCH_DO) := input(BRANCH_CTRL).mux( jumpInterface.valid := arbitration.isFiring && predictionMissmatch //Probably just isValid instead of isFiring is better
// BranchCtrlEnum.INC -> False, jumpInterface.payload := (input(BRANCH_DO) ? input(BRANCH_CALC) | input(PC) + (if(pipeline(RVC_GEN)) ((input(IS_RVC)) ? U(2) | U(4)) else 4))
// BranchCtrlEnum.JAL -> True,
// BranchCtrlEnum.JALR -> True,
// BranchCtrlEnum.B -> input(INSTRUCTION)(14 downto 12).mux( when(jumpInterface.valid) {
// B"000" -> eq , stages(indexOf(branchStage) - 1).arbitration.flushAll := True
// B"001" -> !eq , }
// M"1-1" -> !less,
// default -> less if(catchAddressMisaligned) {
// ) branchExceptionPort.valid := arbitration.isValid && input(BRANCH_DO) && (if(pipeline(RVC_GEN)) input(BRANCH_CALC)(0 downto 0) =/= 0 else input(BRANCH_CALC)(1 downto 0) =/= 0)
// ) branchExceptionPort.code := 0
// branchExceptionPort.badAddr := input(BRANCH_CALC)
// val imm = IMM(input(INSTRUCTION)) }
// val branch_src1 = (input(BRANCH_CTRL) === BranchCtrlEnum.JALR) ? input(RS1).asUInt | input(PC) }
// val branch_src2 = input(BRANCH_CTRL).mux( }
// BranchCtrlEnum.JAL -> imm.j_sext,
// BranchCtrlEnum.JALR -> imm.i_sext,
// default -> imm.b_sext
// ).asUInt
//
// val branchAdder = branch_src1 + branch_src2
// insert(BRANCH_CALC) := branchAdder(31 downto 1) @@ ((input(BRANCH_CTRL) === BranchCtrlEnum.JALR) ? False | branchAdder(0))
// }
//
// //Apply branchs (JAL,JALR, Bxx)
// val branchStage = if(earlyBranch) execute else memory
// branchStage plug new Area {
// import branchStage._
//
// val predictionMissmatch = input(PREDICTION).confidence.msb =/= input(BRANCH_DO) || (input(BRANCH_DO) && input(PREDICTION).target =/= input(BRANCH_CALC))
//
// historyWrite.valid := False
// historyWrite.address := (branchStage.output(PC) >> 2).resized
// historyWrite.data.source := input(PC).asBits >> 1 + historyRamSizeLog2
// historyWrite.data.target := input(BRANCH_CALC)
//
// jumpInterface.valid := False
// jumpInterface.payload := input(BRANCH_CALC)
//
//
// when(!input(BRANCH_DO)){
// historyWrite.valid := arbitration.isFiring && input(PREDICTION_HIT)
// historyWrite.data.confidence := input(PREDICTION).confidence - (input(PREDICTION).confidence =/= 0).asUInt
// historyWrite.data.target := input(BRANCH_CALC)
//
//
// jumpInterface.valid := input(PREDICTION_HIT) && input(PREDICTION).confidence.msb && !input(PREDICTION_WRITE_HAZARD) && arbitration.isFiring
// jumpInterface.payload := input(PC) + 4
// } otherwise{
// when(!input(PREDICTION_HIT) || input(PREDICTION_WRITE_HAZARD)){
// jumpInterface.valid := arbitration.isFiring
// historyWrite.valid := arbitration.isFiring
// historyWrite.data.confidence := "10"
// } otherwise {
// historyWrite.valid := arbitration.isFiring
// historyWrite.data.confidence := input(PREDICTION).confidence + (input(PREDICTION).confidence =/= 3).asUInt
// when(!input(PREDICTION).confidence.msb || input(PREDICTION).target =/= input(BRANCH_CALC)){
// jumpInterface.valid := arbitration.isFiring
// }
// }
// }
//
// //Prevent rewriting an history which already had hazard
// historyWrite.valid clearWhen(input(PREDICTION_WRITE_HAZARD))
//
//
//
// when(jumpInterface.valid) {
// stages(indexOf(branchStage) - 1).arbitration.flushAll := True
// }
//
// if(catchAddressMisaligned) {
// branchExceptionPort.valid := arbitration.isValid && input(BRANCH_DO) && jumpInterface.payload(1 downto 0) =/= 0
// branchExceptionPort.code := 0
// branchExceptionPort.badAddr := jumpInterface.payload
// }
// }
//
// //Init History
// val historyInit = pipeline plug new Area{
// val counter = Reg(UInt(historyRamSizeLog2 + 1 bits)) init(0)
// when(!counter.msb){
// prefetch.arbitration.haltByOther := True
// historyWrite.valid := True
// historyWrite.address := counter.resized
// historyWrite.data.confidence := 0
// counter := counter + 1
// }
// }
// }
} }

View File

@ -22,6 +22,7 @@ abstract class IBusFetcherImpl(val catchAccessFault : Boolean,
var prefetchExceptionPort : Flow[ExceptionCause] = null var prefetchExceptionPort : Flow[ExceptionCause] = null
var decodePrediction : DecodePredictionBus = null var decodePrediction : DecodePredictionBus = null
var fetchPrediction : FetchPredictionBus = null
assert(cmdToRspStageCount >= 1) assert(cmdToRspStageCount >= 1)
assert(!(cmdToRspStageCount == 1 && !injectorStage)) assert(!(cmdToRspStageCount == 1 && !injectorStage))
assert(!(compressedGen && !decodePcGen)) assert(!(compressedGen && !decodePcGen))
@ -68,6 +69,9 @@ abstract class IBusFetcherImpl(val catchAccessFault : Boolean,
predictionJumpInterface = createJumpInterface(pipeline.decode) predictionJumpInterface = createJumpInterface(pipeline.decode)
decodePrediction = pipeline.service(classOf[PredictionInterface]).askDecodePrediction() decodePrediction = pipeline.service(classOf[PredictionInterface]).askDecodePrediction()
} }
case DYNAMIC_TARGET => {
fetchPrediction = pipeline.service(classOf[PredictionInterface]).askFetchPrediction()
}
} }
} }
@ -98,6 +102,7 @@ abstract class IBusFetcherImpl(val catchAccessFault : Boolean,
class PcFetch extends Area{ class PcFetch extends Area{
val preOutput = Stream(UInt(32 bits)) val preOutput = Stream(UInt(32 bits))
val output = preOutput.haltWhen(fetcherHalt) val output = preOutput.haltWhen(fetcherHalt)
val predictionPcLoad = ifGen(prediction == DYNAMIC_TARGET) (Flow(UInt(32 bits)))
} }
val fetchPc = if(relaxedPcCalculation) new PcFetch { val fetchPc = if(relaxedPcCalculation) new PcFetch {
@ -117,6 +122,11 @@ abstract class IBusFetcherImpl(val catchAccessFault : Boolean,
} }
//application of the selected jump request //application of the selected jump request
if(predictionPcLoad != null) {
when(predictionPcLoad.valid) {
pcReg := predictionPcLoad.payload
}
}
when(jump.pcLoad.valid) { when(jump.pcLoad.valid) {
pcReg := jump.pcLoad.payload pcReg := jump.pcLoad.payload
} }
@ -131,6 +141,13 @@ abstract class IBusFetcherImpl(val catchAccessFault : Boolean,
val pc = pcReg + (inc ## B"00").asUInt val pc = pcReg + (inc ## B"00").asUInt
val samplePcNext = False val samplePcNext = False
if(predictionPcLoad != null) {
when(predictionPcLoad.valid) {
inc := False
samplePcNext := True
pc := predictionPcLoad.payload
}
}
when(jump.pcLoad.valid) { when(jump.pcLoad.valid) {
inc := False inc := False
samplePcNext := True samplePcNext := True
@ -252,26 +269,26 @@ abstract class IBusFetcherImpl(val catchAccessFault : Boolean,
def condApply[T](that : T, cond : Boolean)(func : (T) => T) = if(cond)func(that) else that def condApply[T](that : T, cond : Boolean)(func : (T) => T) = if(cond)func(that) else that
val injector = new Area { val injector = new Area {
val inputBeforeHalt = condApply(if (decodePcGen) decompressor.output else iBusRsp.output, injectorReadyCutGen)(_.s2mPipe(flush)) val inputBeforeStage = condApply(if (decodePcGen) decompressor.output else iBusRsp.output, injectorReadyCutGen)(_.s2mPipe(flush))
if (injectorReadyCutGen) { if (injectorReadyCutGen) {
iBusRsp.readyForError.clearWhen(inputBeforeHalt.valid) iBusRsp.readyForError.clearWhen(inputBeforeStage.valid)
incomingInstruction setWhen (inputBeforeHalt.valid) incomingInstruction setWhen (inputBeforeStage.valid)
} }
val decodeInput = (if (injectorStage) { val decodeInput = (if (injectorStage) {
val decodeInput = inputBeforeHalt.m2sPipeWithFlush(killLastStage, collapsBubble = false) val decodeInput = inputBeforeStage.m2sPipeWithFlush(killLastStage, collapsBubble = false)
decode.insert(INSTRUCTION_ANTICIPATED) := Mux(decode.arbitration.isStuck, decode.input(INSTRUCTION), inputBeforeHalt.rsp.inst) decode.insert(INSTRUCTION_ANTICIPATED) := Mux(decode.arbitration.isStuck, decode.input(INSTRUCTION), inputBeforeStage.rsp.inst)
iBusRsp.readyForError.clearWhen(decodeInput.valid) iBusRsp.readyForError.clearWhen(decodeInput.valid)
incomingInstruction setWhen (decodeInput.valid) incomingInstruction setWhen (decodeInput.valid)
decodeInput decodeInput
} else { } else {
inputBeforeHalt inputBeforeStage
}) })
if (decodePcGen) { if (decodePcGen) {
decodeNextPcValid := True decodeNextPcValid := True
decodeNextPc := decodePc.pcReg decodeNextPc := decodePc.pcReg
} else { } else {
val lastStageStream = if (injectorStage) inputBeforeHalt val lastStageStream = if (injectorStage) inputBeforeStage
else if (cmdToRspStageCount > 1) iBusRsp.inputPipeline(cmdToRspStageCount - 2) else if (cmdToRspStageCount > 1) iBusRsp.inputPipeline(cmdToRspStageCount - 2)
else throw new Exception("Fetch should at least have two stages") else throw new Exception("Fetch should at least have two stages")
@ -354,7 +371,7 @@ abstract class IBusFetcherImpl(val catchAccessFault : Boolean,
} }
} }
prediction match { val predictor = prediction match {
case NONE => case NONE =>
case STATIC | DYNAMIC => { case STATIC | DYNAMIC => {
def historyWidth = 2 def historyWidth = 2
@ -393,6 +410,67 @@ abstract class IBusFetcherImpl(val catchAccessFault : Boolean,
// predictionExceptionPort.badAddr := predictionJumpInterface.payload // predictionExceptionPort.badAddr := predictionJumpInterface.payload
} }
} }
case DYNAMIC_TARGET => new Area{
val historyRamSizeLog2 : Int = 10
case class BranchPredictorLine() extends Bundle{
val source = Bits(31 - historyRamSizeLog2 bits)
val branchWish = UInt(2 bits)
val target = UInt(32 bits)
}
val history = Mem(BranchPredictorLine(), 1 << historyRamSizeLog2)
val historyWrite = history.writePort
val line = history.readSync((fetchPc.output.payload >> 2).resized, iBusRsp.inputPipeline(0).ready)
val hit = line.source === (iBusRsp.inputPipeline(0).payload.asBits >> 1 + historyRamSizeLog2)
//Avoid write to read hazard
val historyWriteLast = RegNextWhen(historyWrite, iBusRsp.inputPipeline(0).ready)
val hazard = historyWriteLast.valid && historyWriteLast.address === (iBusRsp.inputPipeline(0).payload >> 2).resized
fetchPc.predictionPcLoad.valid := line.branchWish.msb && hit && iBusRsp.inputPipeline(0).fire && !flush && !hazard
fetchPc.predictionPcLoad.payload := line.target
case class PredictionResult() extends Bundle{
val hazard = Bool
val hit = Bool
val line = BranchPredictorLine()
}
val fetchContext = PredictionResult()
fetchContext.hazard := hazard
fetchContext.hit := hit
fetchContext.line := line //RegNextWhen(e._1, e._2.)
val iBusRspContext = iBusRsp.inputPipeline.tail.foldLeft(fetchContext)((data,stream) => RegNextWhen(data, stream.ready))
val injectorContext = Delay(iBusRspContext,cycleCount=if(injectorStage) 1 else 0, when=injector.decodeInput.ready)
object PREDICTION_CONTEXT extends Stageable(PredictionResult())
pipeline.decode.insert(PREDICTION_CONTEXT) := injectorContext
val branchStage = fetchPrediction.stage
val branchContext = branchStage.input(PREDICTION_CONTEXT)
fetchPrediction.cmd.hadBranch := branchContext.hit && !branchContext.hazard && branchContext.line.branchWish.msb
fetchPrediction.cmd.targetPc := branchContext.line.target
historyWrite.valid := False
historyWrite.address := (branchStage.input(PC) >> 2).resized
historyWrite.data.source := branchStage.input(PC).asBits >> 1 + historyRamSizeLog2
historyWrite.data.target := fetchPrediction.rsp.finalPc
when(fetchPrediction.rsp.wasRight) {
historyWrite.valid := branchContext.hit
historyWrite.data.branchWish := branchContext.line.branchWish + (branchContext.line.branchWish === 2).asUInt - (branchContext.line.branchWish === 1).asUInt
} otherwise {
when(branchContext.hit) {
historyWrite.valid := True
historyWrite.data.branchWish := branchContext.line.branchWish - (branchContext.line.branchWish.msb).asUInt + (!branchContext.line.branchWish.msb).asUInt
} otherwise {
historyWrite.valid := True
historyWrite.data.branchWish := "10"
}
}
historyWrite.valid clearWhen(branchContext.hazard || !branchStage.arbitration.isFiring)
}
} }
} }
} }