Source code for mpl_render.mpl_render

import numpy as np
from cached_property import cached_property
from threading import Thread, Lock
from time import sleep

# this is mostly based on the `viewlims.py` example on the matplotlib website

class BaseRendering(object):
    size = ...
    _update_id = 0

    # The "update delay" is needed because we typically get sent two
    # events, 'xlim_changed' and 'ylim_changed', in close succession. So
    # we wait a couple miliseconds until the last received one before
    # actually drawing anything. This delay is implemented by having
    # another thread sleep, then do the drawing.
    #
    # Drawing in a thread can lead to issues on non-threadsafe
    # backends, and it is unclear on which backends this is safe.
    #
    # https://stackoverflow.com/questions/48418534/which-matplotlib-methods-may-be-called-outside-the-gui-thread
    #
    # Anyway, if you run into problems, set `force_single_thread` to True

    UPDATE_DELAY = 0.01
    force_single_thread = False

    # For temporarily disabling updates
    _inhibit_update_plot = False

    def __init__(self, ax, size=None, extent=None, render_callback=None):
        if   size: self.size   = size
        if extent: self.extent = extent
        if render_callback: self.render_callback = render_callback

        self._lock # initialize lock attribute to prevent race condition
        self._initial_extent = extent
        self.ax = ax
        self._init_plot()
        self._connect_callbacks()
        self._attach_ax_reference()

    def _init_plot(self):
        raise NotImplementedError()

    def _connect_callbacks(self):
        cb = self.ax.callbacks
        cb.connect('xlim_changed', self.cb_ax_update)
        cb.connect('ylim_changed', self.cb_ax_update)

    def _attach_ax_reference(self):
        ''' prevent getting garbage collected until ax disappears '''
        setattr(self.ax, "___mpl_render_ref_{}".format(id(self)), self)

    @property
    def viewlim_extent(self):
        ''' extent of the axis.viewLim; setting this attribute makes the
        plot zoom to a particular region '''
        x0, y0, xd, yd = self.ax.viewLim.bounds
        x1 = x0 + xd
        y1 = y0 + yd
        return (x0, x1, y0, y1)

    @viewlim_extent.setter
    def viewlim_extent(self, extent):
        ax = self.ax
        self._inhibit_update_plot = True # avoid triggering two update events
        ax.set_xlim(extent[:2])
        self._inhibit_update_plot = False
        ax.set_ylim(extent[2:4])

    def update_data(self):
        ''' set self.data to self.user_render(), then return it

        NOTE: to trigger an update, use the `trigger_update()` method instead
        '''
        self.data = self.user_render()
        return self.data

    def user_render(self):
        '''
        Override this function to provide your own renderer; by
        default, this just calls `self.render_callback` with arguments
        `(self.size, self.extent)`.

        This method must return the data with which the plot should be
        updated. The specific format of the data depends on the
        subclass you're using, e.g. for `imshow` you must return an
        image, and for `plot` you must return a tuple (Xs, Ys). '''
        return self.render_callback(self.size, self.extent)

    @property
    def visible_size(self):
        ''' return size in pixels of image as it is shown on screen '''
        return tuple(ceil(x) for x in
                     ax.axesPatch.get_window_extent().bounds[2:4])

    def user_adjust_size(self):
        ''' by default this does nothing; override this '''

    def cb_ax_update(self, ax):
        ''' callback to update axis '''
        self.trigger_update()

    def trigger_update(self):
        '''
        Update the plot by re-rendering the visible area.

        This triggers an eventual update if `self.force_single_thread`
        is False, or updates right away otherwise. '''

        if self.force_single_thread:
            self._update_plot()
        else:
            Thread(target=self._thread_thunk_update).start()

    def _update_plot(self):
        if self._inhibit_update_plot:
            return

        ax = self.ax
        ax.set_autoscale_on(False) # prevent infinite loop

        # adjust image resolution if need be
        self.user_adjust_size()

        # update the image object with our new data and extent
        self._render_with_new_extent(self.viewlim_extent)

        # update UI
        ax.figure.canvas.draw_idle()

        # remove reference to `self.data` to allow GC
        self._after_update_plot_del_data()

    def _render_with_new_extent(self, extent):
        raise NotImplementedError()

    def _after_update_plot_del_data(self):
        # override this if you actually do need `self.data`
        self.data = None

    @cached_property
    def _lock(self):
        return Lock()

    def _thread_thunk_update(self):
        with self._lock:
            self._update_id = uid = self._update_id + 1
        sleep(self.UPDATE_DELAY)
        with self._lock:
            if self._update_id == uid: # nobody scooped us?
                self._update_plot()

class ColorBarMixin(object):
    colorbar = True

    def _init_colorbar(self):
        if self.colorbar is True: # auto create colorbar
            self.colorbar = self.ax.get_figure().colorbar(self.im)

    def rescale_colorbar_to_data(self):
        ''' note: this assumes `self.data` is defined '''
        V = self.data
        cbar = self.colorbar
        v0, v1 = np.nanmin(V), np.nanmax(V)
        self.im.set_clim(vmin=v0, vmax=v1)
        cbar_ticks = np.linspace(v0, v1, num=5, endpoint=True)
        cbar.set_ticks(cbar_ticks)

    def user_adjust_colorbar(self):
        ''' by default this calls `rescale_colorbar_to_data` '''
        self.rescale_colorbar_to_data()

class RenderingImShow(ColorBarMixin, BaseRendering):
    '''
    The `user_render()` method is expected to return an image with
    size `self.size`, representing area `self.extent`, where `extent`
    describes a rectangle `(x0, x1, y0, y1)`. '''

    size = (500, 300)
    im = True

    def __init__(self, ax, size=None, extent=None, render_callback=None,
                 kw=None, colorbar=True):
        if not kw: kw = {}
        kw.setdefault('origin', 'lower')
        self._imshow_kwargs = kw

        self.colorbar = colorbar

        super().__init__(
            ax, size=size, extent=extent, render_callback=render_callback)

    def _init_plot(self):
        if self.im is True: # auto create imshow plot
            kw = self._imshow_kwargs
            kw.setdefault('extent', self.extent)
            self.im = self.ax.imshow(self.update_data(), **kw)
        self._init_colorbar()

    def _render_with_new_extent(self, extent):
        im = self.im
        self.extent = extent
        im.set_extent(extent)
        self.data = self.user_render()
        im.set_data(self.data)

        self.user_adjust_colorbar()

def simple_plot_wrap(function):
    def func(size, extent):
        xstart, xend, ystart, yend = extent
        xsize, ysize = size
        # TODO: handle log scaling in Xs to compute equidistant points
        Xs = np.linspace(xstart, xend, xsize)
        Ys = function(Xs)
        return (Xs, Ys)
    return func

class RenderingPlot(BaseRendering):
    '''
    The `user_render()` method is expected to return a tuple `(Xs,
    Ys)` where Xs and Ys are 1D arrays. '''
    size = (2000, 2000)
    plotobj = True

    def __init__(self, ax, size=None, extent=None, render_callback=None,
                 kw=None, simple_plot=None):
        if not kw: kw = {}
        self._plot_kwargs = kw

        if simple_plot:
            render_callback = simple_plot_wrap(simple_plot)

        super().__init__(
            ax, size=size, extent=extent, render_callback=render_callback)

    def _init_plot(self):
        if self.plotobj is True:
            kw = self._plot_kwargs
            (X, Y) = self.update_data()
            self.plotobj = self._create_plot(X, Y, kw)

    def _create_plot(self, X, Y, kw):
        o, = self.ax.plot(X, Y, **kw)
        return o

    def _render_with_new_extent(self, extent):
        p = self.plotobj
        self.extent = extent
        (X, Y) = self.user_render()
        p.set_xdata(X)
        p.set_ydata(Y)