fpu div functional, sqrt wip
This commit is contained in:
parent
8761d0d9ee
commit
85dd5dbf8e
|
@ -6,6 +6,10 @@ import spinal.lib.eda.bench.{Bench, Rtl, XilinxStdTargets}
|
|||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
object FpuDivSqrtIterationState extends SpinalEnum{
|
||||
val IDLE, YY, XYY, Y2_XYY, DIV, Y_15_XYY2, Y_15_XYY2_RESULT, SQRT = newElement()
|
||||
}
|
||||
|
||||
case class FpuCore(p : FpuParameter) extends Component{
|
||||
val io = new Bundle {
|
||||
val port = slave(FpuPort(p))
|
||||
|
@ -51,9 +55,19 @@ case class FpuCore(p : FpuParameter) extends Component{
|
|||
val rd = p.rfAddress()
|
||||
val lockId = lockIdType()
|
||||
val add = Bool()
|
||||
val divSqrt = Bool()
|
||||
val msb1, msb2 = Bool() //allow usage of msb bits of mul
|
||||
val minus = Bool()
|
||||
}
|
||||
|
||||
case class DivSqrtInput() extends Bundle{
|
||||
val source = p.source()
|
||||
val rs1, rs2 = p.internalFloating()
|
||||
val rd = p.rfAddress()
|
||||
val lockId = lockIdType()
|
||||
val div = Bool()
|
||||
}
|
||||
|
||||
case class AddInput() extends Bundle{
|
||||
val source = p.source()
|
||||
val rs1, rs2 = p.internalFloating()
|
||||
|
@ -103,6 +117,11 @@ case class FpuCore(p : FpuParameter) extends Component{
|
|||
useRs2 := True
|
||||
useRd := True
|
||||
}
|
||||
is(p.Opcode.DIV_SQRT){
|
||||
useRs1 := True
|
||||
useRs2 := True //TODO
|
||||
useRd := True
|
||||
}
|
||||
is(p.Opcode.FMA){
|
||||
useRs1 := True
|
||||
useRs2 := True
|
||||
|
@ -155,20 +174,35 @@ case class FpuCore(p : FpuParameter) extends Component{
|
|||
store.source := read.output.source
|
||||
store.rs2 := read.output.rs2
|
||||
|
||||
val divSqrtHit = input.opcode === p.Opcode.DIV_SQRT
|
||||
val divSqrt = Stream(DivSqrtInput())
|
||||
input.ready setWhen(divSqrtHit && divSqrt.ready)
|
||||
divSqrt.valid := input.valid && divSqrtHit
|
||||
divSqrt.source := read.output.source
|
||||
divSqrt.rs1 := read.output.rs1
|
||||
divSqrt.rs2 := read.output.rs2
|
||||
divSqrt.rd := read.output.rd
|
||||
divSqrt.lockId := read.output.lockId
|
||||
divSqrt.div := True //TODO
|
||||
|
||||
val fmaHit = input.opcode === p.Opcode.FMA
|
||||
val mulHit = input.opcode === p.Opcode.MUL || fmaHit
|
||||
val mul = Stream(MulInput())
|
||||
input.ready setWhen(mulHit && mul.ready)
|
||||
mul.valid := input.valid && mulHit
|
||||
mul.source := read.output.source
|
||||
mul.rs1 := read.output.rs1
|
||||
mul.rs2 := read.output.rs2
|
||||
mul.rs3 := read.output.rs3
|
||||
mul.rd := read.output.rd
|
||||
mul.lockId := read.output.lockId
|
||||
val divSqrtToMul = Stream(MulInput())
|
||||
|
||||
input.ready setWhen(mulHit && mul.ready && !divSqrtToMul.valid)
|
||||
mul.valid := input.valid && mulHit || divSqrtToMul.valid
|
||||
|
||||
divSqrtToMul.ready := mul.ready
|
||||
mul.payload := divSqrtToMul.payload
|
||||
when(!divSqrtToMul.valid) {
|
||||
mul.payload.assignSomeByName(read.output.payload)
|
||||
mul.add := fmaHit
|
||||
mul.divSqrt := False
|
||||
mul.msb1 := True
|
||||
mul.msb2 := True
|
||||
mul.minus := False //TODO
|
||||
}
|
||||
|
||||
val addHit = input.opcode === p.Opcode.ADD
|
||||
val add = Stream(AddInput())
|
||||
|
@ -183,8 +217,6 @@ case class FpuCore(p : FpuParameter) extends Component{
|
|||
when(!mulToAdd.valid) {
|
||||
add.payload.assignSomeByName(read.output.payload)
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
val load = new Area{
|
||||
|
@ -209,8 +241,8 @@ case class FpuCore(p : FpuParameter) extends Component{
|
|||
val input = decode.mul.stage()
|
||||
|
||||
val math = new Area {
|
||||
val mulA = U"1" @@ input.rs1.mantissa
|
||||
val mulB = U"1" @@ input.rs2.mantissa
|
||||
val mulA = U(input.msb1) @@ input.rs1.mantissa
|
||||
val mulB = U(input.msb2) @@ input.rs2.mantissa
|
||||
val mulC = mulA * mulB
|
||||
val exp = input.rs1.exponent +^ input.rs2.exponent - ((1 << p.internalExponentSize - 1) - 1)
|
||||
}
|
||||
|
@ -226,8 +258,14 @@ case class FpuCore(p : FpuParameter) extends Component{
|
|||
output.mantissa := man
|
||||
}
|
||||
|
||||
val notMul = new Area{
|
||||
val output = Flow(UInt(p.internalMantissaSize + 1 bits))
|
||||
output.valid := input.valid && input.divSqrt
|
||||
output.payload := math.mulC(p.internalMantissaSize, p.internalMantissaSize+1 bits)
|
||||
}
|
||||
|
||||
val output = Stream(WriteInput())
|
||||
output.valid := input.valid && !input.add
|
||||
output.valid := input.valid && !input.add && !input.divSqrt
|
||||
output.source := input.source
|
||||
output.lockId := input.lockId
|
||||
output.rd := input.rd
|
||||
|
@ -242,7 +280,164 @@ case class FpuCore(p : FpuParameter) extends Component{
|
|||
decode.mulToAdd.rd := input.rd
|
||||
decode.mulToAdd.lockId := input.lockId
|
||||
|
||||
input.ready := (input.add ? decode.mulToAdd.ready | output.ready)
|
||||
input.ready := (input.add ? decode.mulToAdd.ready | output.ready) || input.divSqrt
|
||||
}
|
||||
|
||||
val divSqrt = new Area {
|
||||
val input = decode.divSqrt.stage()
|
||||
|
||||
val aproxWidth = 8
|
||||
val aproxDepth = 64
|
||||
val divIterationCount = 3
|
||||
val sqrtIterationCount = 3
|
||||
|
||||
val mulWidth = p.internalMantissaSize + 1
|
||||
|
||||
import FpuDivSqrtIterationState._
|
||||
val state = RegInit(FpuDivSqrtIterationState.IDLE())
|
||||
val iteration = Reg(UInt(log2Up(divIterationCount max sqrtIterationCount) bits))
|
||||
|
||||
decode.divSqrtToMul.valid := False
|
||||
decode.divSqrtToMul.source := input.source
|
||||
decode.divSqrtToMul.rs1.assignDontCare()
|
||||
decode.divSqrtToMul.rs2.assignDontCare()
|
||||
decode.divSqrtToMul.rs3.assignDontCare()
|
||||
decode.divSqrtToMul.rd := input.rd
|
||||
decode.divSqrtToMul.lockId := input.lockId
|
||||
decode.divSqrtToMul.add := False
|
||||
decode.divSqrtToMul.divSqrt := True
|
||||
decode.divSqrtToMul.msb1 := True
|
||||
decode.divSqrtToMul.msb2 := True
|
||||
decode.divSqrtToMul.minus := False
|
||||
|
||||
|
||||
val aprox = new Area {
|
||||
val rom = Mem(UInt(aproxWidth bits), aproxDepth * 2)
|
||||
val divTable, sqrtTable = ArrayBuffer[Double]()
|
||||
for(i <- 0 until aproxDepth){
|
||||
val mantissa = 1+(i+0.5)/aproxDepth
|
||||
divTable += 1/mantissa
|
||||
sqrtTable += 1/Math.sqrt(mantissa)
|
||||
}
|
||||
val romElaboration = (sqrtTable ++ divTable).map(v => BigInt(((v-0.5)*2*(1 << aproxWidth)).round))
|
||||
|
||||
rom.initBigInt(romElaboration)
|
||||
val address = U(input.div ## (input.div ? input.rs2.mantissa | input.rs1.mantissa).takeHigh(log2Up(aproxDepth)))
|
||||
val raw = rom.readAsync(address)
|
||||
val result = U"01" @@ (raw << (mulWidth-aproxWidth-2))
|
||||
}
|
||||
|
||||
val divExp = new Area{
|
||||
val value = (1 << p.internalExponentSize) - 3 - input.rs2.exponent
|
||||
}
|
||||
val sqrtExp = new Area{
|
||||
val value = ((1 << p.internalExponentSize-1) + (1 << p.internalExponentSize-2) - 2) - (input.rs2.exponent >> 1) + input.rs2.exponent.lsb.asUInt
|
||||
}
|
||||
|
||||
def mulArg(rs1 : UInt, rs2 : UInt): Unit ={
|
||||
decode.divSqrtToMul.rs1.mantissa := rs1.resized
|
||||
decode.divSqrtToMul.rs2.mantissa := rs2.resized
|
||||
decode.divSqrtToMul.msb1 := rs1.msb
|
||||
decode.divSqrtToMul.msb2 := rs2.msb
|
||||
}
|
||||
|
||||
val mulBuffer = mul.notMul.output.toStream.stage
|
||||
mulBuffer.ready := False
|
||||
|
||||
val iterationValue = Reg(UInt(mulWidth bits))
|
||||
//val squareInput = (iteration === 0) ? aprox.result | iterationValue
|
||||
|
||||
input.ready := False
|
||||
switch(state){
|
||||
is(IDLE){
|
||||
iterationValue := aprox.result
|
||||
iteration := 0
|
||||
when(input.valid) {
|
||||
state := YY
|
||||
}
|
||||
}
|
||||
is(YY){
|
||||
decode.divSqrtToMul.valid := True
|
||||
mulArg(iterationValue, iterationValue)
|
||||
when(decode.divSqrtToMul.ready) {
|
||||
state := XYY
|
||||
}
|
||||
}
|
||||
is(XYY){
|
||||
decode.divSqrtToMul.valid := mulBuffer.valid
|
||||
mulArg(U"1" @@ (input.div ? input.rs2.mantissa | input.rs1.mantissa), mulBuffer.payload)
|
||||
when(mulBuffer.valid && decode.divSqrtToMul.ready) {
|
||||
state := (input.div ? Y2_XYY | Y_15_XYY2)
|
||||
mulBuffer.ready := input.div
|
||||
when(!input.div){
|
||||
mulBuffer.payload.getDrivingReg := (U"11" << mulWidth-2) - (mulBuffer.payload >> 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
is(Y2_XYY){
|
||||
mulBuffer.ready := True
|
||||
when(mulBuffer.valid) {
|
||||
iterationValue := ((iterationValue << 1) - mulBuffer.payload).resized
|
||||
mulBuffer.ready := True
|
||||
iteration := iteration + 1
|
||||
when(iteration =/= divIterationCount-1){ //TODO
|
||||
state := YY
|
||||
} otherwise {
|
||||
state := DIV
|
||||
}
|
||||
}
|
||||
}
|
||||
is(DIV){
|
||||
decode.divSqrtToMul.valid := True
|
||||
decode.divSqrtToMul.divSqrt := False
|
||||
decode.divSqrtToMul.rs1 := input.rs1
|
||||
decode.divSqrtToMul.rs2.sign := input.rs2.sign
|
||||
decode.divSqrtToMul.rs2.exponent := divExp.value + iterationValue.msb.asUInt
|
||||
decode.divSqrtToMul.rs2.mantissa := (iterationValue << 1).resized
|
||||
when(decode.divSqrtToMul.ready) {
|
||||
state := IDLE
|
||||
input.ready := True
|
||||
}
|
||||
}
|
||||
is(Y_15_XYY2){
|
||||
decode.divSqrtToMul.valid := True
|
||||
mulArg(U"1" @@ input.rs1.mantissa, mulBuffer.payload)
|
||||
when(decode.divSqrtToMul.ready) {
|
||||
mulBuffer.ready := True
|
||||
state := SQRT
|
||||
}
|
||||
}
|
||||
is(Y_15_XYY2_RESULT){
|
||||
when(iteration =/= sqrtIterationCount-1 && !input.rs1.exponent.lsb) {
|
||||
iterationValue := mulBuffer.payload
|
||||
} otherwise {
|
||||
val v = 1.0/Math.sqrt(2.0)
|
||||
val scaled = v* (BigInt(1) << mulWidth-1).toDouble
|
||||
val bigInt = BigDecimal(scaled).toBigInt()
|
||||
iterationValue := mulBuffer.payload + U(bigInt)
|
||||
}
|
||||
mulBuffer.ready := True
|
||||
when(mulBuffer.valid) {
|
||||
when(iteration =/= sqrtIterationCount-1){
|
||||
state := YY
|
||||
} otherwise {
|
||||
state := SQRT
|
||||
}
|
||||
}
|
||||
}
|
||||
is(SQRT){
|
||||
decode.divSqrtToMul.valid := True
|
||||
decode.divSqrtToMul.divSqrt := False
|
||||
decode.divSqrtToMul.rs1 := input.rs1
|
||||
decode.divSqrtToMul.rs2.sign := False
|
||||
decode.divSqrtToMul.rs2.exponent := sqrtExp.value + iterationValue.msb.asUInt
|
||||
decode.divSqrtToMul.rs2.mantissa := (iterationValue << 1).resized
|
||||
when(decode.divSqrtToMul.ready) {
|
||||
state := IDLE
|
||||
input.ready := True
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val add = new Area{
|
||||
|
@ -349,7 +544,7 @@ case class FpuCore(p : FpuParameter) extends Component{
|
|||
|
||||
|
||||
|
||||
object StreamFifoMultiChannelBench extends App{
|
||||
object FpuSynthesisBench extends App{
|
||||
val payloadType = HardType(Bits(8 bits))
|
||||
class Fpu(name : String, p : FpuParameter) extends Rtl{
|
||||
override def getName(): String = "Fpu_" + name
|
||||
|
|
|
@ -23,7 +23,7 @@ case class FpuFloat(exponentSize: Int,
|
|||
}
|
||||
|
||||
case class FpuOpcode(p : FpuParameter) extends SpinalEnum{
|
||||
val LOAD, STORE, MUL, ADD, FMA, I2F, F2I, CMP = newElement()
|
||||
val LOAD, STORE, MUL, ADD, FMA, I2F, F2I, CMP, DIV_SQRT = newElement()
|
||||
}
|
||||
|
||||
case class FpuParameter( internalMantissaSize : Int,
|
||||
|
|
|
@ -96,6 +96,18 @@ class FpuTest extends FunSuite{
|
|||
}
|
||||
}
|
||||
|
||||
def div(rd : Int, rs1 : Int, rs2 : Int): Unit ={
|
||||
cmdQueue += {cmd =>
|
||||
cmd.source #= id
|
||||
cmd.opcode #= cmd.opcode.spinalEnum.DIV_SQRT
|
||||
cmd.value.randomize()
|
||||
cmd.rs1 #= rs1
|
||||
cmd.rs2 #= rs2
|
||||
cmd.rs3.randomize()
|
||||
cmd.rd #= rd
|
||||
}
|
||||
}
|
||||
|
||||
def fma(rd : Int, rs1 : Int, rs2 : Int, rs3 : Int): Unit ={
|
||||
cmdQueue += {cmd =>
|
||||
cmd.source #= id
|
||||
|
@ -213,9 +225,57 @@ class FpuTest extends FunSuite{
|
|||
}
|
||||
}
|
||||
|
||||
// testAdd(0.1f, 1.6f)
|
||||
// testMul(0.1f, 1.6f)
|
||||
|
||||
def testDiv(a : Float, b : Float): Unit ={
|
||||
val rs = new RegAllocator()
|
||||
val rs1, rs2, rs3 = rs.allocate()
|
||||
val rd = Random.nextInt(32)
|
||||
load(rs1, a)
|
||||
load(rs2, b)
|
||||
|
||||
div(rd,rs1,rs2)
|
||||
storeFloat(rd){v =>
|
||||
val ref = a/b
|
||||
val error = Math.abs(ref-v)/ref
|
||||
println(f"$a / $b = $v, $ref $error")
|
||||
assert(checkFloat(ref, v))
|
||||
}
|
||||
}
|
||||
|
||||
def testSqrt(a : Float): Unit ={
|
||||
val rs = new RegAllocator()
|
||||
val rs1, rs2, rs3 = rs.allocate()
|
||||
val rd = Random.nextInt(32)
|
||||
load(rs1, a)
|
||||
|
||||
div(rd,rs1,rs2)
|
||||
storeFloat(rd){v =>
|
||||
val ref = Math.sqrt(a).toFloat
|
||||
val error = Math.abs(ref-v)/ref
|
||||
println(f"sqrt($a) = $v, $ref $error")
|
||||
assert(checkFloat(ref, v))
|
||||
}
|
||||
}
|
||||
|
||||
val b2f = lang.Float.intBitsToFloat(_)
|
||||
|
||||
|
||||
// testSqrt(2.25f)
|
||||
// dut.clockDomain.waitSampling(100)
|
||||
// simFailure()
|
||||
|
||||
testAdd(0.1f, 1.6f)
|
||||
testMul(0.1f, 1.6f)
|
||||
testFma(1.1f, 2.2f, 3.0f)
|
||||
testDiv(1.0f, 1.1f)
|
||||
testDiv(1.0f, 1.5f)
|
||||
testDiv(1.0f, 1.9f)
|
||||
testDiv(1.1f, 1.9f)
|
||||
testDiv(1.0f, b2f(0x3f7ffffe))
|
||||
testDiv(1.0f, b2f(0x3f7fffff))
|
||||
testDiv(1.0f, b2f(0x3f800000))
|
||||
testDiv(1.0f, b2f(0x3f800001))
|
||||
testDiv(1.0f, b2f(0x3f800002))
|
||||
|
||||
for(i <- 0 until 1000){
|
||||
testAdd(randomFloat(), randomFloat())
|
||||
|
@ -226,14 +286,18 @@ class FpuTest extends FunSuite{
|
|||
for(i <- 0 until 1000){
|
||||
testFma(randomFloat(), randomFloat(), randomFloat())
|
||||
}
|
||||
|
||||
for(i <- 0 until 1000){
|
||||
testDiv(randomFloat(), randomFloat())
|
||||
}
|
||||
for(i <- 0 until 1000){
|
||||
val tests = ArrayBuffer[() => Unit]()
|
||||
tests += (() =>{testAdd(randomFloat(), randomFloat())})
|
||||
tests += (() =>{testMul(randomFloat(), randomFloat())})
|
||||
tests += (() =>{testFma(randomFloat(), randomFloat(), randomFloat())})
|
||||
tests += (() =>{testDiv(randomFloat(), randomFloat())})
|
||||
tests.randomPick().apply()
|
||||
}
|
||||
|
||||
waitUntil(cpu.rspQueue.isEmpty)
|
||||
}
|
||||
|
||||
|
|
|
@ -3,8 +3,6 @@ package vexriscv.ip.fpu
|
|||
object MiaouDiv extends App{
|
||||
val input = 2.5
|
||||
var output = 1/(input*0.95)
|
||||
// def x = output
|
||||
// def y = input
|
||||
|
||||
def y = output
|
||||
def x = input
|
||||
|
@ -29,7 +27,7 @@ object MiaouSqrt extends App{
|
|||
def x = input
|
||||
|
||||
for(i <- 0 until 10) {
|
||||
output = y*(1.5-x*y*y/2)
|
||||
output = y * (1.5 - x * y * y / 2)
|
||||
println(output)
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue