fpu div functional, sqrt wip

This commit is contained in:
Dolu1990 2021-01-14 15:56:56 +01:00
parent 8761d0d9ee
commit 85dd5dbf8e
4 changed files with 281 additions and 24 deletions

View File

@ -6,6 +6,10 @@ import spinal.lib.eda.bench.{Bench, Rtl, XilinxStdTargets}
import scala.collection.mutable.ArrayBuffer
object FpuDivSqrtIterationState extends SpinalEnum{
val IDLE, YY, XYY, Y2_XYY, DIV, Y_15_XYY2, Y_15_XYY2_RESULT, SQRT = newElement()
}
case class FpuCore(p : FpuParameter) extends Component{
val io = new Bundle {
val port = slave(FpuPort(p))
@ -51,9 +55,19 @@ case class FpuCore(p : FpuParameter) extends Component{
val rd = p.rfAddress()
val lockId = lockIdType()
val add = Bool()
val divSqrt = Bool()
val msb1, msb2 = Bool() //allow usage of msb bits of mul
val minus = Bool()
}
case class DivSqrtInput() extends Bundle{
val source = p.source()
val rs1, rs2 = p.internalFloating()
val rd = p.rfAddress()
val lockId = lockIdType()
val div = Bool()
}
case class AddInput() extends Bundle{
val source = p.source()
val rs1, rs2 = p.internalFloating()
@ -103,6 +117,11 @@ case class FpuCore(p : FpuParameter) extends Component{
useRs2 := True
useRd := True
}
is(p.Opcode.DIV_SQRT){
useRs1 := True
useRs2 := True //TODO
useRd := True
}
is(p.Opcode.FMA){
useRs1 := True
useRs2 := True
@ -155,20 +174,35 @@ case class FpuCore(p : FpuParameter) extends Component{
store.source := read.output.source
store.rs2 := read.output.rs2
val divSqrtHit = input.opcode === p.Opcode.DIV_SQRT
val divSqrt = Stream(DivSqrtInput())
input.ready setWhen(divSqrtHit && divSqrt.ready)
divSqrt.valid := input.valid && divSqrtHit
divSqrt.source := read.output.source
divSqrt.rs1 := read.output.rs1
divSqrt.rs2 := read.output.rs2
divSqrt.rd := read.output.rd
divSqrt.lockId := read.output.lockId
divSqrt.div := True //TODO
val fmaHit = input.opcode === p.Opcode.FMA
val mulHit = input.opcode === p.Opcode.MUL || fmaHit
val mul = Stream(MulInput())
input.ready setWhen(mulHit && mul.ready)
mul.valid := input.valid && mulHit
mul.source := read.output.source
mul.rs1 := read.output.rs1
mul.rs2 := read.output.rs2
mul.rs3 := read.output.rs3
mul.rd := read.output.rd
mul.lockId := read.output.lockId
mul.add := fmaHit
mul.minus := False //TODO
val divSqrtToMul = Stream(MulInput())
input.ready setWhen(mulHit && mul.ready && !divSqrtToMul.valid)
mul.valid := input.valid && mulHit || divSqrtToMul.valid
divSqrtToMul.ready := mul.ready
mul.payload := divSqrtToMul.payload
when(!divSqrtToMul.valid) {
mul.payload.assignSomeByName(read.output.payload)
mul.add := fmaHit
mul.divSqrt := False
mul.msb1 := True
mul.msb2 := True
mul.minus := False //TODO
}
val addHit = input.opcode === p.Opcode.ADD
val add = Stream(AddInput())
@ -183,8 +217,6 @@ case class FpuCore(p : FpuParameter) extends Component{
when(!mulToAdd.valid) {
add.payload.assignSomeByName(read.output.payload)
}
}
val load = new Area{
@ -209,8 +241,8 @@ case class FpuCore(p : FpuParameter) extends Component{
val input = decode.mul.stage()
val math = new Area {
val mulA = U"1" @@ input.rs1.mantissa
val mulB = U"1" @@ input.rs2.mantissa
val mulA = U(input.msb1) @@ input.rs1.mantissa
val mulB = U(input.msb2) @@ input.rs2.mantissa
val mulC = mulA * mulB
val exp = input.rs1.exponent +^ input.rs2.exponent - ((1 << p.internalExponentSize - 1) - 1)
}
@ -226,8 +258,14 @@ case class FpuCore(p : FpuParameter) extends Component{
output.mantissa := man
}
val notMul = new Area{
val output = Flow(UInt(p.internalMantissaSize + 1 bits))
output.valid := input.valid && input.divSqrt
output.payload := math.mulC(p.internalMantissaSize, p.internalMantissaSize+1 bits)
}
val output = Stream(WriteInput())
output.valid := input.valid && !input.add
output.valid := input.valid && !input.add && !input.divSqrt
output.source := input.source
output.lockId := input.lockId
output.rd := input.rd
@ -242,7 +280,164 @@ case class FpuCore(p : FpuParameter) extends Component{
decode.mulToAdd.rd := input.rd
decode.mulToAdd.lockId := input.lockId
input.ready := (input.add ? decode.mulToAdd.ready | output.ready)
input.ready := (input.add ? decode.mulToAdd.ready | output.ready) || input.divSqrt
}
val divSqrt = new Area {
val input = decode.divSqrt.stage()
val aproxWidth = 8
val aproxDepth = 64
val divIterationCount = 3
val sqrtIterationCount = 3
val mulWidth = p.internalMantissaSize + 1
import FpuDivSqrtIterationState._
val state = RegInit(FpuDivSqrtIterationState.IDLE())
val iteration = Reg(UInt(log2Up(divIterationCount max sqrtIterationCount) bits))
decode.divSqrtToMul.valid := False
decode.divSqrtToMul.source := input.source
decode.divSqrtToMul.rs1.assignDontCare()
decode.divSqrtToMul.rs2.assignDontCare()
decode.divSqrtToMul.rs3.assignDontCare()
decode.divSqrtToMul.rd := input.rd
decode.divSqrtToMul.lockId := input.lockId
decode.divSqrtToMul.add := False
decode.divSqrtToMul.divSqrt := True
decode.divSqrtToMul.msb1 := True
decode.divSqrtToMul.msb2 := True
decode.divSqrtToMul.minus := False
val aprox = new Area {
val rom = Mem(UInt(aproxWidth bits), aproxDepth * 2)
val divTable, sqrtTable = ArrayBuffer[Double]()
for(i <- 0 until aproxDepth){
val mantissa = 1+(i+0.5)/aproxDepth
divTable += 1/mantissa
sqrtTable += 1/Math.sqrt(mantissa)
}
val romElaboration = (sqrtTable ++ divTable).map(v => BigInt(((v-0.5)*2*(1 << aproxWidth)).round))
rom.initBigInt(romElaboration)
val address = U(input.div ## (input.div ? input.rs2.mantissa | input.rs1.mantissa).takeHigh(log2Up(aproxDepth)))
val raw = rom.readAsync(address)
val result = U"01" @@ (raw << (mulWidth-aproxWidth-2))
}
val divExp = new Area{
val value = (1 << p.internalExponentSize) - 3 - input.rs2.exponent
}
val sqrtExp = new Area{
val value = ((1 << p.internalExponentSize-1) + (1 << p.internalExponentSize-2) - 2) - (input.rs2.exponent >> 1) + input.rs2.exponent.lsb.asUInt
}
def mulArg(rs1 : UInt, rs2 : UInt): Unit ={
decode.divSqrtToMul.rs1.mantissa := rs1.resized
decode.divSqrtToMul.rs2.mantissa := rs2.resized
decode.divSqrtToMul.msb1 := rs1.msb
decode.divSqrtToMul.msb2 := rs2.msb
}
val mulBuffer = mul.notMul.output.toStream.stage
mulBuffer.ready := False
val iterationValue = Reg(UInt(mulWidth bits))
//val squareInput = (iteration === 0) ? aprox.result | iterationValue
input.ready := False
switch(state){
is(IDLE){
iterationValue := aprox.result
iteration := 0
when(input.valid) {
state := YY
}
}
is(YY){
decode.divSqrtToMul.valid := True
mulArg(iterationValue, iterationValue)
when(decode.divSqrtToMul.ready) {
state := XYY
}
}
is(XYY){
decode.divSqrtToMul.valid := mulBuffer.valid
mulArg(U"1" @@ (input.div ? input.rs2.mantissa | input.rs1.mantissa), mulBuffer.payload)
when(mulBuffer.valid && decode.divSqrtToMul.ready) {
state := (input.div ? Y2_XYY | Y_15_XYY2)
mulBuffer.ready := input.div
when(!input.div){
mulBuffer.payload.getDrivingReg := (U"11" << mulWidth-2) - (mulBuffer.payload >> 1)
}
}
}
is(Y2_XYY){
mulBuffer.ready := True
when(mulBuffer.valid) {
iterationValue := ((iterationValue << 1) - mulBuffer.payload).resized
mulBuffer.ready := True
iteration := iteration + 1
when(iteration =/= divIterationCount-1){ //TODO
state := YY
} otherwise {
state := DIV
}
}
}
is(DIV){
decode.divSqrtToMul.valid := True
decode.divSqrtToMul.divSqrt := False
decode.divSqrtToMul.rs1 := input.rs1
decode.divSqrtToMul.rs2.sign := input.rs2.sign
decode.divSqrtToMul.rs2.exponent := divExp.value + iterationValue.msb.asUInt
decode.divSqrtToMul.rs2.mantissa := (iterationValue << 1).resized
when(decode.divSqrtToMul.ready) {
state := IDLE
input.ready := True
}
}
is(Y_15_XYY2){
decode.divSqrtToMul.valid := True
mulArg(U"1" @@ input.rs1.mantissa, mulBuffer.payload)
when(decode.divSqrtToMul.ready) {
mulBuffer.ready := True
state := SQRT
}
}
is(Y_15_XYY2_RESULT){
when(iteration =/= sqrtIterationCount-1 && !input.rs1.exponent.lsb) {
iterationValue := mulBuffer.payload
} otherwise {
val v = 1.0/Math.sqrt(2.0)
val scaled = v* (BigInt(1) << mulWidth-1).toDouble
val bigInt = BigDecimal(scaled).toBigInt()
iterationValue := mulBuffer.payload + U(bigInt)
}
mulBuffer.ready := True
when(mulBuffer.valid) {
when(iteration =/= sqrtIterationCount-1){
state := YY
} otherwise {
state := SQRT
}
}
}
is(SQRT){
decode.divSqrtToMul.valid := True
decode.divSqrtToMul.divSqrt := False
decode.divSqrtToMul.rs1 := input.rs1
decode.divSqrtToMul.rs2.sign := False
decode.divSqrtToMul.rs2.exponent := sqrtExp.value + iterationValue.msb.asUInt
decode.divSqrtToMul.rs2.mantissa := (iterationValue << 1).resized
when(decode.divSqrtToMul.ready) {
state := IDLE
input.ready := True
}
}
}
}
val add = new Area{
@ -349,7 +544,7 @@ case class FpuCore(p : FpuParameter) extends Component{
object StreamFifoMultiChannelBench extends App{
object FpuSynthesisBench extends App{
val payloadType = HardType(Bits(8 bits))
class Fpu(name : String, p : FpuParameter) extends Rtl{
override def getName(): String = "Fpu_" + name

View File

@ -23,7 +23,7 @@ case class FpuFloat(exponentSize: Int,
}
case class FpuOpcode(p : FpuParameter) extends SpinalEnum{
val LOAD, STORE, MUL, ADD, FMA, I2F, F2I, CMP = newElement()
val LOAD, STORE, MUL, ADD, FMA, I2F, F2I, CMP, DIV_SQRT = newElement()
}
case class FpuParameter( internalMantissaSize : Int,

View File

@ -96,6 +96,18 @@ class FpuTest extends FunSuite{
}
}
def div(rd : Int, rs1 : Int, rs2 : Int): Unit ={
cmdQueue += {cmd =>
cmd.source #= id
cmd.opcode #= cmd.opcode.spinalEnum.DIV_SQRT
cmd.value.randomize()
cmd.rs1 #= rs1
cmd.rs2 #= rs2
cmd.rs3.randomize()
cmd.rd #= rd
}
}
def fma(rd : Int, rs1 : Int, rs2 : Int, rs3 : Int): Unit ={
cmdQueue += {cmd =>
cmd.source #= id
@ -213,9 +225,57 @@ class FpuTest extends FunSuite{
}
}
// testAdd(0.1f, 1.6f)
// testMul(0.1f, 1.6f)
def testDiv(a : Float, b : Float): Unit ={
val rs = new RegAllocator()
val rs1, rs2, rs3 = rs.allocate()
val rd = Random.nextInt(32)
load(rs1, a)
load(rs2, b)
div(rd,rs1,rs2)
storeFloat(rd){v =>
val ref = a/b
val error = Math.abs(ref-v)/ref
println(f"$a / $b = $v, $ref $error")
assert(checkFloat(ref, v))
}
}
def testSqrt(a : Float): Unit ={
val rs = new RegAllocator()
val rs1, rs2, rs3 = rs.allocate()
val rd = Random.nextInt(32)
load(rs1, a)
div(rd,rs1,rs2)
storeFloat(rd){v =>
val ref = Math.sqrt(a).toFloat
val error = Math.abs(ref-v)/ref
println(f"sqrt($a) = $v, $ref $error")
assert(checkFloat(ref, v))
}
}
val b2f = lang.Float.intBitsToFloat(_)
// testSqrt(2.25f)
// dut.clockDomain.waitSampling(100)
// simFailure()
testAdd(0.1f, 1.6f)
testMul(0.1f, 1.6f)
testFma(1.1f, 2.2f, 3.0f)
testDiv(1.0f, 1.1f)
testDiv(1.0f, 1.5f)
testDiv(1.0f, 1.9f)
testDiv(1.1f, 1.9f)
testDiv(1.0f, b2f(0x3f7ffffe))
testDiv(1.0f, b2f(0x3f7fffff))
testDiv(1.0f, b2f(0x3f800000))
testDiv(1.0f, b2f(0x3f800001))
testDiv(1.0f, b2f(0x3f800002))
for(i <- 0 until 1000){
testAdd(randomFloat(), randomFloat())
@ -226,14 +286,18 @@ class FpuTest extends FunSuite{
for(i <- 0 until 1000){
testFma(randomFloat(), randomFloat(), randomFloat())
}
for(i <- 0 until 1000){
testDiv(randomFloat(), randomFloat())
}
for(i <- 0 until 1000){
val tests = ArrayBuffer[() => Unit]()
tests += (() =>{testAdd(randomFloat(), randomFloat())})
tests += (() =>{testMul(randomFloat(), randomFloat())})
tests += (() =>{testFma(randomFloat(), randomFloat(), randomFloat())})
tests += (() =>{testDiv(randomFloat(), randomFloat())})
tests.randomPick().apply()
}
waitUntil(cpu.rspQueue.isEmpty)
}

View File

@ -3,8 +3,6 @@ package vexriscv.ip.fpu
object MiaouDiv extends App{
val input = 2.5
var output = 1/(input*0.95)
// def x = output
// def y = input
def y = output
def x = input
@ -29,7 +27,7 @@ object MiaouSqrt extends App{
def x = input
for(i <- 0 until 10) {
output = y*(1.5-x*y*y/2)
output = y * (1.5 - x * y * y / 2)
println(output)
}