Static prediction is fully functionnal
This commit is contained in:
parent
0919308a8f
commit
76ca852478
|
@ -11,25 +11,71 @@ object STATIC extends BranchPrediction
|
||||||
object DYNAMIC extends BranchPrediction
|
object DYNAMIC extends BranchPrediction
|
||||||
object DYNAMIC_TARGET extends BranchPrediction
|
object DYNAMIC_TARGET extends BranchPrediction
|
||||||
|
|
||||||
|
object BranchCtrlEnum extends SpinalEnum(binarySequential){
|
||||||
|
val INC,B,JAL,JALR = newElement()
|
||||||
|
}
|
||||||
|
object BRANCH_CTRL extends Stageable(BranchCtrlEnum())
|
||||||
|
|
||||||
|
|
||||||
|
case class DecodePredictionCmd() extends Bundle {
|
||||||
|
val hadBranch = Bool
|
||||||
|
}
|
||||||
|
case class DecodePredictionRsp(stage : Stage) extends Bundle {
|
||||||
|
val wasWrong = Bool
|
||||||
|
}
|
||||||
|
case class DecodePredictionBus(stage : Stage) extends Bundle {
|
||||||
|
val cmd = DecodePredictionCmd()
|
||||||
|
val rsp = DecodePredictionRsp(stage)
|
||||||
|
}
|
||||||
|
|
||||||
|
case class FetchPredictionCmd() extends Bundle{
|
||||||
|
val hadBranch = Bool
|
||||||
|
val targetPc = UInt(32 bits)
|
||||||
|
}
|
||||||
|
case class FetchPredictionRsp(stage : Stage) extends Bundle{
|
||||||
|
val wasRight = Bool
|
||||||
|
val targetPc = UInt(32 bits)
|
||||||
|
}
|
||||||
|
case class FetchPredictionBus(stage : Stage) extends Bundle {
|
||||||
|
val cmd = FetchPredictionCmd()
|
||||||
|
val rsp = FetchPredictionRsp(stage)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
trait PredictionInterface{
|
||||||
|
def askFetchPrediction() : FetchPredictionBus
|
||||||
|
def askDecodePrediction() : DecodePredictionBus
|
||||||
|
}
|
||||||
|
|
||||||
class BranchPlugin(earlyBranch : Boolean,
|
class BranchPlugin(earlyBranch : Boolean,
|
||||||
catchAddressMisaligned : Boolean,
|
catchAddressMisaligned : Boolean,
|
||||||
prediction : BranchPrediction,
|
prediction : BranchPrediction,
|
||||||
historyRamSizeLog2 : Int = 10,
|
historyRamSizeLog2 : Int = 10,
|
||||||
historyWidth : Int = 2) extends Plugin[VexRiscv]{
|
historyWidth : Int = 2) extends Plugin[VexRiscv] with PredictionInterface{
|
||||||
object BranchCtrlEnum extends SpinalEnum(binarySequential){
|
|
||||||
val INC,B,JAL,JALR = newElement()
|
|
||||||
}
|
lazy val branchStage = if(earlyBranch) pipeline.execute else pipeline.memory
|
||||||
|
|
||||||
object BRANCH_CTRL extends Stageable(BranchCtrlEnum())
|
|
||||||
object BRANCH_CALC extends Stageable(UInt(32 bits))
|
object BRANCH_CALC extends Stageable(UInt(32 bits))
|
||||||
object BRANCH_DO extends Stageable(Bool)
|
object BRANCH_DO extends Stageable(Bool)
|
||||||
object BRANCH_COND_RESULT extends Stageable(Bool)
|
object BRANCH_COND_RESULT extends Stageable(Bool)
|
||||||
|
// object PREDICTION_HAD_BRANCHED extends Stageable(Bool)
|
||||||
|
|
||||||
var jumpInterface : Flow[UInt] = null
|
var jumpInterface : Flow[UInt] = null
|
||||||
var predictionJumpInterface : Flow[UInt] = null
|
var predictionJumpInterface : Flow[UInt] = null
|
||||||
var predictionExceptionPort : Flow[ExceptionCause] = null
|
var predictionExceptionPort : Flow[ExceptionCause] = null
|
||||||
var branchExceptionPort : Flow[ExceptionCause] = null
|
var branchExceptionPort : Flow[ExceptionCause] = null
|
||||||
|
|
||||||
|
|
||||||
|
var decodePrediction : DecodePredictionBus = null
|
||||||
|
|
||||||
|
|
||||||
|
override def askFetchPrediction() = ???
|
||||||
|
override def askDecodePrediction() = {
|
||||||
|
decodePrediction = DecodePredictionBus(branchStage)
|
||||||
|
decodePrediction
|
||||||
|
}
|
||||||
|
|
||||||
override def setup(pipeline: VexRiscv): Unit = {
|
override def setup(pipeline: VexRiscv): Unit = {
|
||||||
import Riscv._
|
import Riscv._
|
||||||
import pipeline.config._
|
import pipeline.config._
|
||||||
|
@ -66,7 +112,7 @@ class BranchPlugin(earlyBranch : Boolean,
|
||||||
))
|
))
|
||||||
|
|
||||||
val pcManagerService = pipeline.service(classOf[JumpService])
|
val pcManagerService = pipeline.service(classOf[JumpService])
|
||||||
jumpInterface = pcManagerService.createJumpInterface(if(earlyBranch) pipeline.execute else pipeline.memory)
|
jumpInterface = pcManagerService.createJumpInterface(branchStage)
|
||||||
|
|
||||||
prediction match {
|
prediction match {
|
||||||
case NONE =>
|
case NONE =>
|
||||||
|
@ -76,7 +122,7 @@ class BranchPlugin(earlyBranch : Boolean,
|
||||||
|
|
||||||
if (catchAddressMisaligned) {
|
if (catchAddressMisaligned) {
|
||||||
val exceptionService = pipeline.service(classOf[ExceptionService])
|
val exceptionService = pipeline.service(classOf[ExceptionService])
|
||||||
branchExceptionPort = exceptionService.newExceptionPort(if (earlyBranch) pipeline.execute else pipeline.memory)
|
branchExceptionPort = exceptionService.newExceptionPort(branchStage)
|
||||||
prediction match {
|
prediction match {
|
||||||
case NONE =>
|
case NONE =>
|
||||||
// case STATIC | DYNAMIC => predictionExceptionPort = exceptionService.newExceptionPort(pipeline.decode)
|
// case STATIC | DYNAMIC => predictionExceptionPort = exceptionService.newExceptionPort(pipeline.decode)
|
||||||
|
@ -85,9 +131,9 @@ class BranchPlugin(earlyBranch : Boolean,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override def build(pipeline: VexRiscv): Unit = prediction match {
|
override def build(pipeline: VexRiscv): Unit = (decodePrediction) match {
|
||||||
case `NONE` => buildWithoutPrediction(pipeline)
|
case null => buildWithoutPrediction(pipeline)
|
||||||
// case `STATIC` => buildWithPrediction(pipeline)
|
case _ => buildWithPrediction(pipeline)
|
||||||
// case `DYNAMIC` => buildWithPrediction(pipeline)
|
// case `DYNAMIC` => buildWithPrediction(pipeline)
|
||||||
// case `DYNAMIC_TARGET` => buildDynamicTargetPrediction(pipeline)
|
// case `DYNAMIC_TARGET` => buildDynamicTargetPrediction(pipeline)
|
||||||
}
|
}
|
||||||
|
@ -128,7 +174,6 @@ class BranchPlugin(earlyBranch : Boolean,
|
||||||
}
|
}
|
||||||
|
|
||||||
//Apply branchs (JAL,JALR, Bxx)
|
//Apply branchs (JAL,JALR, Bxx)
|
||||||
val branchStage = if(earlyBranch) execute else memory
|
|
||||||
branchStage plug new Area {
|
branchStage plug new Area {
|
||||||
import branchStage._
|
import branchStage._
|
||||||
jumpInterface.valid := arbitration.isFiring && input(BRANCH_DO)
|
jumpInterface.valid := arbitration.isFiring && input(BRANCH_DO)
|
||||||
|
@ -147,21 +192,21 @@ class BranchPlugin(earlyBranch : Boolean,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// def buildWithPrediction(pipeline: VexRiscv): Unit = {
|
def buildWithPrediction(pipeline: VexRiscv): Unit = {
|
||||||
// case class BranchPredictorLine() extends Bundle{
|
// case class BranchPredictorLine() extends Bundle{
|
||||||
// val history = SInt(historyWidth bits)
|
// val history = SInt(historyWidth bits)
|
||||||
// }
|
// }
|
||||||
//
|
|
||||||
// object PREDICTION_HAD_BRANCHED extends Stageable(Bool)
|
object PREDICTION_HAD_BRANCHED extends Stageable(Bool)
|
||||||
// object HISTORY_LINE extends Stageable(BranchPredictorLine())
|
// object HISTORY_LINE extends Stageable(BranchPredictorLine())
|
||||||
//
|
|
||||||
// import pipeline._
|
import pipeline._
|
||||||
// import pipeline.config._
|
import pipeline.config._
|
||||||
//
|
|
||||||
// val historyCache = if(prediction == DYNAMIC) Mem(BranchPredictorLine(), 1 << historyRamSizeLog2) setName("branchCache") else null
|
// val historyCache = if(prediction == DYNAMIC) Mem(BranchPredictorLine(), 1 << historyRamSizeLog2) setName("branchCache") else null
|
||||||
// val historyCacheWrite = if(prediction == DYNAMIC) historyCache.writePort else null
|
// val historyCacheWrite = if(prediction == DYNAMIC) historyCache.writePort else null
|
||||||
//
|
|
||||||
// //Read historyCache
|
//Read historyCache
|
||||||
// if(prediction == DYNAMIC) fetch plug new Area{
|
// if(prediction == DYNAMIC) fetch plug new Area{
|
||||||
// val readAddress = prefetch.output(PC)(2, historyRamSizeLog2 bits)
|
// val readAddress = prefetch.output(PC)(2, historyRamSizeLog2 bits)
|
||||||
// fetch.insert(HISTORY_LINE) := historyCache.readSync(readAddress,!prefetch.arbitration.isStuckByOthers)
|
// fetch.insert(HISTORY_LINE) := historyCache.readSync(readAddress,!prefetch.arbitration.isStuckByOthers)
|
||||||
|
@ -172,8 +217,8 @@ class BranchPlugin(earlyBranch : Boolean,
|
||||||
//// fetch.insert(HISTORY_LINE) := writePortReg.data
|
//// fetch.insert(HISTORY_LINE) := writePortReg.data
|
||||||
//// }
|
//// }
|
||||||
// }
|
// }
|
||||||
//
|
|
||||||
// //Branch JAL, predict Bxx and branch it
|
//Branch JAL, predict Bxx and branch it
|
||||||
// decode plug new Area{
|
// decode plug new Area{
|
||||||
// import decode._
|
// import decode._
|
||||||
// val imm = IMM(input(INSTRUCTION))
|
// val imm = IMM(input(INSTRUCTION))
|
||||||
|
@ -196,65 +241,71 @@ class BranchPlugin(earlyBranch : Boolean,
|
||||||
// predictionExceptionPort.badAddr := predictionJumpInterface.payload
|
// predictionExceptionPort.badAddr := predictionJumpInterface.payload
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
//
|
|
||||||
// //Do real branch calculation
|
decode plug new Area {
|
||||||
// execute plug new Area {
|
import decode._
|
||||||
// import execute._
|
insert(PREDICTION_HAD_BRANCHED) := decodePrediction.cmd.hadBranch
|
||||||
//
|
}
|
||||||
// val less = input(SRC_LESS)
|
|
||||||
// val eq = input(SRC1) === input(SRC2)
|
//Do real branch calculation
|
||||||
//
|
execute plug new Area {
|
||||||
// insert(BRANCH_COND_RESULT) := input(BRANCH_CTRL).mux(
|
import execute._
|
||||||
// BranchCtrlEnum.INC -> False,
|
|
||||||
// BranchCtrlEnum.JAL -> True,
|
val less = input(SRC_LESS)
|
||||||
// BranchCtrlEnum.JALR -> True,
|
val eq = input(SRC1) === input(SRC2)
|
||||||
// BranchCtrlEnum.B -> input(INSTRUCTION)(14 downto 12).mux(
|
|
||||||
// B"000" -> eq ,
|
insert(BRANCH_COND_RESULT) := input(BRANCH_CTRL).mux(
|
||||||
// B"001" -> !eq ,
|
BranchCtrlEnum.INC -> False,
|
||||||
// M"1-1" -> !less,
|
BranchCtrlEnum.JAL -> True,
|
||||||
// default -> less
|
BranchCtrlEnum.JALR -> True,
|
||||||
// )
|
BranchCtrlEnum.B -> input(INSTRUCTION)(14 downto 12).mux(
|
||||||
// )
|
B"000" -> eq ,
|
||||||
//
|
B"001" -> !eq ,
|
||||||
// insert(BRANCH_DO) := input(PREDICTION_HAD_BRANCHED) =/= insert(BRANCH_COND_RESULT)
|
M"1-1" -> !less,
|
||||||
//
|
default -> less
|
||||||
// //Calculation of the branch target / correction
|
)
|
||||||
// val imm = IMM(input(INSTRUCTION))
|
)
|
||||||
// val branch_src1,branch_src2 = UInt(32 bits)
|
|
||||||
// switch(input(BRANCH_CTRL)){
|
insert(BRANCH_DO) := input(PREDICTION_HAD_BRANCHED) =/= insert(BRANCH_COND_RESULT)
|
||||||
// is(BranchCtrlEnum.JALR){
|
|
||||||
// branch_src1 := input(RS1).asUInt
|
//Calculation of the branch target / correction
|
||||||
// branch_src2 := imm.i_sext.asUInt
|
val imm = IMM(input(INSTRUCTION))
|
||||||
// }
|
val branch_src1,branch_src2 = UInt(32 bits)
|
||||||
// default{
|
switch(input(BRANCH_CTRL)){
|
||||||
// branch_src1 := input(PC)
|
is(BranchCtrlEnum.JALR){
|
||||||
// branch_src2 := (input(PREDICTION_HAD_BRANCHED) ? B(4) | imm.b_sext).asUInt
|
branch_src1 := input(RS1).asUInt
|
||||||
// }
|
branch_src2 := imm.i_sext.asUInt
|
||||||
// }
|
}
|
||||||
// val branchAdder = branch_src1 + branch_src2
|
default{
|
||||||
// insert(BRANCH_CALC) := branchAdder(31 downto 1) @@ ((input(BRANCH_CTRL) === BranchCtrlEnum.JALR) ? False | branchAdder(0))
|
branch_src1 := input(PC)
|
||||||
// }
|
branch_src2 := (input(PREDICTION_HAD_BRANCHED) ? B(4) | imm.b_sext).asUInt
|
||||||
//
|
}
|
||||||
//
|
}
|
||||||
// // branch JALR or JAL/Bxx prediction miss corrections
|
val branchAdder = branch_src1 + branch_src2
|
||||||
// val branchStage = if(earlyBranch) execute else memory
|
insert(BRANCH_CALC) := branchAdder(31 downto 1) @@ ((input(BRANCH_CTRL) === BranchCtrlEnum.JALR) ? False | branchAdder(0))
|
||||||
// branchStage plug new Area {
|
}
|
||||||
// import branchStage._
|
|
||||||
// jumpInterface.valid := input(BRANCH_DO) && arbitration.isFiring
|
|
||||||
// jumpInterface.payload := input(BRANCH_CALC)
|
// branch JALR or JAL/Bxx prediction miss corrections
|
||||||
//
|
val branchStage = if(earlyBranch) execute else memory
|
||||||
// when(jumpInterface.valid) {
|
branchStage plug new Area {
|
||||||
// stages(indexOf(branchStage) - 1).arbitration.flushAll := True
|
import branchStage._
|
||||||
// }
|
jumpInterface.valid := input(BRANCH_DO) && arbitration.isFiring
|
||||||
//
|
jumpInterface.payload := input(BRANCH_CALC)
|
||||||
// if(catchAddressMisaligned) {
|
|
||||||
// branchExceptionPort.valid := input(INSTRUCTION_READY) && arbitration.isValid && input(BRANCH_DO) && jumpInterface.payload(1 downto 0) =/= 0
|
when(jumpInterface.valid) {
|
||||||
// branchExceptionPort.code := 0
|
stages(indexOf(branchStage) - 1).arbitration.flushAll := True
|
||||||
// branchExceptionPort.badAddr := jumpInterface.payload
|
}
|
||||||
// }
|
|
||||||
// }
|
if(catchAddressMisaligned) {
|
||||||
//
|
branchExceptionPort.valid := input(INSTRUCTION_READY) && arbitration.isValid && input(BRANCH_DO) && jumpInterface.payload(1 downto 0) =/= 0
|
||||||
// //Update historyCache
|
branchExceptionPort.code := 0
|
||||||
|
branchExceptionPort.badAddr := jumpInterface.payload
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//Update historyCache
|
||||||
|
decodePrediction.rsp.wasWrong := jumpInterface.valid
|
||||||
// if(prediction == DYNAMIC) branchStage plug new Area {
|
// if(prediction == DYNAMIC) branchStage plug new Area {
|
||||||
// import branchStage._
|
// import branchStage._
|
||||||
// val newHistory = input(HISTORY_LINE).history.resize(historyWidth + 1) + Mux(input(BRANCH_COND_RESULT),S(-1),S(1))
|
// val newHistory = input(HISTORY_LINE).history.resize(historyWidth + 1) + Mux(input(BRANCH_COND_RESULT),S(-1),S(1))
|
||||||
|
@ -264,7 +315,7 @@ class BranchPlugin(earlyBranch : Boolean,
|
||||||
// historyCacheWrite.address := input(PC)(2, historyRamSizeLog2 bits)
|
// historyCacheWrite.address := input(PC)(2, historyRamSizeLog2 bits)
|
||||||
// historyCacheWrite.data.history := newHistory.resized
|
// historyCacheWrite.data.history := newHistory.resized
|
||||||
// }
|
// }
|
||||||
// }
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ import spinal.core._
|
||||||
import spinal.lib._
|
import spinal.lib._
|
||||||
import spinal.lib.bus.amba4.axi._
|
import spinal.lib.bus.amba4.axi._
|
||||||
import spinal.lib.bus.avalon.{AvalonMM, AvalonMMConfig}
|
import spinal.lib.bus.avalon.{AvalonMM, AvalonMMConfig}
|
||||||
|
import vexriscv.Riscv.IMM
|
||||||
|
|
||||||
import scala.collection.mutable.ArrayBuffer
|
import scala.collection.mutable.ArrayBuffer
|
||||||
|
|
||||||
|
@ -109,12 +110,14 @@ class IBusSimplePlugin(interfaceKeepData : Boolean, catchAccessFault : Boolean,
|
||||||
var prefetchExceptionPort : Flow[ExceptionCause] = null
|
var prefetchExceptionPort : Flow[ExceptionCause] = null
|
||||||
def resetVector = BigInt(0x80000000l)
|
def resetVector = BigInt(0x80000000l)
|
||||||
def keepPcPlus4 = false
|
def keepPcPlus4 = false
|
||||||
def decodePcGen = true
|
def decodePcGen = false
|
||||||
def compressedGen = true
|
def compressedGen = false
|
||||||
def cmdToRspStageCount = 1
|
def cmdToRspStageCount = 1
|
||||||
def rspStageGen = false
|
def rspStageGen = false
|
||||||
def injectorReadyCutGen = true
|
def injectorReadyCutGen = false
|
||||||
def relaxedPcCalculation = true
|
def relaxedPcCalculation = false
|
||||||
|
def prediction : BranchPrediction = STATIC
|
||||||
|
var decodePrediction : DecodePredictionBus = null
|
||||||
assert(cmdToRspStageCount >= 1)
|
assert(cmdToRspStageCount >= 1)
|
||||||
assert(!(compressedGen && !decodePcGen))
|
assert(!(compressedGen && !decodePcGen))
|
||||||
lazy val fetcherHalt = False
|
lazy val fetcherHalt = False
|
||||||
|
@ -122,6 +125,8 @@ class IBusSimplePlugin(interfaceKeepData : Boolean, catchAccessFault : Boolean,
|
||||||
lazy val decodeNextPc = UInt(32 bits)
|
lazy val decodeNextPc = UInt(32 bits)
|
||||||
def nextPc() = (decodeNextPcValid, decodeNextPc)
|
def nextPc() = (decodeNextPcValid, decodeNextPc)
|
||||||
|
|
||||||
|
var predictionJumpInterface : Flow[UInt] = null
|
||||||
|
|
||||||
override def haltIt(): Unit = fetcherHalt := True
|
override def haltIt(): Unit = fetcherHalt := True
|
||||||
|
|
||||||
case class JumpInfo(interface : Flow[UInt], stage: Stage, priority : Int)
|
case class JumpInfo(interface : Flow[UInt], stage: Stage, priority : Int)
|
||||||
|
@ -142,6 +147,14 @@ class IBusSimplePlugin(interfaceKeepData : Boolean, catchAccessFault : Boolean,
|
||||||
}
|
}
|
||||||
|
|
||||||
pipeline(RVC_GEN) = compressedGen
|
pipeline(RVC_GEN) = compressedGen
|
||||||
|
|
||||||
|
prediction match {
|
||||||
|
case NONE =>
|
||||||
|
case STATIC | DYNAMIC => {
|
||||||
|
predictionJumpInterface = createJumpInterface(pipeline.decode)
|
||||||
|
decodePrediction = pipeline.service(classOf[PredictionInterface]).askDecodePrediction()
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override def build(pipeline: VexRiscv): Unit = {
|
override def build(pipeline: VexRiscv): Unit = {
|
||||||
|
@ -365,6 +378,29 @@ class IBusSimplePlugin(interfaceKeepData : Boolean, catchAccessFault : Boolean,
|
||||||
decodeExceptionPort.code := 1
|
decodeExceptionPort.code := 1
|
||||||
decodeExceptionPort.badAddr := decode.input(PC)
|
decodeExceptionPort.badAddr := decode.input(PC)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
prediction match {
|
||||||
|
case `NONE` =>
|
||||||
|
case `STATIC` => {
|
||||||
|
val imm = IMM(decode.input(INSTRUCTION))
|
||||||
|
|
||||||
|
val conditionalBranchPrediction = (prediction match {
|
||||||
|
case `STATIC` => imm.b_sext.msb
|
||||||
|
//case `DYNAMIC` => input(HISTORY_LINE).history.msb
|
||||||
|
})
|
||||||
|
decodePrediction.cmd.hadBranch := decode.input(BRANCH_CTRL) === BranchCtrlEnum.JAL || (decode.input(BRANCH_CTRL) === BranchCtrlEnum.B && conditionalBranchPrediction)
|
||||||
|
|
||||||
|
predictionJumpInterface.valid := decodePrediction.cmd.hadBranch && decode.arbitration.isFiring //TODO OH Doublon de priorité
|
||||||
|
predictionJumpInterface.payload := decode.input(PC) + ((decode.input(BRANCH_CTRL) === BranchCtrlEnum.JAL) ? imm.j_sext | imm.b_sext).asUInt
|
||||||
|
|
||||||
|
|
||||||
|
// if(catchAddressMisaligned) {
|
||||||
|
// predictionExceptionPort.valid := input(INSTRUCTION_READY) && input(PREDICTION_HAD_BRANCHED) && arbitration.isValid && predictionJumpInterface.payload(1 downto 0) =/= 0
|
||||||
|
// predictionExceptionPort.code := 0
|
||||||
|
// predictionExceptionPort.badAddr := predictionJumpInterface.payload
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue