from migen.fhdl.std import *

"""
Encoders and decoders between binary and one-hot representation
"""


class Encoder(Module):
    """Encode one-hot to binary

    If `n` is low, the `o` th bit in `i` is asserted, else none or
    multiple bits are asserted.

    Parameters
    ----------
    width : int
        Bit width of the input

    Attributes
    ----------
    i : Signal(width), in
        One-hot input
    o : Signal(max=width), out
        Encoded binary
    n : Signal(1), out
        Invalid, either none or multiple input bits are asserted
    """
    def __init__(self, width):
        self.i = Signal(width)  # one-hot
        self.o = Signal(max=max(2, width))  # binary
        self.n = Signal()  # invalid: none or multiple
        act = dict((1<<j, self.o.eq(j)) for j in range(width))
        act["default"] = self.n.eq(1)
        self.comb += Case(self.i, act)


class PriorityEncoder(Module):
    """Priority encode requests to binary

    If `n` is low, the `o` th bit in `i` is asserted and the bits below
    `o` are unasserted, else `o == 0`. The LSB has priority.

    Parameters
    ----------
    width : int
        Bit width of the input

    Attributes
    ----------
    i : Signal(width), in
        Input requests
    o : Signal(max=width), out
        Encoded binary
    n : Signal(1), out
        Invalid, no input bits are asserted
    """
    def __init__(self, width):
        self.i = Signal(width)  # one-hot, lsb has priority
        self.o = Signal(max=max(2, width))  # binary
        self.n = Signal()  # none
        for j in range(width)[::-1]:  # last has priority
            self.comb += If(self.i[j], self.o.eq(j))
        self.comb += self.n.eq(self.i == 0)


class Decoder(Module):
    """Decode binary to one-hot

    If `n` is low, the `i` th bit in `o` is asserted, the others are
    not, else `o == 0`.

    Parameters
    ----------
    width : int
        Bit width of the output

    Attributes
    ----------
    i : Signal(max=width), in
        Input binary
    o : Signal(width), out
        Decoded one-hot
    n : Signal(1), in
        Invalid, no output bits are to be asserted
    """

    def __init__(self, width):
        self.i = Signal(max=max(2, width))  # binary
        self.n = Signal()  # none/invalid
        self.o = Signal(width)  # one-hot
        act = dict((j, self.o.eq(1<<j)) for j in range(width))
        self.comb += Case(self.i, act)
        self.comb += If(self.n, self.o.eq(0))


class PriorityDecoder(Decoder):
    pass  # same