fpu improve fmax

This commit is contained in:
Dolu1990 2021-02-17 16:35:52 +01:00
parent 1e647f799c
commit 8537d18b16
2 changed files with 203 additions and 122 deletions

View File

@ -83,7 +83,7 @@ case class FpuCore( portCount : Int, p : FpuParameter) extends Component{
val rs1Boxed, rs2Boxed = p.withDouble generate Bool()
}
case class MulInput() extends Bundle{
class MulInput() extends Bundle{
val source = Source()
val rs1, rs2, rs3 = p.internalFloating()
val rd = p.rfAddress()
@ -117,7 +117,7 @@ case class FpuCore( portCount : Int, p : FpuParameter) extends Component{
}
case class MergeInput() extends Bundle{
class MergeInput() extends Bundle{
val source = Source()
val lockId = lockIdType()
val rd = p.rfAddress()
@ -174,7 +174,7 @@ case class FpuCore( portCount : Int, p : FpuParameter) extends Component{
val fork = new StreamFork(FpuCommit(p), 2)
fork.io.input << io.port(i).commit
fork.io.outputs(0) >> load(i)
fork.io.outputs(1) >> commit(i)
fork.io.outputs(1).pipelined(m2s = true, s2m = true) >> commit(i) //Pipelining here is light, as it only use the flags of the payload
}
}
@ -198,16 +198,16 @@ case class FpuCore( portCount : Int, p : FpuParameter) extends Component{
}
}
//TODO nan boxing decoding
val read = new Area{
val arbiter = StreamArbiterFactory.noLock.roundRobin.build(FpuCmd(p), portCount)
arbiter.io.inputs <> Vec(io.port.map(_.cmd))
val s0 = Stream(RfReadInput())
s0.arbitrationFrom(arbiter.io.output)
s0.source := arbiter.io.chosen
s0.payload.assignSomeByName(arbiter.io.output.payload)
val arbiterOutput = Stream(RfReadInput())
arbiterOutput.arbitrationFrom(arbiter.io.output)
arbiterOutput.source := arbiter.io.chosen
arbiterOutput.payload.assignSomeByName(arbiter.io.output.payload)
val s0 = arbiterOutput.pipelined(m2s = true, s2m = true)
val useRs1, useRs2, useRs3, useRd = False
switch(s0.opcode){
is(p.Opcode.LOAD) { useRd := True }
@ -314,8 +314,8 @@ case class FpuCore( portCount : Int, p : FpuParameter) extends Component{
val fmaHit = input.opcode === p.Opcode.FMA
val mulHit = input.opcode === p.Opcode.MUL || fmaHit
val mul = Stream(MulInput())
val divSqrtToMul = Stream(MulInput())
val mul = Stream(new MulInput())
val divSqrtToMul = Stream(new MulInput())
if(p.withMul) {
input.ready setWhen (mulHit && mul.ready && !divSqrtToMul.valid)
@ -369,7 +369,7 @@ case class FpuCore( portCount : Int, p : FpuParameter) extends Component{
}
val s0 = new Area{
val input = decode.load.stage()
val input = decode.load.pipelined(m2s = true, s2m = true)
val filtred = commitFork.load.map(port => port.takeWhen(port.sync))
def feed = filtred(input.source)
val hazard = !feed.valid
@ -390,7 +390,6 @@ case class FpuCore( portCount : Int, p : FpuParameter) extends Component{
output.format := FpuFormat.FLOAT
}
}
}
val s1 = new Area{
@ -510,7 +509,7 @@ case class FpuCore( portCount : Int, p : FpuParameter) extends Component{
when(isInfinity){recoded.setInfinity}
when(isNan){recoded.setNan}
val output = input.haltWhen(busy).swapPayload(MergeInput())
val output = input.haltWhen(busy).swapPayload(new MergeInput())
output.source := input.source
output.lockId := input.lockId
output.roundMode := input.roundMode
@ -540,7 +539,7 @@ case class FpuCore( portCount : Int, p : FpuParameter) extends Component{
val shortPip = new Area{
val input = decode.shortPip.stage()
val rfOutput = Stream(MergeInput())
val rfOutput = Stream(new MergeInput())
val result = p.storeLoadType().assignDontCare()
@ -820,20 +819,61 @@ case class FpuCore( portCount : Int, p : FpuParameter) extends Component{
}
val mul = p.withMul generate new Area{
val input = decode.mul.stage()
val inWidthA = p.internalMantissaSize+1
val inWidthB = p.internalMantissaSize+1
val outWidth = p.internalMantissaSize*2+2
val math = new Area {
case class MulSplit(offsetA : Int, offsetB : Int, widthA : Int, widthB : Int, id : Int){
val offsetC = offsetA+offsetB
}
val splitsUnordered = for(offsetA <- 0 until inWidthA by p.mulWidthA;
offsetB <- 0 until inWidthB by p.mulWidthB;
widthA = (inWidthA - offsetA) min p.mulWidthA;
widthB = (inWidthB - offsetB) min p.mulWidthB) yield {
MulSplit(offsetA, offsetB, widthA, widthB, -1)
}
val splits = splitsUnordered.sortWith(_.offsetC < _.offsetC).zipWithIndex.map(e => e._1.copy(id=e._2))
class MathWithExp extends MulInput{
val exp = UInt(p.internalExponentSize+1 bits)
}
val preMul = new Area{
val input = decode.mul.stage()
val output = input.swapPayload(new MathWithExp())
output.payload.assignSomeByName(input.payload)
output.exp := input.rs1.exponent +^ input.rs2.exponent
}
class MathWithMul extends MathWithExp{
val muls = Vec(splits.map(e => UInt(e.widthA + e.widthB bits)))
}
val mul = new Area{
val input = preMul.output.stage()
val output = input.swapPayload(new MathWithMul())
val mulA = U(input.msb1) @@ input.rs1.mantissa
val mulB = U(input.msb2) @@ input.rs2.mantissa
val mulC = mulA * mulB
val exp = input.rs1.exponent +^ input.rs2.exponent
output.payload.assignSomeByName(input.payload)
splits.foreach(e => output.muls(e.id) := mulA(e.offsetA, e.widthA bits) * mulB(e.offsetB, e.widthB bits))
}
class MathOutput extends MathWithExp{
val mulC = UInt(p.internalMantissaSize*2+2 bits)
}
val math = new Area {
val input = mul.output.stage()
val sum = splits.map(e => (input.muls(e.id) << e.offsetC).resize(outWidth)).reduceBalancedTree(_ + _)
val output = input.swapPayload(new MathOutput())
output.payload.assignSomeByName(input.payload)
output.mulC := sum
}
val norm = new Area{
val (mulHigh, mulLow) = math.mulC.splitAt(p.internalMantissaSize-1)
val input = math.output.stage()
val (mulHigh, mulLow) = input.mulC.splitAt(p.internalMantissaSize-1)
val scrap = mulLow =/= 0
val needShift = mulHigh.msb
val exp = math.exp + U(needShift)
val exp = input.exp + U(needShift)
val man = needShift ? mulHigh(1, p.internalMantissaSize+1 bits) | mulHigh(0, p.internalMantissaSize+1 bits)
scrap setWhen(needShift && mulHigh(0))
val forceZero = input.rs1.isZero || input.rs2.isZero
@ -863,38 +903,40 @@ case class FpuCore( portCount : Int, p : FpuParameter) extends Component{
} elsewhen(forceUnderflow) {
output.exponent := underflowExp.resized
}
}
val notMul = new Area{
val output = Flow(UInt(p.internalMantissaSize + 1 bits))
output.valid := input.valid && input.divSqrt
output.payload := math.mulC(p.internalMantissaSize, p.internalMantissaSize+1 bits)
val result = new Area {
def input = norm.input
val notMul = new Area {
val output = Flow(UInt(p.internalMantissaSize + 1 bits))
output.valid := input.valid && input.divSqrt
output.payload := input.mulC(p.internalMantissaSize, p.internalMantissaSize + 1 bits)
}
val output = Stream(new MergeInput())
output.valid := input.valid && !input.add && !input.divSqrt
output.source := input.source
output.lockId := input.lockId
output.rd := input.rd
if (p.withDouble) output.format := input.format
output.roundMode := input.roundMode
output.scrap := norm.scrap
output.value := norm.output
decode.mulToAdd.valid := input.valid && input.add
decode.mulToAdd.source := input.source
decode.mulToAdd.rs1.mantissa := norm.output.mantissa >> 1 //FMA Precision lost
decode.mulToAdd.rs1.exponent := norm.output.exponent
decode.mulToAdd.rs1.sign := norm.output.sign
decode.mulToAdd.rs1.special := False //TODO
decode.mulToAdd.rs2 := input.rs3
decode.mulToAdd.rd := input.rd
decode.mulToAdd.lockId := input.lockId
decode.mulToAdd.roundMode := input.roundMode
if (p.withDouble) decode.mulToAdd.format := input.format
input.ready := (input.add ? decode.mulToAdd.ready | output.ready) || input.divSqrt
}
val output = Stream(MergeInput())
output.valid := input.valid && !input.add && !input.divSqrt
output.source := input.source
output.lockId := input.lockId
output.rd := input.rd
if(p.withDouble) output.format := input.format
output.roundMode := input.roundMode
output.scrap := norm.scrap
output.value := norm.output
decode.mulToAdd.valid := input.valid && input.add
decode.mulToAdd.source := input.source
decode.mulToAdd.rs1.mantissa := norm.output.mantissa >> 1 //FMA Precision lost
decode.mulToAdd.rs1.exponent := norm.output.exponent
decode.mulToAdd.rs1.sign := norm.output.sign
decode.mulToAdd.rs1.special := False //TODO
decode.mulToAdd.rs2 := input.rs3
decode.mulToAdd.rd := input.rd
decode.mulToAdd.lockId := input.lockId
decode.mulToAdd.roundMode := input.roundMode
if(p.withDouble) decode.mulToAdd.format := input.format
input.ready := (input.add ? decode.mulToAdd.ready | output.ready) || input.divSqrt
}
val divSqrt = p.withDivSqrt generate new Area {
@ -965,7 +1007,7 @@ case class FpuCore( portCount : Int, p : FpuParameter) extends Component{
decode.divSqrtToMul.msb2 := rs2.msb
}
val mulBuffer = mul.notMul.output.toStream.stage
val mulBuffer = mul.result.notMul.output.toStream.stage
mulBuffer.ready := False
val iterationValue = Reg(UInt(mulWidth bits))
@ -1081,9 +1123,20 @@ case class FpuCore( portCount : Int, p : FpuParameter) extends Component{
}
val add = p.withAdd generate new Area{
val input = decode.add.stage()
class ShifterOutput extends AddInput{
val xSign, ySign = Bool()
val xMantissa, yMantissa = UInt(p.internalMantissaSize+3 bits)
val xyExponent = UInt(p.internalExponentSize bits)
val xySign = Bool()
val roundingScrap = Bool()
}
val shifter = new Area {
val input = decode.add.stage()
val output = input.swapPayload(new ShifterOutput)
output.payload.assignSomeByName(input.payload)
val exp21 = input.rs2.exponent -^ input.rs1.exponent
val rs1ExponentBigger = (exp21.msb || input.rs2.isZero) && !input.rs1.isZero
val rs1ExponentEqual = input.rs1.exponent === input.rs2.exponent
@ -1095,8 +1148,8 @@ case class FpuCore( portCount : Int, p : FpuParameter) extends Component{
//Note that rs1ExponentBigger can be replaced by absRs1Bigger bellow to avoid xsigned two complement in math block at expense of combinatorial path
val xySign = absRs1Bigger ? input.rs1.sign | input.rs2.sign
val xSign = xySign ^ (rs1ExponentBigger ? input.rs1.sign | input.rs2.sign)
val ySign = xySign ^ (rs1ExponentBigger ? input.rs2.sign | input.rs1.sign)
output.xSign := xySign ^ (rs1ExponentBigger ? input.rs1.sign | input.rs2.sign)
output.ySign := xySign ^ (rs1ExponentBigger ? input.rs2.sign | input.rs1.sign)
val xMantissa = U"1" @@ (rs1ExponentBigger ? input.rs1.mantissa | input.rs2.mantissa) @@ U"00"
val yMantissaUnshifted = U"1" @@ (rs1ExponentBigger ? input.rs2.mantissa | input.rs1.mantissa) @@ U"00"
var yMantissa = CombInit(yMantissaUnshifted)
@ -1108,66 +1161,86 @@ case class FpuCore( portCount : Int, p : FpuParameter) extends Component{
when(passThrough) { yMantissa := 0 }
when(shiftOverflow) { roundingScrap := True }
when(input.rs1.special || input.rs2.special){ roundingScrap := False }
val xyExponent = rs1ExponentBigger ? input.rs1.exponent | input.rs2.exponent
output.xyExponent := rs1ExponentBigger ? input.rs1.exponent | input.rs2.exponent
output.xMantissa := xMantissa
output.yMantissa := yMantissa
output.xySign := xySign
output.roundingScrap := roundingScrap
}
class MathOutput extends ShifterOutput{
val xyMantissa = UInt(p.internalMantissaSize+4 bits)
}
val math = new Area {
def xSign = shifter.xSign
def ySign = shifter.ySign
def xMantissa = shifter.xMantissa
def yMantissa = shifter.yMantissa
def xyExponent = shifter.xyExponent
def xySign = shifter.xySign
val input = shifter.output.stage()
val output = input.swapPayload(new MathOutput)
output.payload.assignSomeByName(input.payload)
import input.payload._
val xSigned = xMantissa.twoComplement(xSign) //TODO Is that necessary ?
val ySigned = ((ySign ## Mux(ySign, ~yMantissa, yMantissa)).asUInt + (ySign && !shifter.roundingScrap).asUInt).asSInt //rounding here
val xyMantissa = U(xSigned +^ ySigned).trim(1 bits)
val ySigned = ((ySign ## Mux(ySign, ~yMantissa, yMantissa)).asUInt + (ySign && !roundingScrap).asUInt).asSInt //rounding here
output.xyMantissa := U(xSigned +^ ySigned).trim(1 bits)
}
class NormOutput extends AddInput{
val mantissa = UInt(p.internalMantissaSize+4 bits)
val exponent = UInt(p.internalExponentSize+1 bits)
val infinityNan, forceNan, forceZero, forceInfinity = Bool()
val xySign, roundingScrap = Bool()
val xyMantissaZero = Bool()
}
val norm = new Area{
def xyExponent = math.xyExponent
def xyMantissa = math.xyMantissa
val xySign = CombInit(math.xySign)
val input = math.output.stage()
val output = input.swapPayload(new NormOutput)
output.payload.assignSomeByName(input.payload)
import input.payload._
val shiftOh = OHMasking.first(xyMantissa.asBools.reverse)
val shift = OHToUInt(shiftOh)
val mantissa = (xyMantissa |<< shift)
val exponent = xyExponent -^ shift + 1
val forceZero = xyMantissa === 0 || (input.rs1.isZero && input.rs2.isZero)
// val forceOverflow = exponent === exponentOne + 128 //Handled by writeback rounding
val forceInfinity = (input.rs1.isInfinity || input.rs2.isInfinity)
val infinityNan = (input.rs1.isInfinity && input.rs2.isInfinity && (input.rs1.sign ^ input.rs2.sign))
val forceNan = input.rs1.isNan || input.rs2.isNan || infinityNan
output.mantissa := (xyMantissa |<< shift)
output.exponent := xyExponent -^ shift + 1
output.forceInfinity := (input.rs1.isInfinity || input.rs2.isInfinity)
output.forceZero := xyMantissa === 0 || (input.rs1.isZero && input.rs2.isZero)
output.infinityNan := (input.rs1.isInfinity && input.rs2.isInfinity && (input.rs1.sign ^ input.rs2.sign))
output.forceNan := input.rs1.isNan || input.rs2.isNan || output.infinityNan
output.xyMantissaZero := xyMantissa === 0
}
val result = new Area {
val input = norm.output.stage()
val output = input.swapPayload(new MergeInput())
import input.payload._
val output = input.swapPayload(MergeInput())
output.source := input.source
output.lockId := input.lockId
output.rd := input.rd
output.value.sign := norm.xySign
output.value.mantissa := (norm.mantissa >> 2).resized
output.value.exponent := norm.exponent.resized
output.value.special := False
output.roundMode := input.roundMode
if(p.withDouble) output.format := input.format
output.scrap := (norm.mantissa(1) | norm.mantissa(0) | shifter.roundingScrap)
output.source := input.source
output.lockId := input.lockId
output.rd := input.rd
output.value.sign := xySign
output.value.mantissa := (mantissa >> 2).resized
output.value.exponent := exponent.resized
output.value.special := False
output.roundMode := input.roundMode
if (p.withDouble) output.format := input.format
output.scrap := (mantissa(1) | mantissa(0) | roundingScrap)
val flag = io.port(input.source).completion.flag
flag.NV setWhen(input.valid && (norm.infinityNan || input.rs1.isNanSignaling || input.rs2.isNanSignaling))
when(norm.forceNan) {
output.value.setNanQuiet
} elsewhen(norm.forceZero) {
output.value.setZero
when(norm.xyMantissa === 0 || input.rs1.isZero && input.rs2.isZero){
output.value.sign := input.rs1.sign && input.rs2.sign
val flag = io.port(input.source).completion.flag
flag.NV setWhen (input.valid && (infinityNan || input.rs1.isNanSignaling || input.rs2.isNanSignaling))
when(forceNan) {
output.value.setNanQuiet
} elsewhen (forceZero) {
output.value.setZero
when(xyMantissaZero || input.rs1.isZero && input.rs2.isZero) {
output.value.sign := input.rs1.sign && input.rs2.sign
}
when((input.rs1.sign || input.rs2.sign) && input.roundMode === FpuRoundMode.RDN) {
output.value.sign := True
}
} elsewhen (forceInfinity) {
output.value.setInfinity
}
when((input.rs1.sign || input.rs2.sign) && input.roundMode === FpuRoundMode.RDN){
output.value.sign := True
}
} elsewhen(norm.forceInfinity) {
output.value.setInfinity
}
}
@ -1175,37 +1248,55 @@ case class FpuCore( portCount : Int, p : FpuParameter) extends Component{
val merge = new Area {
//TODO maybe load can bypass merge and round.
val inputs = ArrayBuffer[Stream[MergeInput]]()
inputs += load.s1.output
if(p.withAdd) (inputs += add.output)
if(p.withMul) (inputs += mul.output)
inputs += load.s1.output.stage()
if(p.withAdd) (inputs += add.result.output)
if(p.withMul) (inputs += mul.result.output)
if(p.withShortPipMisc) (inputs += shortPip.rfOutput)
val arbitrated = StreamArbiterFactory.lowerFirst.noLock.on(inputs)
val isCommited = rf.lock.map(_.commited).read(arbitrated.lockId)
val commited = arbitrated.haltWhen(!isCommited).toFlow
}
val round = new Area{
val input = merge.commited.combStage
class RoundFront extends MergeInput{
val mantissaIncrement = Bool()
val roundAdjusted = Bits(2 bits)
val exactMask = UInt(p.internalMantissaSize + 2 bits)
}
val roundFront = new Area {
val input = merge.commited.stage()
val output = input.swapPayload(new RoundFront())
output.payload.assignSomeByName(input.payload)
val manAggregate = input.value.mantissa @@ input.scrap
val expBase = muxDouble[UInt](input.format)(exponentF64Subnormal+1)(exponentF32Subnormal+1)
val expBase = muxDouble[UInt](input.format)(exponentF64Subnormal + 1)(exponentF32Subnormal + 1)
val expDif = expBase -^ input.value.exponent
val expSubnormal = !expDif.msb
var discardCount = (expSubnormal ? expDif.resize(log2Up(p.internalMantissaSize) bits) | U(0))
if(p.withDouble) when(input.format === FpuFormat.FLOAT){
var discardCount = (expSubnormal ? expDif.resize(log2Up(p.internalMantissaSize) bits) | U(0))
if (p.withDouble) when(input.format === FpuFormat.FLOAT) {
discardCount \= discardCount + 29
}
val exactMask = (List(True) ++ (0 until p.internalMantissaSize+1).map(_ < discardCount)).asBits.asUInt
val roundAdjusted = (True ## (manAggregate>>1))(discardCount) ## ((manAggregate & exactMask) =/= 0)
val exactMask = (List(True) ++ (0 until p.internalMantissaSize + 1).map(_ < discardCount)).asBits.asUInt
val roundAdjusted = (True ## (manAggregate >> 1)) (discardCount) ## ((manAggregate & exactMask) =/= 0)
val mantissaIncrement = !input.value.special && input.roundMode.mux(
FpuRoundMode.RNE -> (roundAdjusted(1) && (roundAdjusted(0) || (U"01" ## (manAggregate>>2))(discardCount))),
FpuRoundMode.RNE -> (roundAdjusted(1) && (roundAdjusted(0) || (U"01" ## (manAggregate >> 2)) (discardCount))),
FpuRoundMode.RTZ -> False,
FpuRoundMode.RDN -> (roundAdjusted =/= 0 && input.value.sign),
FpuRoundMode.RDN -> (roundAdjusted =/= 0 && input.value.sign),
FpuRoundMode.RUP -> (roundAdjusted =/= 0 && !input.value.sign),
FpuRoundMode.RMM -> (roundAdjusted(1))
)
output.mantissaIncrement := mantissaIncrement
output.roundAdjusted := roundAdjusted
output.exactMask := exactMask
}
val roundBack = new Area{
val input = roundFront.output.stage()
val output = input.swapPayload(RoundOutput())
import input.payload._
val math = p.internalFloating()
val mantissaRange = p.internalMantissaSize downto 1
val adderMantissa = input.value.mantissa(mantissaRange) & (mantissaIncrement ? ~(exactMask.trim(1) >> 1) | input.value.mantissa(mantissaRange).maxValue)
@ -1218,12 +1309,6 @@ case class FpuCore( portCount : Int, p : FpuParameter) extends Component{
val patched = CombInit(math)
val nx,of,uf = False
// val ufPatch = input.roundMode === FpuRoundMode.RUP && !input.value.sign && !input.scrap|| input.roundMode === FpuRoundMode.RDN && input.value.sign && !input.scrap
// when(!math.special && (input.value.exponent <= exponentOne-127 && (math.exponent =/= exponentOne-126 || !input.value.mantissa.lsb || ufPatch)) && roundAdjusted.asUInt =/= 0){
// uf := True
// }
val ufSubnormalThreshold = muxDouble[UInt](input.format)(exponentF64Subnormal)(exponentF32Subnormal)
val ufThreshold = muxDouble[UInt](input.format)(exponentF64Subnormal-52+1)(exponentF32Subnormal-23+1)
@ -1277,7 +1362,6 @@ case class FpuCore( portCount : Int, p : FpuParameter) extends Component{
flag.OF setWhen(of)
flag.UF setWhen(uf)
}
val output = input.swapPayload(RoundOutput())
output.source := input.source
output.lockId := input.lockId
output.rd := input.rd
@ -1286,7 +1370,7 @@ case class FpuCore( portCount : Int, p : FpuParameter) extends Component{
}
val writeback = new Area{
val input = round.output.combStage
val input = roundBack.output.stage()
for(i <- 0 until portCount){
completion(i).increments += (RegNext(input.fire && input.source === i) init(False))
@ -1393,12 +1477,7 @@ object FpuSynthesisBench extends App{
})
}
//Fpu_32 ->
//Artix 7 -> 46 Mhz 1786 LUT 628 FF
//Artix 7 -> 47 Mhz 1901 LUT 628 FF
//Fpu_64 ->
//Artix 7 -> 37 Mhz 3407 LUT 1006 FF
//Artix 7 -> 36 Mhz 3564 LUT 1006 FF
val rtls = ArrayBuffer[Rtl]()
rtls += new Fpu(

View File

@ -109,6 +109,8 @@ object FpuRoundModeInstr extends SpinalEnum(){
case class FpuParameter( withDouble : Boolean,
mulWidthA : Int = 18,
mulWidthB : Int = 18,
sim : Boolean = false,
withAdd : Boolean = true,
withMul : Boolean = true,