Static prediction is fully functionnal

This commit is contained in:
Dolu1990 2018-04-02 17:43:06 +02:00
parent 0919308a8f
commit 76ca852478
2 changed files with 173 additions and 86 deletions

View File

@ -11,25 +11,71 @@ object STATIC extends BranchPrediction
object DYNAMIC extends BranchPrediction object DYNAMIC extends BranchPrediction
object DYNAMIC_TARGET extends BranchPrediction object DYNAMIC_TARGET extends BranchPrediction
object BranchCtrlEnum extends SpinalEnum(binarySequential){
val INC,B,JAL,JALR = newElement()
}
object BRANCH_CTRL extends Stageable(BranchCtrlEnum())
case class DecodePredictionCmd() extends Bundle {
val hadBranch = Bool
}
case class DecodePredictionRsp(stage : Stage) extends Bundle {
val wasWrong = Bool
}
case class DecodePredictionBus(stage : Stage) extends Bundle {
val cmd = DecodePredictionCmd()
val rsp = DecodePredictionRsp(stage)
}
case class FetchPredictionCmd() extends Bundle{
val hadBranch = Bool
val targetPc = UInt(32 bits)
}
case class FetchPredictionRsp(stage : Stage) extends Bundle{
val wasRight = Bool
val targetPc = UInt(32 bits)
}
case class FetchPredictionBus(stage : Stage) extends Bundle {
val cmd = FetchPredictionCmd()
val rsp = FetchPredictionRsp(stage)
}
trait PredictionInterface{
def askFetchPrediction() : FetchPredictionBus
def askDecodePrediction() : DecodePredictionBus
}
class BranchPlugin(earlyBranch : Boolean, class BranchPlugin(earlyBranch : Boolean,
catchAddressMisaligned : Boolean, catchAddressMisaligned : Boolean,
prediction : BranchPrediction, prediction : BranchPrediction,
historyRamSizeLog2 : Int = 10, historyRamSizeLog2 : Int = 10,
historyWidth : Int = 2) extends Plugin[VexRiscv]{ historyWidth : Int = 2) extends Plugin[VexRiscv] with PredictionInterface{
object BranchCtrlEnum extends SpinalEnum(binarySequential){
val INC,B,JAL,JALR = newElement()
} lazy val branchStage = if(earlyBranch) pipeline.execute else pipeline.memory
object BRANCH_CTRL extends Stageable(BranchCtrlEnum())
object BRANCH_CALC extends Stageable(UInt(32 bits)) object BRANCH_CALC extends Stageable(UInt(32 bits))
object BRANCH_DO extends Stageable(Bool) object BRANCH_DO extends Stageable(Bool)
object BRANCH_COND_RESULT extends Stageable(Bool) object BRANCH_COND_RESULT extends Stageable(Bool)
// object PREDICTION_HAD_BRANCHED extends Stageable(Bool)
var jumpInterface : Flow[UInt] = null var jumpInterface : Flow[UInt] = null
var predictionJumpInterface : Flow[UInt] = null var predictionJumpInterface : Flow[UInt] = null
var predictionExceptionPort : Flow[ExceptionCause] = null var predictionExceptionPort : Flow[ExceptionCause] = null
var branchExceptionPort : Flow[ExceptionCause] = null var branchExceptionPort : Flow[ExceptionCause] = null
var decodePrediction : DecodePredictionBus = null
override def askFetchPrediction() = ???
override def askDecodePrediction() = {
decodePrediction = DecodePredictionBus(branchStage)
decodePrediction
}
override def setup(pipeline: VexRiscv): Unit = { override def setup(pipeline: VexRiscv): Unit = {
import Riscv._ import Riscv._
import pipeline.config._ import pipeline.config._
@ -66,7 +112,7 @@ class BranchPlugin(earlyBranch : Boolean,
)) ))
val pcManagerService = pipeline.service(classOf[JumpService]) val pcManagerService = pipeline.service(classOf[JumpService])
jumpInterface = pcManagerService.createJumpInterface(if(earlyBranch) pipeline.execute else pipeline.memory) jumpInterface = pcManagerService.createJumpInterface(branchStage)
prediction match { prediction match {
case NONE => case NONE =>
@ -76,7 +122,7 @@ class BranchPlugin(earlyBranch : Boolean,
if (catchAddressMisaligned) { if (catchAddressMisaligned) {
val exceptionService = pipeline.service(classOf[ExceptionService]) val exceptionService = pipeline.service(classOf[ExceptionService])
branchExceptionPort = exceptionService.newExceptionPort(if (earlyBranch) pipeline.execute else pipeline.memory) branchExceptionPort = exceptionService.newExceptionPort(branchStage)
prediction match { prediction match {
case NONE => case NONE =>
// case STATIC | DYNAMIC => predictionExceptionPort = exceptionService.newExceptionPort(pipeline.decode) // case STATIC | DYNAMIC => predictionExceptionPort = exceptionService.newExceptionPort(pipeline.decode)
@ -85,9 +131,9 @@ class BranchPlugin(earlyBranch : Boolean,
} }
} }
override def build(pipeline: VexRiscv): Unit = prediction match { override def build(pipeline: VexRiscv): Unit = (decodePrediction) match {
case `NONE` => buildWithoutPrediction(pipeline) case null => buildWithoutPrediction(pipeline)
// case `STATIC` => buildWithPrediction(pipeline) case _ => buildWithPrediction(pipeline)
// case `DYNAMIC` => buildWithPrediction(pipeline) // case `DYNAMIC` => buildWithPrediction(pipeline)
// case `DYNAMIC_TARGET` => buildDynamicTargetPrediction(pipeline) // case `DYNAMIC_TARGET` => buildDynamicTargetPrediction(pipeline)
} }
@ -128,7 +174,6 @@ class BranchPlugin(earlyBranch : Boolean,
} }
//Apply branchs (JAL,JALR, Bxx) //Apply branchs (JAL,JALR, Bxx)
val branchStage = if(earlyBranch) execute else memory
branchStage plug new Area { branchStage plug new Area {
import branchStage._ import branchStage._
jumpInterface.valid := arbitration.isFiring && input(BRANCH_DO) jumpInterface.valid := arbitration.isFiring && input(BRANCH_DO)
@ -147,21 +192,21 @@ class BranchPlugin(earlyBranch : Boolean,
} }
// def buildWithPrediction(pipeline: VexRiscv): Unit = { def buildWithPrediction(pipeline: VexRiscv): Unit = {
// case class BranchPredictorLine() extends Bundle{ // case class BranchPredictorLine() extends Bundle{
// val history = SInt(historyWidth bits) // val history = SInt(historyWidth bits)
// } // }
//
// object PREDICTION_HAD_BRANCHED extends Stageable(Bool) object PREDICTION_HAD_BRANCHED extends Stageable(Bool)
// object HISTORY_LINE extends Stageable(BranchPredictorLine()) // object HISTORY_LINE extends Stageable(BranchPredictorLine())
//
// import pipeline._ import pipeline._
// import pipeline.config._ import pipeline.config._
//
// val historyCache = if(prediction == DYNAMIC) Mem(BranchPredictorLine(), 1 << historyRamSizeLog2) setName("branchCache") else null // val historyCache = if(prediction == DYNAMIC) Mem(BranchPredictorLine(), 1 << historyRamSizeLog2) setName("branchCache") else null
// val historyCacheWrite = if(prediction == DYNAMIC) historyCache.writePort else null // val historyCacheWrite = if(prediction == DYNAMIC) historyCache.writePort else null
//
// //Read historyCache //Read historyCache
// if(prediction == DYNAMIC) fetch plug new Area{ // if(prediction == DYNAMIC) fetch plug new Area{
// val readAddress = prefetch.output(PC)(2, historyRamSizeLog2 bits) // val readAddress = prefetch.output(PC)(2, historyRamSizeLog2 bits)
// fetch.insert(HISTORY_LINE) := historyCache.readSync(readAddress,!prefetch.arbitration.isStuckByOthers) // fetch.insert(HISTORY_LINE) := historyCache.readSync(readAddress,!prefetch.arbitration.isStuckByOthers)
@ -172,8 +217,8 @@ class BranchPlugin(earlyBranch : Boolean,
//// fetch.insert(HISTORY_LINE) := writePortReg.data //// fetch.insert(HISTORY_LINE) := writePortReg.data
//// } //// }
// } // }
//
// //Branch JAL, predict Bxx and branch it //Branch JAL, predict Bxx and branch it
// decode plug new Area{ // decode plug new Area{
// import decode._ // import decode._
// val imm = IMM(input(INSTRUCTION)) // val imm = IMM(input(INSTRUCTION))
@ -196,65 +241,71 @@ class BranchPlugin(earlyBranch : Boolean,
// predictionExceptionPort.badAddr := predictionJumpInterface.payload // predictionExceptionPort.badAddr := predictionJumpInterface.payload
// } // }
// } // }
//
// //Do real branch calculation decode plug new Area {
// execute plug new Area { import decode._
// import execute._ insert(PREDICTION_HAD_BRANCHED) := decodePrediction.cmd.hadBranch
// }
// val less = input(SRC_LESS)
// val eq = input(SRC1) === input(SRC2) //Do real branch calculation
// execute plug new Area {
// insert(BRANCH_COND_RESULT) := input(BRANCH_CTRL).mux( import execute._
// BranchCtrlEnum.INC -> False,
// BranchCtrlEnum.JAL -> True, val less = input(SRC_LESS)
// BranchCtrlEnum.JALR -> True, val eq = input(SRC1) === input(SRC2)
// BranchCtrlEnum.B -> input(INSTRUCTION)(14 downto 12).mux(
// B"000" -> eq , insert(BRANCH_COND_RESULT) := input(BRANCH_CTRL).mux(
// B"001" -> !eq , BranchCtrlEnum.INC -> False,
// M"1-1" -> !less, BranchCtrlEnum.JAL -> True,
// default -> less BranchCtrlEnum.JALR -> True,
// ) BranchCtrlEnum.B -> input(INSTRUCTION)(14 downto 12).mux(
// ) B"000" -> eq ,
// B"001" -> !eq ,
// insert(BRANCH_DO) := input(PREDICTION_HAD_BRANCHED) =/= insert(BRANCH_COND_RESULT) M"1-1" -> !less,
// default -> less
// //Calculation of the branch target / correction )
// val imm = IMM(input(INSTRUCTION)) )
// val branch_src1,branch_src2 = UInt(32 bits)
// switch(input(BRANCH_CTRL)){ insert(BRANCH_DO) := input(PREDICTION_HAD_BRANCHED) =/= insert(BRANCH_COND_RESULT)
// is(BranchCtrlEnum.JALR){
// branch_src1 := input(RS1).asUInt //Calculation of the branch target / correction
// branch_src2 := imm.i_sext.asUInt val imm = IMM(input(INSTRUCTION))
// } val branch_src1,branch_src2 = UInt(32 bits)
// default{ switch(input(BRANCH_CTRL)){
// branch_src1 := input(PC) is(BranchCtrlEnum.JALR){
// branch_src2 := (input(PREDICTION_HAD_BRANCHED) ? B(4) | imm.b_sext).asUInt branch_src1 := input(RS1).asUInt
// } branch_src2 := imm.i_sext.asUInt
// } }
// val branchAdder = branch_src1 + branch_src2 default{
// insert(BRANCH_CALC) := branchAdder(31 downto 1) @@ ((input(BRANCH_CTRL) === BranchCtrlEnum.JALR) ? False | branchAdder(0)) branch_src1 := input(PC)
// } branch_src2 := (input(PREDICTION_HAD_BRANCHED) ? B(4) | imm.b_sext).asUInt
// }
// }
// // branch JALR or JAL/Bxx prediction miss corrections val branchAdder = branch_src1 + branch_src2
// val branchStage = if(earlyBranch) execute else memory insert(BRANCH_CALC) := branchAdder(31 downto 1) @@ ((input(BRANCH_CTRL) === BranchCtrlEnum.JALR) ? False | branchAdder(0))
// branchStage plug new Area { }
// import branchStage._
// jumpInterface.valid := input(BRANCH_DO) && arbitration.isFiring
// jumpInterface.payload := input(BRANCH_CALC) // branch JALR or JAL/Bxx prediction miss corrections
// val branchStage = if(earlyBranch) execute else memory
// when(jumpInterface.valid) { branchStage plug new Area {
// stages(indexOf(branchStage) - 1).arbitration.flushAll := True import branchStage._
// } jumpInterface.valid := input(BRANCH_DO) && arbitration.isFiring
// jumpInterface.payload := input(BRANCH_CALC)
// if(catchAddressMisaligned) {
// branchExceptionPort.valid := input(INSTRUCTION_READY) && arbitration.isValid && input(BRANCH_DO) && jumpInterface.payload(1 downto 0) =/= 0 when(jumpInterface.valid) {
// branchExceptionPort.code := 0 stages(indexOf(branchStage) - 1).arbitration.flushAll := True
// branchExceptionPort.badAddr := jumpInterface.payload }
// }
// } if(catchAddressMisaligned) {
// branchExceptionPort.valid := input(INSTRUCTION_READY) && arbitration.isValid && input(BRANCH_DO) && jumpInterface.payload(1 downto 0) =/= 0
// //Update historyCache branchExceptionPort.code := 0
branchExceptionPort.badAddr := jumpInterface.payload
}
}
//Update historyCache
decodePrediction.rsp.wasWrong := jumpInterface.valid
// if(prediction == DYNAMIC) branchStage plug new Area { // if(prediction == DYNAMIC) branchStage plug new Area {
// import branchStage._ // import branchStage._
// val newHistory = input(HISTORY_LINE).history.resize(historyWidth + 1) + Mux(input(BRANCH_COND_RESULT),S(-1),S(1)) // val newHistory = input(HISTORY_LINE).history.resize(historyWidth + 1) + Mux(input(BRANCH_COND_RESULT),S(-1),S(1))
@ -264,7 +315,7 @@ class BranchPlugin(earlyBranch : Boolean,
// historyCacheWrite.address := input(PC)(2, historyRamSizeLog2 bits) // historyCacheWrite.address := input(PC)(2, historyRamSizeLog2 bits)
// historyCacheWrite.data.history := newHistory.resized // historyCacheWrite.data.history := newHistory.resized
// } // }
// } }

View File

@ -5,6 +5,7 @@ import spinal.core._
import spinal.lib._ import spinal.lib._
import spinal.lib.bus.amba4.axi._ import spinal.lib.bus.amba4.axi._
import spinal.lib.bus.avalon.{AvalonMM, AvalonMMConfig} import spinal.lib.bus.avalon.{AvalonMM, AvalonMMConfig}
import vexriscv.Riscv.IMM
import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.ArrayBuffer
@ -109,12 +110,14 @@ class IBusSimplePlugin(interfaceKeepData : Boolean, catchAccessFault : Boolean,
var prefetchExceptionPort : Flow[ExceptionCause] = null var prefetchExceptionPort : Flow[ExceptionCause] = null
def resetVector = BigInt(0x80000000l) def resetVector = BigInt(0x80000000l)
def keepPcPlus4 = false def keepPcPlus4 = false
def decodePcGen = true def decodePcGen = false
def compressedGen = true def compressedGen = false
def cmdToRspStageCount = 1 def cmdToRspStageCount = 1
def rspStageGen = false def rspStageGen = false
def injectorReadyCutGen = true def injectorReadyCutGen = false
def relaxedPcCalculation = true def relaxedPcCalculation = false
def prediction : BranchPrediction = STATIC
var decodePrediction : DecodePredictionBus = null
assert(cmdToRspStageCount >= 1) assert(cmdToRspStageCount >= 1)
assert(!(compressedGen && !decodePcGen)) assert(!(compressedGen && !decodePcGen))
lazy val fetcherHalt = False lazy val fetcherHalt = False
@ -122,6 +125,8 @@ class IBusSimplePlugin(interfaceKeepData : Boolean, catchAccessFault : Boolean,
lazy val decodeNextPc = UInt(32 bits) lazy val decodeNextPc = UInt(32 bits)
def nextPc() = (decodeNextPcValid, decodeNextPc) def nextPc() = (decodeNextPcValid, decodeNextPc)
var predictionJumpInterface : Flow[UInt] = null
override def haltIt(): Unit = fetcherHalt := True override def haltIt(): Unit = fetcherHalt := True
case class JumpInfo(interface : Flow[UInt], stage: Stage, priority : Int) case class JumpInfo(interface : Flow[UInt], stage: Stage, priority : Int)
@ -142,6 +147,14 @@ class IBusSimplePlugin(interfaceKeepData : Boolean, catchAccessFault : Boolean,
} }
pipeline(RVC_GEN) = compressedGen pipeline(RVC_GEN) = compressedGen
prediction match {
case NONE =>
case STATIC | DYNAMIC => {
predictionJumpInterface = createJumpInterface(pipeline.decode)
decodePrediction = pipeline.service(classOf[PredictionInterface]).askDecodePrediction()
}
}
} }
override def build(pipeline: VexRiscv): Unit = { override def build(pipeline: VexRiscv): Unit = {
@ -365,6 +378,29 @@ class IBusSimplePlugin(interfaceKeepData : Boolean, catchAccessFault : Boolean,
decodeExceptionPort.code := 1 decodeExceptionPort.code := 1
decodeExceptionPort.badAddr := decode.input(PC) decodeExceptionPort.badAddr := decode.input(PC)
} }
prediction match {
case `NONE` =>
case `STATIC` => {
val imm = IMM(decode.input(INSTRUCTION))
val conditionalBranchPrediction = (prediction match {
case `STATIC` => imm.b_sext.msb
//case `DYNAMIC` => input(HISTORY_LINE).history.msb
})
decodePrediction.cmd.hadBranch := decode.input(BRANCH_CTRL) === BranchCtrlEnum.JAL || (decode.input(BRANCH_CTRL) === BranchCtrlEnum.B && conditionalBranchPrediction)
predictionJumpInterface.valid := decodePrediction.cmd.hadBranch && decode.arbitration.isFiring //TODO OH Doublon de priorité
predictionJumpInterface.payload := decode.input(PC) + ((decode.input(BRANCH_CTRL) === BranchCtrlEnum.JAL) ? imm.j_sext | imm.b_sext).asUInt
// if(catchAddressMisaligned) {
// predictionExceptionPort.valid := input(INSTRUCTION_READY) && input(PREDICTION_HAD_BRANCHED) && arbitration.isValid && predictionJumpInterface.payload(1 downto 0) =/= 0
// predictionExceptionPort.code := 0
// predictionExceptionPort.badAddr := predictionJumpInterface.payload
// }
}
}
} }
} }
} }