mirror of
https://github.com/SpinalHDL/VexRiscv.git
synced 2025-01-03 03:43:39 -05:00
Static prediction is fully functionnal
This commit is contained in:
parent
0919308a8f
commit
76ca852478
2 changed files with 173 additions and 86 deletions
|
@ -11,25 +11,71 @@ object STATIC extends BranchPrediction
|
|||
object DYNAMIC extends BranchPrediction
|
||||
object DYNAMIC_TARGET extends BranchPrediction
|
||||
|
||||
object BranchCtrlEnum extends SpinalEnum(binarySequential){
|
||||
val INC,B,JAL,JALR = newElement()
|
||||
}
|
||||
object BRANCH_CTRL extends Stageable(BranchCtrlEnum())
|
||||
|
||||
|
||||
case class DecodePredictionCmd() extends Bundle {
|
||||
val hadBranch = Bool
|
||||
}
|
||||
case class DecodePredictionRsp(stage : Stage) extends Bundle {
|
||||
val wasWrong = Bool
|
||||
}
|
||||
case class DecodePredictionBus(stage : Stage) extends Bundle {
|
||||
val cmd = DecodePredictionCmd()
|
||||
val rsp = DecodePredictionRsp(stage)
|
||||
}
|
||||
|
||||
case class FetchPredictionCmd() extends Bundle{
|
||||
val hadBranch = Bool
|
||||
val targetPc = UInt(32 bits)
|
||||
}
|
||||
case class FetchPredictionRsp(stage : Stage) extends Bundle{
|
||||
val wasRight = Bool
|
||||
val targetPc = UInt(32 bits)
|
||||
}
|
||||
case class FetchPredictionBus(stage : Stage) extends Bundle {
|
||||
val cmd = FetchPredictionCmd()
|
||||
val rsp = FetchPredictionRsp(stage)
|
||||
}
|
||||
|
||||
|
||||
trait PredictionInterface{
|
||||
def askFetchPrediction() : FetchPredictionBus
|
||||
def askDecodePrediction() : DecodePredictionBus
|
||||
}
|
||||
|
||||
class BranchPlugin(earlyBranch : Boolean,
|
||||
catchAddressMisaligned : Boolean,
|
||||
prediction : BranchPrediction,
|
||||
historyRamSizeLog2 : Int = 10,
|
||||
historyWidth : Int = 2) extends Plugin[VexRiscv]{
|
||||
object BranchCtrlEnum extends SpinalEnum(binarySequential){
|
||||
val INC,B,JAL,JALR = newElement()
|
||||
}
|
||||
historyWidth : Int = 2) extends Plugin[VexRiscv] with PredictionInterface{
|
||||
|
||||
|
||||
lazy val branchStage = if(earlyBranch) pipeline.execute else pipeline.memory
|
||||
|
||||
object BRANCH_CTRL extends Stageable(BranchCtrlEnum())
|
||||
object BRANCH_CALC extends Stageable(UInt(32 bits))
|
||||
object BRANCH_DO extends Stageable(Bool)
|
||||
object BRANCH_COND_RESULT extends Stageable(Bool)
|
||||
// object PREDICTION_HAD_BRANCHED extends Stageable(Bool)
|
||||
|
||||
var jumpInterface : Flow[UInt] = null
|
||||
var predictionJumpInterface : Flow[UInt] = null
|
||||
var predictionExceptionPort : Flow[ExceptionCause] = null
|
||||
var branchExceptionPort : Flow[ExceptionCause] = null
|
||||
|
||||
|
||||
var decodePrediction : DecodePredictionBus = null
|
||||
|
||||
|
||||
override def askFetchPrediction() = ???
|
||||
override def askDecodePrediction() = {
|
||||
decodePrediction = DecodePredictionBus(branchStage)
|
||||
decodePrediction
|
||||
}
|
||||
|
||||
override def setup(pipeline: VexRiscv): Unit = {
|
||||
import Riscv._
|
||||
import pipeline.config._
|
||||
|
@ -66,7 +112,7 @@ class BranchPlugin(earlyBranch : Boolean,
|
|||
))
|
||||
|
||||
val pcManagerService = pipeline.service(classOf[JumpService])
|
||||
jumpInterface = pcManagerService.createJumpInterface(if(earlyBranch) pipeline.execute else pipeline.memory)
|
||||
jumpInterface = pcManagerService.createJumpInterface(branchStage)
|
||||
|
||||
prediction match {
|
||||
case NONE =>
|
||||
|
@ -76,7 +122,7 @@ class BranchPlugin(earlyBranch : Boolean,
|
|||
|
||||
if (catchAddressMisaligned) {
|
||||
val exceptionService = pipeline.service(classOf[ExceptionService])
|
||||
branchExceptionPort = exceptionService.newExceptionPort(if (earlyBranch) pipeline.execute else pipeline.memory)
|
||||
branchExceptionPort = exceptionService.newExceptionPort(branchStage)
|
||||
prediction match {
|
||||
case NONE =>
|
||||
// case STATIC | DYNAMIC => predictionExceptionPort = exceptionService.newExceptionPort(pipeline.decode)
|
||||
|
@ -85,9 +131,9 @@ class BranchPlugin(earlyBranch : Boolean,
|
|||
}
|
||||
}
|
||||
|
||||
override def build(pipeline: VexRiscv): Unit = prediction match {
|
||||
case `NONE` => buildWithoutPrediction(pipeline)
|
||||
// case `STATIC` => buildWithPrediction(pipeline)
|
||||
override def build(pipeline: VexRiscv): Unit = (decodePrediction) match {
|
||||
case null => buildWithoutPrediction(pipeline)
|
||||
case _ => buildWithPrediction(pipeline)
|
||||
// case `DYNAMIC` => buildWithPrediction(pipeline)
|
||||
// case `DYNAMIC_TARGET` => buildDynamicTargetPrediction(pipeline)
|
||||
}
|
||||
|
@ -128,7 +174,6 @@ class BranchPlugin(earlyBranch : Boolean,
|
|||
}
|
||||
|
||||
//Apply branchs (JAL,JALR, Bxx)
|
||||
val branchStage = if(earlyBranch) execute else memory
|
||||
branchStage plug new Area {
|
||||
import branchStage._
|
||||
jumpInterface.valid := arbitration.isFiring && input(BRANCH_DO)
|
||||
|
@ -147,21 +192,21 @@ class BranchPlugin(earlyBranch : Boolean,
|
|||
}
|
||||
|
||||
|
||||
// def buildWithPrediction(pipeline: VexRiscv): Unit = {
|
||||
def buildWithPrediction(pipeline: VexRiscv): Unit = {
|
||||
// case class BranchPredictorLine() extends Bundle{
|
||||
// val history = SInt(historyWidth bits)
|
||||
// }
|
||||
//
|
||||
// object PREDICTION_HAD_BRANCHED extends Stageable(Bool)
|
||||
|
||||
object PREDICTION_HAD_BRANCHED extends Stageable(Bool)
|
||||
// object HISTORY_LINE extends Stageable(BranchPredictorLine())
|
||||
//
|
||||
// import pipeline._
|
||||
// import pipeline.config._
|
||||
//
|
||||
|
||||
import pipeline._
|
||||
import pipeline.config._
|
||||
|
||||
// val historyCache = if(prediction == DYNAMIC) Mem(BranchPredictorLine(), 1 << historyRamSizeLog2) setName("branchCache") else null
|
||||
// val historyCacheWrite = if(prediction == DYNAMIC) historyCache.writePort else null
|
||||
//
|
||||
// //Read historyCache
|
||||
|
||||
//Read historyCache
|
||||
// if(prediction == DYNAMIC) fetch plug new Area{
|
||||
// val readAddress = prefetch.output(PC)(2, historyRamSizeLog2 bits)
|
||||
// fetch.insert(HISTORY_LINE) := historyCache.readSync(readAddress,!prefetch.arbitration.isStuckByOthers)
|
||||
|
@ -172,8 +217,8 @@ class BranchPlugin(earlyBranch : Boolean,
|
|||
//// fetch.insert(HISTORY_LINE) := writePortReg.data
|
||||
//// }
|
||||
// }
|
||||
//
|
||||
// //Branch JAL, predict Bxx and branch it
|
||||
|
||||
//Branch JAL, predict Bxx and branch it
|
||||
// decode plug new Area{
|
||||
// import decode._
|
||||
// val imm = IMM(input(INSTRUCTION))
|
||||
|
@ -196,65 +241,71 @@ class BranchPlugin(earlyBranch : Boolean,
|
|||
// predictionExceptionPort.badAddr := predictionJumpInterface.payload
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// //Do real branch calculation
|
||||
// execute plug new Area {
|
||||
// import execute._
|
||||
//
|
||||
// val less = input(SRC_LESS)
|
||||
// val eq = input(SRC1) === input(SRC2)
|
||||
//
|
||||
// insert(BRANCH_COND_RESULT) := input(BRANCH_CTRL).mux(
|
||||
// BranchCtrlEnum.INC -> False,
|
||||
// BranchCtrlEnum.JAL -> True,
|
||||
// BranchCtrlEnum.JALR -> True,
|
||||
// BranchCtrlEnum.B -> input(INSTRUCTION)(14 downto 12).mux(
|
||||
// B"000" -> eq ,
|
||||
// B"001" -> !eq ,
|
||||
// M"1-1" -> !less,
|
||||
// default -> less
|
||||
// )
|
||||
// )
|
||||
//
|
||||
// insert(BRANCH_DO) := input(PREDICTION_HAD_BRANCHED) =/= insert(BRANCH_COND_RESULT)
|
||||
//
|
||||
// //Calculation of the branch target / correction
|
||||
// val imm = IMM(input(INSTRUCTION))
|
||||
// val branch_src1,branch_src2 = UInt(32 bits)
|
||||
// switch(input(BRANCH_CTRL)){
|
||||
// is(BranchCtrlEnum.JALR){
|
||||
// branch_src1 := input(RS1).asUInt
|
||||
// branch_src2 := imm.i_sext.asUInt
|
||||
// }
|
||||
// default{
|
||||
// branch_src1 := input(PC)
|
||||
// branch_src2 := (input(PREDICTION_HAD_BRANCHED) ? B(4) | imm.b_sext).asUInt
|
||||
// }
|
||||
// }
|
||||
// val branchAdder = branch_src1 + branch_src2
|
||||
// insert(BRANCH_CALC) := branchAdder(31 downto 1) @@ ((input(BRANCH_CTRL) === BranchCtrlEnum.JALR) ? False | branchAdder(0))
|
||||
// }
|
||||
//
|
||||
//
|
||||
// // branch JALR or JAL/Bxx prediction miss corrections
|
||||
// val branchStage = if(earlyBranch) execute else memory
|
||||
// branchStage plug new Area {
|
||||
// import branchStage._
|
||||
// jumpInterface.valid := input(BRANCH_DO) && arbitration.isFiring
|
||||
// jumpInterface.payload := input(BRANCH_CALC)
|
||||
//
|
||||
// when(jumpInterface.valid) {
|
||||
// stages(indexOf(branchStage) - 1).arbitration.flushAll := True
|
||||
// }
|
||||
//
|
||||
// if(catchAddressMisaligned) {
|
||||
// branchExceptionPort.valid := input(INSTRUCTION_READY) && arbitration.isValid && input(BRANCH_DO) && jumpInterface.payload(1 downto 0) =/= 0
|
||||
// branchExceptionPort.code := 0
|
||||
// branchExceptionPort.badAddr := jumpInterface.payload
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// //Update historyCache
|
||||
|
||||
decode plug new Area {
|
||||
import decode._
|
||||
insert(PREDICTION_HAD_BRANCHED) := decodePrediction.cmd.hadBranch
|
||||
}
|
||||
|
||||
//Do real branch calculation
|
||||
execute plug new Area {
|
||||
import execute._
|
||||
|
||||
val less = input(SRC_LESS)
|
||||
val eq = input(SRC1) === input(SRC2)
|
||||
|
||||
insert(BRANCH_COND_RESULT) := input(BRANCH_CTRL).mux(
|
||||
BranchCtrlEnum.INC -> False,
|
||||
BranchCtrlEnum.JAL -> True,
|
||||
BranchCtrlEnum.JALR -> True,
|
||||
BranchCtrlEnum.B -> input(INSTRUCTION)(14 downto 12).mux(
|
||||
B"000" -> eq ,
|
||||
B"001" -> !eq ,
|
||||
M"1-1" -> !less,
|
||||
default -> less
|
||||
)
|
||||
)
|
||||
|
||||
insert(BRANCH_DO) := input(PREDICTION_HAD_BRANCHED) =/= insert(BRANCH_COND_RESULT)
|
||||
|
||||
//Calculation of the branch target / correction
|
||||
val imm = IMM(input(INSTRUCTION))
|
||||
val branch_src1,branch_src2 = UInt(32 bits)
|
||||
switch(input(BRANCH_CTRL)){
|
||||
is(BranchCtrlEnum.JALR){
|
||||
branch_src1 := input(RS1).asUInt
|
||||
branch_src2 := imm.i_sext.asUInt
|
||||
}
|
||||
default{
|
||||
branch_src1 := input(PC)
|
||||
branch_src2 := (input(PREDICTION_HAD_BRANCHED) ? B(4) | imm.b_sext).asUInt
|
||||
}
|
||||
}
|
||||
val branchAdder = branch_src1 + branch_src2
|
||||
insert(BRANCH_CALC) := branchAdder(31 downto 1) @@ ((input(BRANCH_CTRL) === BranchCtrlEnum.JALR) ? False | branchAdder(0))
|
||||
}
|
||||
|
||||
|
||||
// branch JALR or JAL/Bxx prediction miss corrections
|
||||
val branchStage = if(earlyBranch) execute else memory
|
||||
branchStage plug new Area {
|
||||
import branchStage._
|
||||
jumpInterface.valid := input(BRANCH_DO) && arbitration.isFiring
|
||||
jumpInterface.payload := input(BRANCH_CALC)
|
||||
|
||||
when(jumpInterface.valid) {
|
||||
stages(indexOf(branchStage) - 1).arbitration.flushAll := True
|
||||
}
|
||||
|
||||
if(catchAddressMisaligned) {
|
||||
branchExceptionPort.valid := input(INSTRUCTION_READY) && arbitration.isValid && input(BRANCH_DO) && jumpInterface.payload(1 downto 0) =/= 0
|
||||
branchExceptionPort.code := 0
|
||||
branchExceptionPort.badAddr := jumpInterface.payload
|
||||
}
|
||||
}
|
||||
|
||||
//Update historyCache
|
||||
decodePrediction.rsp.wasWrong := jumpInterface.valid
|
||||
// if(prediction == DYNAMIC) branchStage plug new Area {
|
||||
// import branchStage._
|
||||
// val newHistory = input(HISTORY_LINE).history.resize(historyWidth + 1) + Mux(input(BRANCH_COND_RESULT),S(-1),S(1))
|
||||
|
@ -264,7 +315,7 @@ class BranchPlugin(earlyBranch : Boolean,
|
|||
// historyCacheWrite.address := input(PC)(2, historyRamSizeLog2 bits)
|
||||
// historyCacheWrite.data.history := newHistory.resized
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import spinal.core._
|
|||
import spinal.lib._
|
||||
import spinal.lib.bus.amba4.axi._
|
||||
import spinal.lib.bus.avalon.{AvalonMM, AvalonMMConfig}
|
||||
import vexriscv.Riscv.IMM
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
|
@ -109,12 +110,14 @@ class IBusSimplePlugin(interfaceKeepData : Boolean, catchAccessFault : Boolean,
|
|||
var prefetchExceptionPort : Flow[ExceptionCause] = null
|
||||
def resetVector = BigInt(0x80000000l)
|
||||
def keepPcPlus4 = false
|
||||
def decodePcGen = true
|
||||
def compressedGen = true
|
||||
def decodePcGen = false
|
||||
def compressedGen = false
|
||||
def cmdToRspStageCount = 1
|
||||
def rspStageGen = false
|
||||
def injectorReadyCutGen = true
|
||||
def relaxedPcCalculation = true
|
||||
def injectorReadyCutGen = false
|
||||
def relaxedPcCalculation = false
|
||||
def prediction : BranchPrediction = STATIC
|
||||
var decodePrediction : DecodePredictionBus = null
|
||||
assert(cmdToRspStageCount >= 1)
|
||||
assert(!(compressedGen && !decodePcGen))
|
||||
lazy val fetcherHalt = False
|
||||
|
@ -122,6 +125,8 @@ class IBusSimplePlugin(interfaceKeepData : Boolean, catchAccessFault : Boolean,
|
|||
lazy val decodeNextPc = UInt(32 bits)
|
||||
def nextPc() = (decodeNextPcValid, decodeNextPc)
|
||||
|
||||
var predictionJumpInterface : Flow[UInt] = null
|
||||
|
||||
override def haltIt(): Unit = fetcherHalt := True
|
||||
|
||||
case class JumpInfo(interface : Flow[UInt], stage: Stage, priority : Int)
|
||||
|
@ -142,6 +147,14 @@ class IBusSimplePlugin(interfaceKeepData : Boolean, catchAccessFault : Boolean,
|
|||
}
|
||||
|
||||
pipeline(RVC_GEN) = compressedGen
|
||||
|
||||
prediction match {
|
||||
case NONE =>
|
||||
case STATIC | DYNAMIC => {
|
||||
predictionJumpInterface = createJumpInterface(pipeline.decode)
|
||||
decodePrediction = pipeline.service(classOf[PredictionInterface]).askDecodePrediction()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def build(pipeline: VexRiscv): Unit = {
|
||||
|
@ -365,6 +378,29 @@ class IBusSimplePlugin(interfaceKeepData : Boolean, catchAccessFault : Boolean,
|
|||
decodeExceptionPort.code := 1
|
||||
decodeExceptionPort.badAddr := decode.input(PC)
|
||||
}
|
||||
|
||||
prediction match {
|
||||
case `NONE` =>
|
||||
case `STATIC` => {
|
||||
val imm = IMM(decode.input(INSTRUCTION))
|
||||
|
||||
val conditionalBranchPrediction = (prediction match {
|
||||
case `STATIC` => imm.b_sext.msb
|
||||
//case `DYNAMIC` => input(HISTORY_LINE).history.msb
|
||||
})
|
||||
decodePrediction.cmd.hadBranch := decode.input(BRANCH_CTRL) === BranchCtrlEnum.JAL || (decode.input(BRANCH_CTRL) === BranchCtrlEnum.B && conditionalBranchPrediction)
|
||||
|
||||
predictionJumpInterface.valid := decodePrediction.cmd.hadBranch && decode.arbitration.isFiring //TODO OH Doublon de priorité
|
||||
predictionJumpInterface.payload := decode.input(PC) + ((decode.input(BRANCH_CTRL) === BranchCtrlEnum.JAL) ? imm.j_sext | imm.b_sext).asUInt
|
||||
|
||||
|
||||
// if(catchAddressMisaligned) {
|
||||
// predictionExceptionPort.valid := input(INSTRUCTION_READY) && input(PREDICTION_HAD_BRANCHED) && arbitration.isValid && predictionJumpInterface.payload(1 downto 0) =/= 0
|
||||
// predictionExceptionPort.code := 0
|
||||
// predictionExceptionPort.badAddr := predictionJumpInterface.payload
|
||||
// }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue