diff --git a/src/main/scala/vexriscv/plugin/HazardSimplePlugin.scala b/src/main/scala/vexriscv/plugin/HazardSimplePlugin.scala index 1ed1d83..1b650e3 100644 --- a/src/main/scala/vexriscv/plugin/HazardSimplePlugin.scala +++ b/src/main/scala/vexriscv/plugin/HazardSimplePlugin.scala @@ -31,29 +31,62 @@ class HazardSimplePlugin(bypassExecute : Boolean = false, override def build(pipeline: VexRiscv): Unit = { import pipeline._ import pipeline.config._ - val src0Hazard = False - val src1Hazard = False - val readStage = service(classOf[RegFileService]).readStage() + pipeline plug new Area { + val src0Hazard = False + val src1Hazard = False - def trackHazardWithStage(stage : Stage,bypassable : Boolean, runtimeBypassable : Stageable[Bool]): Unit ={ - val runtimeBypassableValue = if(runtimeBypassable != null) stage.input(runtimeBypassable) else True - val addr0Match = if(pessimisticAddressMatch) True else stage.input(INSTRUCTION)(rdRange) === readStage.input(INSTRUCTION)(rs1Range) - val addr1Match = if(pessimisticAddressMatch) True else stage.input(INSTRUCTION)(rdRange) === readStage.input(INSTRUCTION)(rs2Range) - when(stage.arbitration.isValid && stage.input(REGFILE_WRITE_VALID)) { - if (bypassable) { - when(runtimeBypassableValue) { + val readStage = service(classOf[RegFileService]).readStage() + + def trackHazardWithStage(stage: Stage, bypassable: Boolean, runtimeBypassable: Stageable[Bool]): Unit = { + val runtimeBypassableValue = if (runtimeBypassable != null) stage.input(runtimeBypassable) else True + val addr0Match = if (pessimisticAddressMatch) True else stage.input(INSTRUCTION)(rdRange) === readStage.input(INSTRUCTION)(rs1Range) + val addr1Match = if (pessimisticAddressMatch) True else stage.input(INSTRUCTION)(rdRange) === readStage.input(INSTRUCTION)(rs2Range) + when(stage.arbitration.isValid && stage.input(REGFILE_WRITE_VALID)) { + if (bypassable) { + when(runtimeBypassableValue) { + when(addr0Match) { + readStage.input(RS1) := stage.output(REGFILE_WRITE_DATA) + } + when(addr1Match) { + readStage.input(RS2) := stage.output(REGFILE_WRITE_DATA) + } + } + } + } + when(stage.arbitration.isValid && (if (pessimisticWriteRegFile) True else stage.input(REGFILE_WRITE_VALID))) { + when((Bool(!bypassable) || !runtimeBypassableValue)) { when(addr0Match) { - readStage.input(RS1) := stage.output(REGFILE_WRITE_DATA) + src0Hazard := True } when(addr1Match) { - readStage.input(RS2) := stage.output(REGFILE_WRITE_DATA) + src1Hazard := True } } } } - when(stage.arbitration.isValid && (if(pessimisticWriteRegFile) True else stage.input(REGFILE_WRITE_VALID))) { - when((Bool(!bypassable) || !runtimeBypassableValue)) { + + + val writeBackWrites = Flow(cloneable(new Bundle { + val address = Bits(5 bits) + val data = Bits(32 bits) + })) + writeBackWrites.valid := stages.last.output(REGFILE_WRITE_VALID) && stages.last.arbitration.isFiring + writeBackWrites.address := stages.last.output(INSTRUCTION)(rdRange) + writeBackWrites.data := stages.last.output(REGFILE_WRITE_DATA) + val writeBackBuffer = writeBackWrites.stage() + + val addr0Match = if (pessimisticAddressMatch) True else writeBackBuffer.address === readStage.input(INSTRUCTION)(rs1Range) + val addr1Match = if (pessimisticAddressMatch) True else writeBackBuffer.address === readStage.input(INSTRUCTION)(rs2Range) + when(writeBackBuffer.valid) { + if (bypassWriteBackBuffer) { + when(addr0Match) { + readStage.input(RS1) := writeBackBuffer.data + } + when(addr1Match) { + readStage.input(RS2) := writeBackBuffer.data + } + } else { when(addr0Match) { src0Hazard := True } @@ -62,54 +95,24 @@ class HazardSimplePlugin(bypassExecute : Boolean = false, } } } - } + + if (withWriteBackStage) trackHazardWithStage(writeBack, bypassWriteBack, null) + if (withMemoryStage) trackHazardWithStage(memory, bypassMemory, if (stages.last == memory) null else BYPASSABLE_MEMORY_STAGE) + if (readStage != execute) trackHazardWithStage(execute, bypassExecute, if (stages.last == execute) null else BYPASSABLE_EXECUTE_STAGE) - val writeBackWrites = Flow(cloneable(new Bundle{ - val address = Bits(5 bits) - val data = Bits(32 bits) - })) - writeBackWrites.valid := stages.last.output(REGFILE_WRITE_VALID) && stages.last.arbitration.isFiring - writeBackWrites.address := stages.last.output(INSTRUCTION)(rdRange) - writeBackWrites.data := stages.last.output(REGFILE_WRITE_DATA) - val writeBackBuffer = writeBackWrites.stage() - - val addr0Match = if(pessimisticAddressMatch) True else writeBackBuffer.address === readStage.input(INSTRUCTION)(rs1Range) - val addr1Match = if(pessimisticAddressMatch) True else writeBackBuffer.address === readStage.input(INSTRUCTION)(rs2Range) - when(writeBackBuffer.valid) { - if (bypassWriteBackBuffer) { - when(addr0Match) { - readStage.input(RS1) := writeBackBuffer.data + if (!pessimisticUseSrc) { + when(!readStage.input(RS1_USE)) { + src0Hazard := False } - when(addr1Match) { - readStage.input(RS2) := writeBackBuffer.data - } - } else { - when(addr0Match) { - src0Hazard := True - } - when(addr1Match) { - src1Hazard := True + when(!readStage.input(RS2_USE)) { + src1Hazard := False } } - } - if(withWriteBackStage) trackHazardWithStage(writeBack,bypassWriteBack,null) - if(withMemoryStage) trackHazardWithStage(memory ,bypassMemory, if(stages.last == memory) null else BYPASSABLE_MEMORY_STAGE) - if(readStage != execute) trackHazardWithStage(execute ,bypassExecute , if(stages.last == execute) null else BYPASSABLE_EXECUTE_STAGE) - - - if(!pessimisticUseSrc) { - when(!readStage.input(RS1_USE)) { - src0Hazard := False + when(readStage.arbitration.isValid && (src0Hazard || src1Hazard)) { + readStage.arbitration.haltByOther := True } - when(!readStage.input(RS2_USE)) { - src1Hazard := False - } - } - - when(readStage.arbitration.isValid && (src0Hazard || src1Hazard)){ - readStage.arbitration.haltByOther := True } } } diff --git a/src/main/scala/vexriscv/plugin/MulPlugin.scala b/src/main/scala/vexriscv/plugin/MulPlugin.scala index c1619d2..997857b 100644 --- a/src/main/scala/vexriscv/plugin/MulPlugin.scala +++ b/src/main/scala/vexriscv/plugin/MulPlugin.scala @@ -61,10 +61,9 @@ class MulPlugin(var inputBuffer : Boolean = false, when(arbitration.isValid && input(IS_MUL) && counter =/= delay){ arbitration.haltItself := True } - when(!arbitration.isStuckByOthers){ - counter := counter + 1 - } - when(!arbitration.isStuck){ + + counter := counter + 1 + when(!arbitration.isStuck || arbitration.isStuckByOthers){ counter := 0 } }