Static prediction is fully functionnal

This commit is contained in:
Dolu1990 2018-04-02 17:43:06 +02:00
parent 0919308a8f
commit 76ca852478
2 changed files with 173 additions and 86 deletions

View file

@ -11,25 +11,71 @@ object STATIC extends BranchPrediction
object DYNAMIC extends BranchPrediction
object DYNAMIC_TARGET extends BranchPrediction
object BranchCtrlEnum extends SpinalEnum(binarySequential){
val INC,B,JAL,JALR = newElement()
}
object BRANCH_CTRL extends Stageable(BranchCtrlEnum())
case class DecodePredictionCmd() extends Bundle {
val hadBranch = Bool
}
case class DecodePredictionRsp(stage : Stage) extends Bundle {
val wasWrong = Bool
}
case class DecodePredictionBus(stage : Stage) extends Bundle {
val cmd = DecodePredictionCmd()
val rsp = DecodePredictionRsp(stage)
}
case class FetchPredictionCmd() extends Bundle{
val hadBranch = Bool
val targetPc = UInt(32 bits)
}
case class FetchPredictionRsp(stage : Stage) extends Bundle{
val wasRight = Bool
val targetPc = UInt(32 bits)
}
case class FetchPredictionBus(stage : Stage) extends Bundle {
val cmd = FetchPredictionCmd()
val rsp = FetchPredictionRsp(stage)
}
trait PredictionInterface{
def askFetchPrediction() : FetchPredictionBus
def askDecodePrediction() : DecodePredictionBus
}
class BranchPlugin(earlyBranch : Boolean,
catchAddressMisaligned : Boolean,
prediction : BranchPrediction,
historyRamSizeLog2 : Int = 10,
historyWidth : Int = 2) extends Plugin[VexRiscv]{
object BranchCtrlEnum extends SpinalEnum(binarySequential){
val INC,B,JAL,JALR = newElement()
}
historyWidth : Int = 2) extends Plugin[VexRiscv] with PredictionInterface{
lazy val branchStage = if(earlyBranch) pipeline.execute else pipeline.memory
object BRANCH_CTRL extends Stageable(BranchCtrlEnum())
object BRANCH_CALC extends Stageable(UInt(32 bits))
object BRANCH_DO extends Stageable(Bool)
object BRANCH_COND_RESULT extends Stageable(Bool)
// object PREDICTION_HAD_BRANCHED extends Stageable(Bool)
var jumpInterface : Flow[UInt] = null
var predictionJumpInterface : Flow[UInt] = null
var predictionExceptionPort : Flow[ExceptionCause] = null
var branchExceptionPort : Flow[ExceptionCause] = null
var decodePrediction : DecodePredictionBus = null
override def askFetchPrediction() = ???
override def askDecodePrediction() = {
decodePrediction = DecodePredictionBus(branchStage)
decodePrediction
}
override def setup(pipeline: VexRiscv): Unit = {
import Riscv._
import pipeline.config._
@ -66,7 +112,7 @@ class BranchPlugin(earlyBranch : Boolean,
))
val pcManagerService = pipeline.service(classOf[JumpService])
jumpInterface = pcManagerService.createJumpInterface(if(earlyBranch) pipeline.execute else pipeline.memory)
jumpInterface = pcManagerService.createJumpInterface(branchStage)
prediction match {
case NONE =>
@ -76,7 +122,7 @@ class BranchPlugin(earlyBranch : Boolean,
if (catchAddressMisaligned) {
val exceptionService = pipeline.service(classOf[ExceptionService])
branchExceptionPort = exceptionService.newExceptionPort(if (earlyBranch) pipeline.execute else pipeline.memory)
branchExceptionPort = exceptionService.newExceptionPort(branchStage)
prediction match {
case NONE =>
// case STATIC | DYNAMIC => predictionExceptionPort = exceptionService.newExceptionPort(pipeline.decode)
@ -85,9 +131,9 @@ class BranchPlugin(earlyBranch : Boolean,
}
}
override def build(pipeline: VexRiscv): Unit = prediction match {
case `NONE` => buildWithoutPrediction(pipeline)
// case `STATIC` => buildWithPrediction(pipeline)
override def build(pipeline: VexRiscv): Unit = (decodePrediction) match {
case null => buildWithoutPrediction(pipeline)
case _ => buildWithPrediction(pipeline)
// case `DYNAMIC` => buildWithPrediction(pipeline)
// case `DYNAMIC_TARGET` => buildDynamicTargetPrediction(pipeline)
}
@ -128,7 +174,6 @@ class BranchPlugin(earlyBranch : Boolean,
}
//Apply branchs (JAL,JALR, Bxx)
val branchStage = if(earlyBranch) execute else memory
branchStage plug new Area {
import branchStage._
jumpInterface.valid := arbitration.isFiring && input(BRANCH_DO)
@ -147,21 +192,21 @@ class BranchPlugin(earlyBranch : Boolean,
}
// def buildWithPrediction(pipeline: VexRiscv): Unit = {
def buildWithPrediction(pipeline: VexRiscv): Unit = {
// case class BranchPredictorLine() extends Bundle{
// val history = SInt(historyWidth bits)
// }
//
// object PREDICTION_HAD_BRANCHED extends Stageable(Bool)
object PREDICTION_HAD_BRANCHED extends Stageable(Bool)
// object HISTORY_LINE extends Stageable(BranchPredictorLine())
//
// import pipeline._
// import pipeline.config._
//
import pipeline._
import pipeline.config._
// val historyCache = if(prediction == DYNAMIC) Mem(BranchPredictorLine(), 1 << historyRamSizeLog2) setName("branchCache") else null
// val historyCacheWrite = if(prediction == DYNAMIC) historyCache.writePort else null
//
// //Read historyCache
//Read historyCache
// if(prediction == DYNAMIC) fetch plug new Area{
// val readAddress = prefetch.output(PC)(2, historyRamSizeLog2 bits)
// fetch.insert(HISTORY_LINE) := historyCache.readSync(readAddress,!prefetch.arbitration.isStuckByOthers)
@ -172,8 +217,8 @@ class BranchPlugin(earlyBranch : Boolean,
//// fetch.insert(HISTORY_LINE) := writePortReg.data
//// }
// }
//
// //Branch JAL, predict Bxx and branch it
//Branch JAL, predict Bxx and branch it
// decode plug new Area{
// import decode._
// val imm = IMM(input(INSTRUCTION))
@ -196,65 +241,71 @@ class BranchPlugin(earlyBranch : Boolean,
// predictionExceptionPort.badAddr := predictionJumpInterface.payload
// }
// }
//
// //Do real branch calculation
// execute plug new Area {
// import execute._
//
// val less = input(SRC_LESS)
// val eq = input(SRC1) === input(SRC2)
//
// insert(BRANCH_COND_RESULT) := input(BRANCH_CTRL).mux(
// BranchCtrlEnum.INC -> False,
// BranchCtrlEnum.JAL -> True,
// BranchCtrlEnum.JALR -> True,
// BranchCtrlEnum.B -> input(INSTRUCTION)(14 downto 12).mux(
// B"000" -> eq ,
// B"001" -> !eq ,
// M"1-1" -> !less,
// default -> less
// )
// )
//
// insert(BRANCH_DO) := input(PREDICTION_HAD_BRANCHED) =/= insert(BRANCH_COND_RESULT)
//
// //Calculation of the branch target / correction
// val imm = IMM(input(INSTRUCTION))
// val branch_src1,branch_src2 = UInt(32 bits)
// switch(input(BRANCH_CTRL)){
// is(BranchCtrlEnum.JALR){
// branch_src1 := input(RS1).asUInt
// branch_src2 := imm.i_sext.asUInt
// }
// default{
// branch_src1 := input(PC)
// branch_src2 := (input(PREDICTION_HAD_BRANCHED) ? B(4) | imm.b_sext).asUInt
// }
// }
// val branchAdder = branch_src1 + branch_src2
// insert(BRANCH_CALC) := branchAdder(31 downto 1) @@ ((input(BRANCH_CTRL) === BranchCtrlEnum.JALR) ? False | branchAdder(0))
// }
//
//
// // branch JALR or JAL/Bxx prediction miss corrections
// val branchStage = if(earlyBranch) execute else memory
// branchStage plug new Area {
// import branchStage._
// jumpInterface.valid := input(BRANCH_DO) && arbitration.isFiring
// jumpInterface.payload := input(BRANCH_CALC)
//
// when(jumpInterface.valid) {
// stages(indexOf(branchStage) - 1).arbitration.flushAll := True
// }
//
// if(catchAddressMisaligned) {
// branchExceptionPort.valid := input(INSTRUCTION_READY) && arbitration.isValid && input(BRANCH_DO) && jumpInterface.payload(1 downto 0) =/= 0
// branchExceptionPort.code := 0
// branchExceptionPort.badAddr := jumpInterface.payload
// }
// }
//
// //Update historyCache
decode plug new Area {
import decode._
insert(PREDICTION_HAD_BRANCHED) := decodePrediction.cmd.hadBranch
}
//Do real branch calculation
execute plug new Area {
import execute._
val less = input(SRC_LESS)
val eq = input(SRC1) === input(SRC2)
insert(BRANCH_COND_RESULT) := input(BRANCH_CTRL).mux(
BranchCtrlEnum.INC -> False,
BranchCtrlEnum.JAL -> True,
BranchCtrlEnum.JALR -> True,
BranchCtrlEnum.B -> input(INSTRUCTION)(14 downto 12).mux(
B"000" -> eq ,
B"001" -> !eq ,
M"1-1" -> !less,
default -> less
)
)
insert(BRANCH_DO) := input(PREDICTION_HAD_BRANCHED) =/= insert(BRANCH_COND_RESULT)
//Calculation of the branch target / correction
val imm = IMM(input(INSTRUCTION))
val branch_src1,branch_src2 = UInt(32 bits)
switch(input(BRANCH_CTRL)){
is(BranchCtrlEnum.JALR){
branch_src1 := input(RS1).asUInt
branch_src2 := imm.i_sext.asUInt
}
default{
branch_src1 := input(PC)
branch_src2 := (input(PREDICTION_HAD_BRANCHED) ? B(4) | imm.b_sext).asUInt
}
}
val branchAdder = branch_src1 + branch_src2
insert(BRANCH_CALC) := branchAdder(31 downto 1) @@ ((input(BRANCH_CTRL) === BranchCtrlEnum.JALR) ? False | branchAdder(0))
}
// branch JALR or JAL/Bxx prediction miss corrections
val branchStage = if(earlyBranch) execute else memory
branchStage plug new Area {
import branchStage._
jumpInterface.valid := input(BRANCH_DO) && arbitration.isFiring
jumpInterface.payload := input(BRANCH_CALC)
when(jumpInterface.valid) {
stages(indexOf(branchStage) - 1).arbitration.flushAll := True
}
if(catchAddressMisaligned) {
branchExceptionPort.valid := input(INSTRUCTION_READY) && arbitration.isValid && input(BRANCH_DO) && jumpInterface.payload(1 downto 0) =/= 0
branchExceptionPort.code := 0
branchExceptionPort.badAddr := jumpInterface.payload
}
}
//Update historyCache
decodePrediction.rsp.wasWrong := jumpInterface.valid
// if(prediction == DYNAMIC) branchStage plug new Area {
// import branchStage._
// val newHistory = input(HISTORY_LINE).history.resize(historyWidth + 1) + Mux(input(BRANCH_COND_RESULT),S(-1),S(1))
@ -264,7 +315,7 @@ class BranchPlugin(earlyBranch : Boolean,
// historyCacheWrite.address := input(PC)(2, historyRamSizeLog2 bits)
// historyCacheWrite.data.history := newHistory.resized
// }
// }
}

View file

@ -5,6 +5,7 @@ import spinal.core._
import spinal.lib._
import spinal.lib.bus.amba4.axi._
import spinal.lib.bus.avalon.{AvalonMM, AvalonMMConfig}
import vexriscv.Riscv.IMM
import scala.collection.mutable.ArrayBuffer
@ -109,12 +110,14 @@ class IBusSimplePlugin(interfaceKeepData : Boolean, catchAccessFault : Boolean,
var prefetchExceptionPort : Flow[ExceptionCause] = null
def resetVector = BigInt(0x80000000l)
def keepPcPlus4 = false
def decodePcGen = true
def compressedGen = true
def decodePcGen = false
def compressedGen = false
def cmdToRspStageCount = 1
def rspStageGen = false
def injectorReadyCutGen = true
def relaxedPcCalculation = true
def injectorReadyCutGen = false
def relaxedPcCalculation = false
def prediction : BranchPrediction = STATIC
var decodePrediction : DecodePredictionBus = null
assert(cmdToRspStageCount >= 1)
assert(!(compressedGen && !decodePcGen))
lazy val fetcherHalt = False
@ -122,6 +125,8 @@ class IBusSimplePlugin(interfaceKeepData : Boolean, catchAccessFault : Boolean,
lazy val decodeNextPc = UInt(32 bits)
def nextPc() = (decodeNextPcValid, decodeNextPc)
var predictionJumpInterface : Flow[UInt] = null
override def haltIt(): Unit = fetcherHalt := True
case class JumpInfo(interface : Flow[UInt], stage: Stage, priority : Int)
@ -142,6 +147,14 @@ class IBusSimplePlugin(interfaceKeepData : Boolean, catchAccessFault : Boolean,
}
pipeline(RVC_GEN) = compressedGen
prediction match {
case NONE =>
case STATIC | DYNAMIC => {
predictionJumpInterface = createJumpInterface(pipeline.decode)
decodePrediction = pipeline.service(classOf[PredictionInterface]).askDecodePrediction()
}
}
}
override def build(pipeline: VexRiscv): Unit = {
@ -365,6 +378,29 @@ class IBusSimplePlugin(interfaceKeepData : Boolean, catchAccessFault : Boolean,
decodeExceptionPort.code := 1
decodeExceptionPort.badAddr := decode.input(PC)
}
prediction match {
case `NONE` =>
case `STATIC` => {
val imm = IMM(decode.input(INSTRUCTION))
val conditionalBranchPrediction = (prediction match {
case `STATIC` => imm.b_sext.msb
//case `DYNAMIC` => input(HISTORY_LINE).history.msb
})
decodePrediction.cmd.hadBranch := decode.input(BRANCH_CTRL) === BranchCtrlEnum.JAL || (decode.input(BRANCH_CTRL) === BranchCtrlEnum.B && conditionalBranchPrediction)
predictionJumpInterface.valid := decodePrediction.cmd.hadBranch && decode.arbitration.isFiring //TODO OH Doublon de priorité
predictionJumpInterface.payload := decode.input(PC) + ((decode.input(BRANCH_CTRL) === BranchCtrlEnum.JAL) ? imm.j_sext | imm.b_sext).asUInt
// if(catchAddressMisaligned) {
// predictionExceptionPort.valid := input(INSTRUCTION_READY) && input(PREDICTION_HAD_BRANCHED) && arbitration.isValid && predictionJumpInterface.payload(1 downto 0) =/= 0
// predictionExceptionPort.code := 0
// predictionExceptionPort.badAddr := predictionJumpInterface.payload
// }
}
}
}
}
}