"""Classes and functions for working with WeatherNews data."""
from __future__ import print_function, absolute_import, division
from collections import namedtuple
import logging
import mmap
import struct
import numpy as np
import wni.config as config
import wni.scan_config
from wni.scan_config import clk_to_us
from wni.util import unpackb
logger = logging.getLogger(__name__)
# size of each range gate, in meters
GATE_SIZE = config.GATE_SIZE
AZEL_FORMAT = 'dd'
AZEL_SIZE = struct.calcsize(AZEL_FORMAT)
[docs]def parse_radials(data, nradials=1):
    """
    Parse specified number of radials from ``data``.
    Args:
        data: the buffer of radar data.
        nradials: the number of radials to return.  If the number of radials
          in the ``data`` set is less than nradials, an exception is raised.
    Returns:
        (tuple): tuple containing:
            **radials**: a list of Radials.
            **bytes_read**: bytes read from ``data``.
    """
    header = PacketHeader(data)
    packet_length = header.data_length + header.HEADER_LENGTH
    found = 0
    end_byte = len(data)
    packet_start = 0
    radials, packets = [], []
    while len(radials) < nradials:
        packet_end = packet_start + packet_length
        this_data = data[packet_start:packet_end]
        header = this_data[:4]
        if header == b'OMG!':
            packet = Packet(this_data)
            packets.append(packet)
            packet_start += packet_length
            found += 1
        else:
            azel_bytes = data[packet_start:packet_start + AZEL_SIZE]
            az, el = struct.unpack(AZEL_FORMAT, azel_bytes)
            packet_start += AZEL_SIZE
            radials.append(Radial(packets, az, el))
        if packet_start == end_byte:
            break
    return radials, packet_start 
Radial = namedtuple('Radial', 'packets az el'.split())
[docs]class IQData(object):
    """
    Parse data from the FPGA into a Python object.
    """
    def __init__(self, radar_data, waveforms=None, scaninfo=()):
        """
        Args:
          radar_data (bytes): Pulse data from FPGA
          waveforms (np.array or None): Tx waveforms
          scaninfo: Any metadata associated with the scan.
        """
        self._raw_data = memoryview(radar_data)
        self.az = []
        self.el = []
        self.packets = []
        read = self._extract_packets()
        self._bytes_consumed = read
        # Matched filters
        self.waveforms = waveforms
        self.i = np.zeros([2, len(self.packets), len(self.packets[0].i1)], dtype=np.int16)
        self.q = np.zeros(self.i.shape, dtype=np.int16)
        for idx, packet in enumerate(self.packets):
            self.i[0, idx, :] = packet.i1
            self.i[1, idx, :] = packet.i2
            self.q[0, idx, :] = packet.q1
            self.q[1, idx, :] = packet.q2
        self.numgates = self.i.shape[2]
        # set info for scan based on first header (it _should_ be the same for
        # every pulse, but the hardware does not enforce this)
        h = self.packets[0].header
        self.header = h
        self.scaninfo = scaninfo
        self.pulses = self.num_pulses
        # this is the delay between the tx start and rx start.
        _hardware_delay = h.rx_delay - h.tx_delay
        rx_delay = _hardware_delay - wni.scan_config._RX_LAG
        start_gate = (rx_delay // 4) * GATE_SIZE
        if h.fir_enable:
            gate_size = GATE_SIZE * config.FIR_DECIMATION
        else:
            gate_size = GATE_SIZE
        end_gate = start_gate + (self.numgates - 1) * gate_size
        self.gates = np.linspace(start_gate, end_gate, self.numgates)
        self._iq = None
        self._range_correction = None
    def _calc_filtered_iq(self, channel):
        filtered = self._iq_filtered_conv(channel)
        return filtered
[docs]    def savemat(self, fname):
        """Dumps iq (unfiltered) data to a format readable by Matlab."""
        import scipy.io
        data = {
            'iq': self.i + 1j*self.q,
            'num_pulses': self.num_pulses,
            'prt': self.header.prt,
            'gates': self.gates,
            'az': self.az,
            'el': self.el,
            'header': [p.header for p in self.packets],
            'scaninfo': self.scaninfo,
        }
        scipy.io.savemat(fname, data) 
    def _extract_packets(self):
        radials, offset = parse_radials(self._raw_data)
        for radial in radials:
            self.packets.extend(radial.packets)
            self.az.append(radial.az)
            self.el.append(radial.el)
        self.num_pulses = len(self.packets)
        return offset
    def _az_to_degrees(self, az):
        return (az % 8000) * 360 / 8000
    def _el_to_degrees(self, el):
        return el % 8000
    @property
    def iq(self):
        if self._iq is None:
            self._iq = self.i + 1j*self.q
        return self._iq
    def _iq_filtered_fft(self, channel):
        """Implement matched filter using fft."""
        if channel != 0 and channel != 1:
            raise ValueError('channel must be 0 or 1, not {}'.format(channel))
        iq = self.iq[channel]
        h = self.h[channel]
        nbins = nextpow2(iq.shape[1])
        X = np.fft.fft(iq, nbins)
        H = np.fft.fft(h, nbins)
        H_2 = np.matlib.repmat(H, iq.shape[0], 1)
        Y = X * H_2
        y = np.fft.ifft(Y)
        start = h.shape[0] // 2 - 1
        y = y[:, start: start + iq.shape[1]]
        y.real /= np.sum(np.abs(h.real))
        y.imag /= np.sum(np.abs(h.imag))
        return y
    def _iq_filtered_conv(self, channel):
        """Implement matched filter using convolution."""
        if channel != 0 and channel != 1:
            raise ValueError('channel must be 0 or 1, not {}'.format(channel))
        iq = self.iq[channel]
        h = self.h[channel]
        y = np.array([np.convolve(row, h, 'same') for row in iq])
        y.real /= np.sum(np.abs(h.real))
        y.imag /= np.sum(np.abs(h.imag))
        return y
    def _calc_doppler_spectrum(self, iq, n, normalize):
        """Implementation of _doppler_spectrum_unfiltered and doppler_spectrum"""
        fft = np.fft.fft(iq, n, 0)
        shifted = np.fft.fftshift(fft, 0)
        magnitude = np.abs(shifted)
        magnitude = magnitude.astype('float32')
        if normalize:
            magnitude /= magnitude.max()
        return magnitude
    def _doppler_spectrum_unfiltered(self, n=32, normalize=False):
        """Return the Doppler spectrum of the unfiltered iq data."""
        if n < len(self.packets):
            n = nextpow2(len(self.packets))
        return [self._calc_doppler_spectrum(iq, n, normalize) for iq in self.iq]
[docs]    def doppler_spectrum(self, n=32, normalize=False):
        """Return the Doppler spectrum of the data."""
        if n < len(self.packets):
            n = nextpow2(len(self.packets))
        return [self._calc_doppler_spectrum(iq, n, normalize) for iq in self.iq_filtered] 
[docs]    def doppler_spectrum_windowed(self, n=32, normalize=False):
        if n < len(self.packets):
            n = nextpow2(len(self.packets))
        window = np.hanning(self.iq_filtered[0].shape[0])
        windowed_iq = [(iq.T * window).T for iq in self.iq_filtered]
        return [self._calc_doppler_spectrum(iq, n, normalize) for iq in windowed_iq] 
[docs]    def calc_doppler_velocity(self, channel, wavelength=config.WAVELENGTH):
        """Calculate the doppler velocity"""
        # use eqn. 6.18 and 6.19 of "Doppler Weather and Radar Observations"
        iq = self.iq[channel, :, :]
        if False:
            iq0 = iq[:-1, :]
            iq1 = iq[1:, :]
            i0 = iq0.real
            q0 = iq0.imag
            i1 = iq1.real
            q1 = iq1.imag
            real_part = i0 * i1 + q0 * q1
            imag_part = i0 * q1 - q0 * i1
            r = np.sum(real_part, 0)
            i = np.sum(imag_part, 0)
            # divide by the number of pulses - 1
            r /= (iq.shape[0] - 1)
            i /= (iq.shape[0] - 1)
            phase = np.arctan2(i, r)
        else:
            iq0 = iq[:-1, :]
            iq1 = iq[1:, :]
            rotations = iq0.conj() * iq1
            total_rotation = np.mean(rotations, 0)
            phase = np.arctan2(total_rotation.imag, total_rotation.real)
        prt_us = clk_to_us(self.header.prt, config.TXRX_CLK)
        prt = prt_us / 1E6
        factor = -wavelength / (4 * np.pi * prt)
        return phase * factor 
[docs]    def calc_reflectivity(self, channel, cal=0, range_correct=True):
        """Calculate the reflectity.  """
        iq = self.iq[channel]
        power = iq.real**2 + iq.imag**2
        power_avg = np.mean(power, 0)
        # range correction
        if range_correct:
            power_avg *= self.gates**2
        power_db = 10 * np.log10(power_avg)
        reflectivity = power_db + cal
        return reflectivity 
[docs]    @classmethod
    def fromfile(cls, filename, *args, **kwargs):
        with open(filename, 'rb') as f:
            data = f.read()
        return cls(data, filename, *args, **kwargs)  
[docs]class IQDataReader:
    """
    Iterate over a file dumped by :func:``wni.processees.data_dumper``.
    Yields a :class:IQData instance for every iteration.
    """
    def __init__(self, fname, has_header=True):
        self.fname = fname
        # set self.map to None becuase __del__ gets called if the open() call
        # fails, and raises AttributeError on self.map.
        self.map = None
        self._file = open(fname, 'rb')
        self.map = mmap.mmap(self._file.fileno(),
                             length=0,
                             access=mmap.ACCESS_READ)
        if has_header:
            self.header = IQFileHeader.fromfile(fname)
            self.data_start = self.header.bytes_consumed
        else:
            self.header = None
            self.data_start = 0
        self._current_offset = self.data_start
        self._read_size = int(self.__guess_radial_size() * 1.2)
        # a list of offsets at which radials can be parsed.
        self.radial_offsets = []
        # list of radials' offsets that could not be parsed due to a malformed
        # packet.
        self._bad_radials = []
    def __guess_radial_size(self):
        """Guess how much space (in bytes) a radial takes up."""
        bytes = self.map[self.data_start:self.data_start + PacketHeader.HEADER_LENGTH]
        header = PacketHeader(bytes)
        length = (header.data_length + header.HEADER_LENGTH) * header.pulses
        return length
[docs]    def parse_radial_at(self, pos=None):
        """
        Parse radial at the specified offset in the memorymapped file.  If pos
        is None, parse the radial at ``self._current_offset``
        """
        if pos is None:
            pos = self._current_offset
        bytes = self.map[pos:pos + self._read_size]
        iq_data = IQData(bytes, scaninfo=self.header)
        return iq_data 
    def __next__(self):
        try:
            # check to see if we're actually done with the file.
            if self._current_offset == len(self.map):
                raise StopIteration
            iqdata = self.parse_radial_at(self._current_offset)
        except ValueError as ve:
            if str(ve).startswith('packet start should be OMG!'):
                logger.warning('Malformed I/Q data!  Skipping ahead.')
                self._bad_radials.append(self._current_offset)
                self.map.seek(self._current_offset)
                next_start = self.map.find(b'OMG!')
                if next_start == -1:
                    raise StopIteration
                self._current_offset = next_start
                # recursive call here is dangerous if I haven't thought through
                # edge cases; it could a StackOverflow
                return self.__next__()
            else:
                raise
        self.radial_offsets.append(self._current_offset)
        self._current_offset += iqdata._bytes_consumed
        return iqdata
    def __iter__(self):
        return self
[docs]    def close(self):
        if self.map is not None:
            self.map.close()
            self._file.close() 
    def __del__(self):
        self.close() 
[docs]class Packet(object):
    """
    Parse a single packet in the format it is in coming from the FPGA to the
    PS.  A packet consists of a packet header and I/Q data.
    """
    def __init__(self, packet_data):
        self.header = PacketHeader(packet_data)
        start = self.header.HEADER_LENGTH
        stop = start + self.header.data_length
        data = np.frombuffer(packet_data[start:stop], dtype='int16')
        self.i1 = data[::4]
        self.q1 = data[1::4]
        self.i2 = data[2::4]
        self.q2 = data[3::4]
        assert self.i1.size == self.i2.size == self.q1.size == self.q2.size
    def __str__(self):
        return str(self.header).replace('Header', '') 
[docs]def nextpow2(x):
    """Return the next power of 2 greater than or equal to x."""
    return int(np.power(2, np.ceil(np.log(x) / np.log(2)))) 
assert nextpow2(2) == 2
assert nextpow2(3) == 4
assert nextpow2(255) == 256
assert nextpow2(257) == 512