diff --git a/src/main/scala/vexriscv/ip/fpu/FpuCore.scala b/src/main/scala/vexriscv/ip/fpu/FpuCore.scala index b93dd0f..8f7657a 100644 --- a/src/main/scala/vexriscv/ip/fpu/FpuCore.scala +++ b/src/main/scala/vexriscv/ip/fpu/FpuCore.scala @@ -6,6 +6,10 @@ import spinal.lib.eda.bench.{Bench, Rtl, XilinxStdTargets} import scala.collection.mutable.ArrayBuffer +object FpuDivSqrtIterationState extends SpinalEnum{ + val IDLE, YY, XYY, Y2_XYY, DIV, Y_15_XYY2, Y_15_XYY2_RESULT, SQRT = newElement() +} + case class FpuCore(p : FpuParameter) extends Component{ val io = new Bundle { val port = slave(FpuPort(p)) @@ -51,9 +55,19 @@ case class FpuCore(p : FpuParameter) extends Component{ val rd = p.rfAddress() val lockId = lockIdType() val add = Bool() + val divSqrt = Bool() + val msb1, msb2 = Bool() //allow usage of msb bits of mul val minus = Bool() } + case class DivSqrtInput() extends Bundle{ + val source = p.source() + val rs1, rs2 = p.internalFloating() + val rd = p.rfAddress() + val lockId = lockIdType() + val div = Bool() + } + case class AddInput() extends Bundle{ val source = p.source() val rs1, rs2 = p.internalFloating() @@ -103,6 +117,11 @@ case class FpuCore(p : FpuParameter) extends Component{ useRs2 := True useRd := True } + is(p.Opcode.DIV_SQRT){ + useRs1 := True + useRs2 := True //TODO + useRd := True + } is(p.Opcode.FMA){ useRs1 := True useRs2 := True @@ -155,20 +174,35 @@ case class FpuCore(p : FpuParameter) extends Component{ store.source := read.output.source store.rs2 := read.output.rs2 + val divSqrtHit = input.opcode === p.Opcode.DIV_SQRT + val divSqrt = Stream(DivSqrtInput()) + input.ready setWhen(divSqrtHit && divSqrt.ready) + divSqrt.valid := input.valid && divSqrtHit + divSqrt.source := read.output.source + divSqrt.rs1 := read.output.rs1 + divSqrt.rs2 := read.output.rs2 + divSqrt.rd := read.output.rd + divSqrt.lockId := read.output.lockId + divSqrt.div := True //TODO val fmaHit = input.opcode === p.Opcode.FMA val mulHit = input.opcode === p.Opcode.MUL || fmaHit val mul = Stream(MulInput()) - input.ready setWhen(mulHit && mul.ready) - mul.valid := input.valid && mulHit - mul.source := read.output.source - mul.rs1 := read.output.rs1 - mul.rs2 := read.output.rs2 - mul.rs3 := read.output.rs3 - mul.rd := read.output.rd - mul.lockId := read.output.lockId - mul.add := fmaHit - mul.minus := False //TODO + val divSqrtToMul = Stream(MulInput()) + + input.ready setWhen(mulHit && mul.ready && !divSqrtToMul.valid) + mul.valid := input.valid && mulHit || divSqrtToMul.valid + + divSqrtToMul.ready := mul.ready + mul.payload := divSqrtToMul.payload + when(!divSqrtToMul.valid) { + mul.payload.assignSomeByName(read.output.payload) + mul.add := fmaHit + mul.divSqrt := False + mul.msb1 := True + mul.msb2 := True + mul.minus := False //TODO + } val addHit = input.opcode === p.Opcode.ADD val add = Stream(AddInput()) @@ -183,8 +217,6 @@ case class FpuCore(p : FpuParameter) extends Component{ when(!mulToAdd.valid) { add.payload.assignSomeByName(read.output.payload) } - - } val load = new Area{ @@ -209,8 +241,8 @@ case class FpuCore(p : FpuParameter) extends Component{ val input = decode.mul.stage() val math = new Area { - val mulA = U"1" @@ input.rs1.mantissa - val mulB = U"1" @@ input.rs2.mantissa + val mulA = U(input.msb1) @@ input.rs1.mantissa + val mulB = U(input.msb2) @@ input.rs2.mantissa val mulC = mulA * mulB val exp = input.rs1.exponent +^ input.rs2.exponent - ((1 << p.internalExponentSize - 1) - 1) } @@ -226,8 +258,14 @@ case class FpuCore(p : FpuParameter) extends Component{ output.mantissa := man } + val notMul = new Area{ + val output = Flow(UInt(p.internalMantissaSize + 1 bits)) + output.valid := input.valid && input.divSqrt + output.payload := math.mulC(p.internalMantissaSize, p.internalMantissaSize+1 bits) + } + val output = Stream(WriteInput()) - output.valid := input.valid && !input.add + output.valid := input.valid && !input.add && !input.divSqrt output.source := input.source output.lockId := input.lockId output.rd := input.rd @@ -242,7 +280,164 @@ case class FpuCore(p : FpuParameter) extends Component{ decode.mulToAdd.rd := input.rd decode.mulToAdd.lockId := input.lockId - input.ready := (input.add ? decode.mulToAdd.ready | output.ready) + input.ready := (input.add ? decode.mulToAdd.ready | output.ready) || input.divSqrt + } + + val divSqrt = new Area { + val input = decode.divSqrt.stage() + + val aproxWidth = 8 + val aproxDepth = 64 + val divIterationCount = 3 + val sqrtIterationCount = 3 + + val mulWidth = p.internalMantissaSize + 1 + + import FpuDivSqrtIterationState._ + val state = RegInit(FpuDivSqrtIterationState.IDLE()) + val iteration = Reg(UInt(log2Up(divIterationCount max sqrtIterationCount) bits)) + + decode.divSqrtToMul.valid := False + decode.divSqrtToMul.source := input.source + decode.divSqrtToMul.rs1.assignDontCare() + decode.divSqrtToMul.rs2.assignDontCare() + decode.divSqrtToMul.rs3.assignDontCare() + decode.divSqrtToMul.rd := input.rd + decode.divSqrtToMul.lockId := input.lockId + decode.divSqrtToMul.add := False + decode.divSqrtToMul.divSqrt := True + decode.divSqrtToMul.msb1 := True + decode.divSqrtToMul.msb2 := True + decode.divSqrtToMul.minus := False + + + val aprox = new Area { + val rom = Mem(UInt(aproxWidth bits), aproxDepth * 2) + val divTable, sqrtTable = ArrayBuffer[Double]() + for(i <- 0 until aproxDepth){ + val mantissa = 1+(i+0.5)/aproxDepth + divTable += 1/mantissa + sqrtTable += 1/Math.sqrt(mantissa) + } + val romElaboration = (sqrtTable ++ divTable).map(v => BigInt(((v-0.5)*2*(1 << aproxWidth)).round)) + + rom.initBigInt(romElaboration) + val address = U(input.div ## (input.div ? input.rs2.mantissa | input.rs1.mantissa).takeHigh(log2Up(aproxDepth))) + val raw = rom.readAsync(address) + val result = U"01" @@ (raw << (mulWidth-aproxWidth-2)) + } + + val divExp = new Area{ + val value = (1 << p.internalExponentSize) - 3 - input.rs2.exponent + } + val sqrtExp = new Area{ + val value = ((1 << p.internalExponentSize-1) + (1 << p.internalExponentSize-2) - 2) - (input.rs2.exponent >> 1) + input.rs2.exponent.lsb.asUInt + } + + def mulArg(rs1 : UInt, rs2 : UInt): Unit ={ + decode.divSqrtToMul.rs1.mantissa := rs1.resized + decode.divSqrtToMul.rs2.mantissa := rs2.resized + decode.divSqrtToMul.msb1 := rs1.msb + decode.divSqrtToMul.msb2 := rs2.msb + } + + val mulBuffer = mul.notMul.output.toStream.stage + mulBuffer.ready := False + + val iterationValue = Reg(UInt(mulWidth bits)) + //val squareInput = (iteration === 0) ? aprox.result | iterationValue + + input.ready := False + switch(state){ + is(IDLE){ + iterationValue := aprox.result + iteration := 0 + when(input.valid) { + state := YY + } + } + is(YY){ + decode.divSqrtToMul.valid := True + mulArg(iterationValue, iterationValue) + when(decode.divSqrtToMul.ready) { + state := XYY + } + } + is(XYY){ + decode.divSqrtToMul.valid := mulBuffer.valid + mulArg(U"1" @@ (input.div ? input.rs2.mantissa | input.rs1.mantissa), mulBuffer.payload) + when(mulBuffer.valid && decode.divSqrtToMul.ready) { + state := (input.div ? Y2_XYY | Y_15_XYY2) + mulBuffer.ready := input.div + when(!input.div){ + mulBuffer.payload.getDrivingReg := (U"11" << mulWidth-2) - (mulBuffer.payload >> 1) + } + } + } + is(Y2_XYY){ + mulBuffer.ready := True + when(mulBuffer.valid) { + iterationValue := ((iterationValue << 1) - mulBuffer.payload).resized + mulBuffer.ready := True + iteration := iteration + 1 + when(iteration =/= divIterationCount-1){ //TODO + state := YY + } otherwise { + state := DIV + } + } + } + is(DIV){ + decode.divSqrtToMul.valid := True + decode.divSqrtToMul.divSqrt := False + decode.divSqrtToMul.rs1 := input.rs1 + decode.divSqrtToMul.rs2.sign := input.rs2.sign + decode.divSqrtToMul.rs2.exponent := divExp.value + iterationValue.msb.asUInt + decode.divSqrtToMul.rs2.mantissa := (iterationValue << 1).resized + when(decode.divSqrtToMul.ready) { + state := IDLE + input.ready := True + } + } + is(Y_15_XYY2){ + decode.divSqrtToMul.valid := True + mulArg(U"1" @@ input.rs1.mantissa, mulBuffer.payload) + when(decode.divSqrtToMul.ready) { + mulBuffer.ready := True + state := SQRT + } + } + is(Y_15_XYY2_RESULT){ + when(iteration =/= sqrtIterationCount-1 && !input.rs1.exponent.lsb) { + iterationValue := mulBuffer.payload + } otherwise { + val v = 1.0/Math.sqrt(2.0) + val scaled = v* (BigInt(1) << mulWidth-1).toDouble + val bigInt = BigDecimal(scaled).toBigInt() + iterationValue := mulBuffer.payload + U(bigInt) + } + mulBuffer.ready := True + when(mulBuffer.valid) { + when(iteration =/= sqrtIterationCount-1){ + state := YY + } otherwise { + state := SQRT + } + } + } + is(SQRT){ + decode.divSqrtToMul.valid := True + decode.divSqrtToMul.divSqrt := False + decode.divSqrtToMul.rs1 := input.rs1 + decode.divSqrtToMul.rs2.sign := False + decode.divSqrtToMul.rs2.exponent := sqrtExp.value + iterationValue.msb.asUInt + decode.divSqrtToMul.rs2.mantissa := (iterationValue << 1).resized + when(decode.divSqrtToMul.ready) { + state := IDLE + input.ready := True + } + } + } } val add = new Area{ @@ -349,7 +544,7 @@ case class FpuCore(p : FpuParameter) extends Component{ -object StreamFifoMultiChannelBench extends App{ +object FpuSynthesisBench extends App{ val payloadType = HardType(Bits(8 bits)) class Fpu(name : String, p : FpuParameter) extends Rtl{ override def getName(): String = "Fpu_" + name diff --git a/src/main/scala/vexriscv/ip/fpu/Interface.scala b/src/main/scala/vexriscv/ip/fpu/Interface.scala index a996390..b07c18c 100644 --- a/src/main/scala/vexriscv/ip/fpu/Interface.scala +++ b/src/main/scala/vexriscv/ip/fpu/Interface.scala @@ -23,7 +23,7 @@ case class FpuFloat(exponentSize: Int, } case class FpuOpcode(p : FpuParameter) extends SpinalEnum{ - val LOAD, STORE, MUL, ADD, FMA, I2F, F2I, CMP = newElement() + val LOAD, STORE, MUL, ADD, FMA, I2F, F2I, CMP, DIV_SQRT = newElement() } case class FpuParameter( internalMantissaSize : Int, diff --git a/src/test/scala/vexriscv/ip/fpu/FpuTest.scala b/src/test/scala/vexriscv/ip/fpu/FpuTest.scala index a2cd6d1..68c3e46 100644 --- a/src/test/scala/vexriscv/ip/fpu/FpuTest.scala +++ b/src/test/scala/vexriscv/ip/fpu/FpuTest.scala @@ -96,6 +96,18 @@ class FpuTest extends FunSuite{ } } + def div(rd : Int, rs1 : Int, rs2 : Int): Unit ={ + cmdQueue += {cmd => + cmd.source #= id + cmd.opcode #= cmd.opcode.spinalEnum.DIV_SQRT + cmd.value.randomize() + cmd.rs1 #= rs1 + cmd.rs2 #= rs2 + cmd.rs3.randomize() + cmd.rd #= rd + } + } + def fma(rd : Int, rs1 : Int, rs2 : Int, rs3 : Int): Unit ={ cmdQueue += {cmd => cmd.source #= id @@ -213,9 +225,57 @@ class FpuTest extends FunSuite{ } } -// testAdd(0.1f, 1.6f) -// testMul(0.1f, 1.6f) + + def testDiv(a : Float, b : Float): Unit ={ + val rs = new RegAllocator() + val rs1, rs2, rs3 = rs.allocate() + val rd = Random.nextInt(32) + load(rs1, a) + load(rs2, b) + + div(rd,rs1,rs2) + storeFloat(rd){v => + val ref = a/b + val error = Math.abs(ref-v)/ref + println(f"$a / $b = $v, $ref $error") + assert(checkFloat(ref, v)) + } + } + + def testSqrt(a : Float): Unit ={ + val rs = new RegAllocator() + val rs1, rs2, rs3 = rs.allocate() + val rd = Random.nextInt(32) + load(rs1, a) + + div(rd,rs1,rs2) + storeFloat(rd){v => + val ref = Math.sqrt(a).toFloat + val error = Math.abs(ref-v)/ref + println(f"sqrt($a) = $v, $ref $error") + assert(checkFloat(ref, v)) + } + } + + val b2f = lang.Float.intBitsToFloat(_) + + +// testSqrt(2.25f) +// dut.clockDomain.waitSampling(100) +// simFailure() + + testAdd(0.1f, 1.6f) + testMul(0.1f, 1.6f) testFma(1.1f, 2.2f, 3.0f) + testDiv(1.0f, 1.1f) + testDiv(1.0f, 1.5f) + testDiv(1.0f, 1.9f) + testDiv(1.1f, 1.9f) + testDiv(1.0f, b2f(0x3f7ffffe)) + testDiv(1.0f, b2f(0x3f7fffff)) + testDiv(1.0f, b2f(0x3f800000)) + testDiv(1.0f, b2f(0x3f800001)) + testDiv(1.0f, b2f(0x3f800002)) for(i <- 0 until 1000){ testAdd(randomFloat(), randomFloat()) @@ -226,14 +286,18 @@ class FpuTest extends FunSuite{ for(i <- 0 until 1000){ testFma(randomFloat(), randomFloat(), randomFloat()) } + + for(i <- 0 until 1000){ + testDiv(randomFloat(), randomFloat()) + } for(i <- 0 until 1000){ val tests = ArrayBuffer[() => Unit]() tests += (() =>{testAdd(randomFloat(), randomFloat())}) tests += (() =>{testMul(randomFloat(), randomFloat())}) tests += (() =>{testFma(randomFloat(), randomFloat(), randomFloat())}) + tests += (() =>{testDiv(randomFloat(), randomFloat())}) tests.randomPick().apply() } - waitUntil(cpu.rspQueue.isEmpty) } diff --git a/src/test/scala/vexriscv/ip/fpu/Playground.scala b/src/test/scala/vexriscv/ip/fpu/Playground.scala index 2fb99ad..71b500d 100644 --- a/src/test/scala/vexriscv/ip/fpu/Playground.scala +++ b/src/test/scala/vexriscv/ip/fpu/Playground.scala @@ -3,8 +3,6 @@ package vexriscv.ip.fpu object MiaouDiv extends App{ val input = 2.5 var output = 1/(input*0.95) -// def x = output -// def y = input def y = output def x = input @@ -29,7 +27,7 @@ object MiaouSqrt extends App{ def x = input for(i <- 0 until 10) { - output = y*(1.5-x*y*y/2) + output = y * (1.5 - x * y * y / 2) println(output) }