diff --git a/src/main/scala/vexriscv/TestsWorkspace.scala b/src/main/scala/vexriscv/TestsWorkspace.scala index 8d50ad9..ac19f96 100644 --- a/src/main/scala/vexriscv/TestsWorkspace.scala +++ b/src/main/scala/vexriscv/TestsWorkspace.scala @@ -129,7 +129,7 @@ object TestsWorkspace { divUnrollFactor = 1 ), // new DivPlugin, - new CsrPlugin(CsrPluginConfig.all(0x80000020l).copy(deterministicInteruptionEntry = false)), + new CsrPlugin(CsrPluginConfig.all2(0x80000020l).copy(deterministicInteruptionEntry = false)), new DebugPlugin(ClockDomain.current.clone(reset = Bool().setName("debugReset"))), new BranchPlugin( earlyBranch = true, diff --git a/src/main/scala/vexriscv/VexRiscv.scala b/src/main/scala/vexriscv/VexRiscv.scala index afd29bb..bd1ef1c 100644 --- a/src/main/scala/vexriscv/VexRiscv.scala +++ b/src/main/scala/vexriscv/VexRiscv.scala @@ -44,6 +44,10 @@ case class VexRiscvConfig(){ object SRC_USE_SUB_LESS extends Stageable(Bool) object SRC_LESS_UNSIGNED extends Stageable(Bool) + + object DISRUPT_IN_MEMORY_STAGE extends Stageable(Bool) + object DISRUPT_IN_WRITEBACK_STAGE extends Stageable(Bool) + //Formal verification purposes object FORMAL_HALT extends Stageable(Bool) object FORMAL_PC_NEXT extends Stageable(UInt(32 bits)) @@ -56,7 +60,7 @@ case class VexRiscvConfig(){ object Src1CtrlEnum extends SpinalEnum(binarySequential){ - val RS, IMU, PC_INCREMENT = newElement() //IMU, IMZ IMJB + val RS, IMU, PC_INCREMENT, URS1 = newElement() //IMU, IMZ IMJB } object Src2CtrlEnum extends SpinalEnum(binarySequential){ diff --git a/src/main/scala/vexriscv/plugin/BranchPlugin.scala b/src/main/scala/vexriscv/plugin/BranchPlugin.scala index eff2f09..55e69ad 100644 --- a/src/main/scala/vexriscv/plugin/BranchPlugin.scala +++ b/src/main/scala/vexriscv/plugin/BranchPlugin.scala @@ -178,7 +178,7 @@ class BranchPlugin(earlyBranch : Boolean, stages(indexOf(branchStage) - 1).arbitration.flushAll := True } - if(catchAddressMisaligned) { //TODO conflict with instruction cache two stage + if(catchAddressMisaligned) { branchExceptionPort.valid := arbitration.isValid && input(BRANCH_DO) && (if(pipeline(RVC_GEN)) jumpInterface.payload(0 downto 0) =/= 0 else jumpInterface.payload(1 downto 0) =/= 0) branchExceptionPort.code := 0 branchExceptionPort.badAddr := jumpInterface.payload diff --git a/src/main/scala/vexriscv/plugin/CsrPlugin.scala b/src/main/scala/vexriscv/plugin/CsrPlugin.scala index de80f76..5ca5eae 100644 --- a/src/main/scala/vexriscv/plugin/CsrPlugin.scala +++ b/src/main/scala/vexriscv/plugin/CsrPlugin.scala @@ -4,6 +4,7 @@ import spinal.core._ import spinal.lib._ import vexriscv._ import vexriscv.Riscv._ +import vexriscv.plugin.IntAluPlugin.{ALU_BITWISE_CTRL, ALU_CTRL, AluBitwiseCtrlEnum, AluCtrlEnum} import scala.collection.mutable.ArrayBuffer import scala.collection.mutable @@ -50,6 +51,7 @@ case class CsrPluginConfig( ucycleAccess : CsrAccess, wfiGen : Boolean, ecallGen : Boolean, + supervisorGen : Boolean = false, sscratchGen : Boolean = false, stvecAccess : CsrAccess = CsrAccess.NONE, sepcAccess : CsrAccess = CsrAccess.NONE, @@ -92,6 +94,36 @@ object CsrPluginConfig{ ucycleAccess = CsrAccess.READ_ONLY ) + def all2(mtvecInit : BigInt) : CsrPluginConfig = CsrPluginConfig( + catchIllegalAccess = true, + mvendorid = 11, + marchid = 22, + mimpid = 33, + mhartid = 0, + misaExtensionsInit = 66, + misaAccess = CsrAccess.READ_WRITE, + mtvecAccess = CsrAccess.READ_WRITE, + mtvecInit = mtvecInit, + mepcAccess = CsrAccess.READ_WRITE, + mscratchGen = true, + mcauseAccess = CsrAccess.READ_WRITE, + mbadaddrAccess = CsrAccess.READ_WRITE, + mcycleAccess = CsrAccess.READ_WRITE, + minstretAccess = CsrAccess.READ_WRITE, + ecallGen = true, + wfiGen = true, + ucycleAccess = CsrAccess.READ_ONLY, + supervisorGen = true, + sscratchGen = true, + stvecAccess = CsrAccess.READ_WRITE, + sepcAccess = CsrAccess.READ_WRITE, + scauseAccess = CsrAccess.READ_WRITE, + sbadaddrAccess = CsrAccess.READ_WRITE, + scycleAccess = CsrAccess.READ_WRITE, + sinstretAccess = CsrAccess.READ_WRITE, + satpAccess = CsrAccess.READ_WRITE + ) + def small(mtvecInit : BigInt) = CsrPluginConfig( catchIllegalAccess = false, mvendorid = null, @@ -202,7 +234,6 @@ class CsrPlugin(config : CsrPluginConfig) extends Plugin[VexRiscv] with Exceptio } var jumpInterface : Flow[UInt] = null - var pluginExceptionPort : Flow[ExceptionCause] = null var timerInterrupt, externalInterrupt : Bool = null var timerInterruptS, externalInterruptS : Bool = null var privilege : UInt = null @@ -241,8 +272,8 @@ class CsrPlugin(config : CsrPluginConfig) extends Plugin[VexRiscv] with Exceptio val defaultCsrActions = List[(Stageable[_ <: BaseType],Any)]( IS_CSR -> True, REGFILE_WRITE_VALID -> True, - BYPASSABLE_EXECUTE_STAGE -> True, - BYPASSABLE_MEMORY_STAGE -> True + ALU_BITWISE_CTRL -> AluBitwiseCtrlEnum.SRC1, + ALU_CTRL -> AluCtrlEnum.BITWISE ) val nonImmediatActions = defaultCsrActions ++ List( @@ -250,7 +281,9 @@ class CsrPlugin(config : CsrPluginConfig) extends Plugin[VexRiscv] with Exceptio RS1_USE -> True ) - val immediatActions = defaultCsrActions + val immediatActions = defaultCsrActions ++ List( + SRC1_CTRL -> Src1CtrlEnum.URS1 + ) val decoderService = pipeline.service(classOf[DecoderService]) @@ -274,20 +307,19 @@ class CsrPlugin(config : CsrPluginConfig) extends Plugin[VexRiscv] with Exceptio jumpInterface.valid := False jumpInterface.payload.assignDontCare() - if(ecallGen) { - pluginExceptionPort = newExceptionPort(pipeline.execute) - pluginExceptionPort.valid := False - pluginExceptionPort.payload.assignDontCare() - } timerInterrupt = in Bool() setName("timerInterrupt") externalInterrupt = in Bool() setName("externalInterrupt") + if(supervisorGen){ + timerInterruptS = in Bool() setName("timerInterruptS") + externalInterruptS = in Bool() setName("externalInterruptS") + } contextSwitching = Bool().setName("contextSwitching") - privilege = RegInit(U"11") + privilege = RegInit(U"11").setName("CsrPlugin_privilege") - if(catchIllegalAccess) - selfException = newExceptionPort(pipeline.execute) + if(catchIllegalAccess || ecallGen) + selfException = newExceptionPort(pipeline.writeBack) allowInterrupts = True allowException = True @@ -404,8 +436,9 @@ class CsrPlugin(config : CsrPluginConfig) extends Plugin[VexRiscv] with Exceptio minstretAccess(CSR.MINSTRETH, minstret(63 downto 32)) //Supervisor CSR + WRITE_ONLY(CSR.SSTATUS,8 -> sstatus.SPP, 5 -> sstatus.SPIE, 1 -> sstatus.SIE) for(offset <- List(0, 0x200)) { - READ_WRITE(CSR.SSTATUS,8 -> sstatus.SPP, 5 -> sstatus.SPIE, 1 -> sstatus.SIE) + READ_ONLY(CSR.SSTATUS,8 -> sstatus.SPP, 5 -> sstatus.SPIE, 1 -> sstatus.SIE) } READ_ONLY(CSR.SIP, 9 -> sip.SEIP, 5 -> sip.STIP) READ_WRITE(CSR.SIP, 1 -> sip.SSIP) @@ -446,9 +479,9 @@ class CsrPlugin(config : CsrPluginConfig) extends Plugin[VexRiscv] with Exceptio )) case class DelegatorModel(value : Bits, source : Int, target : Int) - def solveDelegators(delegators : Seq[DelegatorModel], id : Int, upTo : Int): UInt = { - val filtredDelegators = delegators.filter(_.target <= upTo) - val ret = U(filtredDelegators.last.target, 2 bits) + def solveDelegators(delegators : Seq[DelegatorModel], id : Int, lowerBound : Int): UInt = { + val filtredDelegators = delegators.filter(_.target >= lowerBound) + val ret = U(lowerBound, 2 bits) for(d <- filtredDelegators){ when(!d.value(id)){ ret := d.source @@ -457,10 +490,10 @@ class CsrPlugin(config : CsrPluginConfig) extends Plugin[VexRiscv] with Exceptio ret } - def solveDelegators(delegators : Seq[DelegatorModel], id : UInt, upTo : UInt): UInt = { + def solveDelegators(delegators : Seq[DelegatorModel], id : UInt, lowerBound : UInt): UInt = { val ret = U(delegators.last.target, 2 bits) for(d <- delegators){ - when(!d.value(id) || d.target > upTo){ + when(!d.value(id) || d.target < lowerBound){ ret := d.source } } @@ -468,10 +501,10 @@ class CsrPlugin(config : CsrPluginConfig) extends Plugin[VexRiscv] with Exceptio } val interruptDelegators = ArrayBuffer[DelegatorModel]() - interruptDelegators += DelegatorModel(mideleg,0, 2) + interruptDelegators += DelegatorModel(mideleg,3, 1) val exceptionDelegators = ArrayBuffer[DelegatorModel]() - exceptionDelegators += DelegatorModel(medeleg,0, 2) + exceptionDelegators += DelegatorModel(medeleg,3, 1) val mepcCaptureStage = if(exceptionPortsInfos.nonEmpty) writeBack else decode @@ -578,7 +611,8 @@ class CsrPlugin(config : CsrPluginConfig) extends Plugin[VexRiscv] with Exceptio interruptJump := interrupt && pipelineLiberator.done val hadException = RegNext(exception) init(False) - writeBack.arbitration.haltItself setWhen(exception && !hadException) + exception clearWhen(hadException) + writeBack.arbitration.haltItself setWhen(exception) val targetPrivilege = CombInit(interruptTargetPrivilege) @@ -586,6 +620,11 @@ class CsrPlugin(config : CsrPluginConfig) extends Plugin[VexRiscv] with Exceptio targetPrivilege := exceptionPortCtrl.exceptionTargetPrivilege } + val trapCause = CombInit(interruptCode) + if(exceptionPortCtrl != null) when( hadException){ + trapCause := exceptionPortCtrl.exceptionContext.code + } + when(hadException || (interruptJump && !exception)){ jumpInterface.valid := True jumpInterface.payload := mtvec @@ -596,30 +635,30 @@ class CsrPlugin(config : CsrPluginConfig) extends Plugin[VexRiscv] with Exceptio is(1){ sstatus.SIE := False sstatus.SPIE := sstatus.SIE - sstatus.SPP := privilege + sstatus.SPP := privilege(0 downto 0) + scause.interrupt := !hadException + scause.exceptionCode := trapCause + sepc := mepcCaptureStage.input(PC) if(exceptionPortCtrl != null) { stval := exceptionPortCtrl.exceptionContext.badAddr - scause.exceptionCode := exceptionPortCtrl.exceptionContext.code } } is(3){ mstatus.MIE := False mstatus.MPIE := mstatus.MIE mstatus.MPP := privilege + mcause.interrupt := !hadException + mcause.exceptionCode := trapCause + mepc := mepcCaptureStage.input(PC) if(exceptionPortCtrl != null) { mtval := exceptionPortCtrl.exceptionContext.badAddr - mcause.exceptionCode := exceptionPortCtrl.exceptionContext.code } } } - mepc := mepcCaptureStage.input(PC) - mcause.interrupt := interruptJump - mcause.exceptionCode := interruptCode } - contextSwitching := jumpInterface.valid //CSR read/write instructions management @@ -632,79 +671,79 @@ class CsrPlugin(config : CsrPluginConfig) extends Plugin[VexRiscv] with Exceptio || (input(INSTRUCTION)(14 downto 13) === "11" && imm.z === 0) ) insert(CSR_READ_OPCODE) := input(INSTRUCTION)(13 downto 7) =/= B"0100000" - //Assure that the CSR access are in the execute stage when there is nothing left in memory/writeback stages to avoid exception hazard - arbitration.haltItself setWhen(arbitration.isValid && input(IS_CSR) && (execute.arbitration.isValid || memory.arbitration.isValid)) } - execute plug new Area { + + + execute plug new Area{ import execute._ + //Manage WFI instructions + if(wfiGen) when(arbitration.isValid && input(ENV_CTRL) === EnvCtrlEnum.WFI){ + when(!interrupt){ + arbitration.haltItself := True + } + } + } + + writeBack plug new Area { + import writeBack._ + def previousStage = memory val illegalAccess = arbitration.isValid && input(IS_CSR) val illegalInstruction = False - if(catchIllegalAccess) { + if(selfException != null) { selfException.valid := illegalAccess || illegalInstruction selfException.code := 2 selfException.badAddr.assignDontCare() } - //TODO jump interface logic change to avoid combinatorial path on the valid ? - //Manage MRET / SRET instructions - when(execute.arbitration.isValid && execute.input(ENV_CTRL) === EnvCtrlEnum.XRET) { - illegalInstruction setWhen(execute.input(INSTRUCTION)(29 downto 28).asUInt =/= privilege) + when(arbitration.isValid && input(ENV_CTRL) === EnvCtrlEnum.XRET) { jumpInterface.payload := mepc - when(memory.arbitration.isValid || writeBack.arbitration.isValid){ - execute.arbitration.haltItself := True - } elsewhen (execute.arbitration.isFiring) { + when(input(INSTRUCTION)(29 downto 28).asUInt =/= privilege) { + illegalInstruction := True + } otherwise{ jumpInterface.valid := True - decode.arbitration.flushAll := True - switch(execute.input(INSTRUCTION)(29 downto 28)){ + previousStage.arbitration.flushAll := True + switch(input(INSTRUCTION)(29 downto 28)){ is(3){ mstatus.MIE := mstatus.MPIE mstatus.MPP := U"00" mstatus.MPIE := True - privilege := mstatus.MPP + privilege := mstatus.MPP //TODO check MPP value } is(1){ sstatus.SIE := sstatus.SPIE - sstatus.SPP := U"00" + sstatus.SPP := U"0" sstatus.SPIE := True - privilege := sstatus.SPP + privilege := U"0" @@ sstatus.SPP } } } } //Manage ECALL instructions - if(ecallGen) when(execute.arbitration.isValid && execute.input(ENV_CTRL) === EnvCtrlEnum.ECALL){ - pluginExceptionPort.valid := True - pluginExceptionPort.code := 11 + if(ecallGen) when(arbitration.isValid && input(ENV_CTRL) === EnvCtrlEnum.ECALL){ + selfException.valid := True + selfException.code := 11 } - //Manage WFI instructions - if(wfiGen) when(execute.arbitration.isValid && execute.input(ENV_CTRL) === EnvCtrlEnum.WFI){ - when(!interrupt){ - execute.arbitration.haltItself := True - } - } - - val imm = IMM(input(INSTRUCTION)) - val writeSrc = input(INSTRUCTION)(14) ? imm.z.asBits.resized | input(SRC1) + val writeSrc = input(REGFILE_WRITE_DATA) val readData = B(0, 32 bits) - def readDataReg = memory.input(REGFILE_WRITE_DATA) //PIPE OPT - val readDataRegValid = Reg(Bool) setWhen(arbitration.isValid) clearWhen(!arbitration.isStuck) +// def readDataReg = memory.input(REGFILE_WRITE_DATA) //PIPE OPT +// val readDataRegValid = Reg(Bool) setWhen(arbitration.isValid) clearWhen(!arbitration.isStuck) val writeData = input(INSTRUCTION)(13).mux( False -> writeSrc, - True -> Mux(input(INSTRUCTION)(12), readDataReg & ~writeSrc, readDataReg | writeSrc) + True -> Mux(input(INSTRUCTION)(12), readData & ~writeSrc, readData | writeSrc) ) val writeInstruction = arbitration.isValid && input(IS_CSR) && input(CSR_WRITE_OPCODE) val readInstruction = arbitration.isValid && input(IS_CSR) && input(CSR_READ_OPCODE) - arbitration.haltItself setWhen(writeInstruction && !readDataRegValid) - val writeEnable = writeInstruction && readDataRegValid - val readEnable = readInstruction && !readDataRegValid +// arbitration.haltItself setWhen(writeInstruction && !readDataRegValid) + val writeEnable = writeInstruction// && readDataRegValid + val readEnable = readInstruction// && !readDataRegValid when(arbitration.isValid && input(IS_CSR)) { output(REGFILE_WRITE_DATA) := readData diff --git a/src/main/scala/vexriscv/plugin/SrcPlugin.scala b/src/main/scala/vexriscv/plugin/SrcPlugin.scala index da9b151..d319b9b 100644 --- a/src/main/scala/vexriscv/plugin/SrcPlugin.scala +++ b/src/main/scala/vexriscv/plugin/SrcPlugin.scala @@ -16,7 +16,8 @@ class SrcPlugin(separatedAddSub : Boolean = false, executeInsertion : Boolean = insert(SRC1) := input(SRC1_CTRL).mux( Src1CtrlEnum.RS -> output(RS1), Src1CtrlEnum.PC_INCREMENT -> (if(pipeline(RVC_GEN)) Mux(input(IS_RVC), B(2), B(4)) else B(4)).resized, - Src1CtrlEnum.IMU -> imm.u.resized + Src1CtrlEnum.IMU -> imm.u.resized, + Src1CtrlEnum.URS1 -> input(INSTRUCTION)(Riscv.rs1Range).resized ) insert(SRC2) := input(SRC2_CTRL).mux( Src2CtrlEnum.RS -> output(RS2),