from __future__ import absolute_import, division, print_function
import os
import re
import shlex
import xml.etree.ElementTree as ET
from builtins import bytes, dict, int, range, str, super
from collections import defaultdict
from io import StringIO
from os import path as osp
from future.utils import PY2, PY3, native
import dolfin
from ..fem import force_ufl
from ..util import generate_base32_token
[docs]def magnitude(x):
    return x.magnitude if hasattr(x, 'magnitude') else x 
[docs]class PlotAddMixin(object):
[docs]    def add(self, name, units, expr, space=None):
        fsr = self.function_subspace_registry
        if space is None:
            try:
                space = fsr.get_function_space(magnitude(expr), new=True)
            except:
                pass
        func = dolfin.Function(space, name=name)
        e = (expr/units)
        if hasattr(e, 'm_as'):
            e = e.m_as('dimensionless')
        # TODO: actually use fsr
        # try:
        #     fsr.assign(func, e)
        # except: # output a warning or something
        e = force_ufl(e)
        dolfin.project(e, space, function=func)
        self._add_func(func)  
[docs]class XdmfPlot(PlotAddMixin):
    partial_xdmf_file = None
    timestamp = 0.0
    def __init__(self, filename, function_subspace_registry=None):
        self.base_filename = filename
        self.function_subspace_registry = function_subspace_registry
        self.delete_partials(self.base_filename)
[docs]    @staticmethod
    def delete_partials(filename):
        dirname, name = osp.split(osp.abspath(filename))
        partial_prefix = name + '.p.'
        for f in [osp.join(dirname, x) for x in os.listdir(dirname)
                  if x.startswith(partial_prefix)]:
            try:
                os.remove(f)
            except OSError:
                pass 
[docs]    def mpi_comm(self):
        return dolfin.mpi_comm_world() 
[docs]    def new(self, timestamp):
        self.close()
        self.timestamp = timestamp 
    def _add_func(self, func):
        self.ensure_open()
        self.partial_xdmf_file.write(func, float(self.timestamp))
[docs]    def open(self):
        if self.partial_xdmf_file is not None:
            self.close()
        while True:
            base = '{:s}.p.{:s}'.format(
                self.base_filename, generate_base32_token(20))
            xdmf_filename = base + '.xdmf'
            try:
                with open(xdmf_filename, 'x'):
                    pass
            except FileExistsError:
                continue
            else:
                break
        with open(base + '.incomplete', 'w'): pass
        self.partial_base_filename = base
        self.partial_xdmf_filename = xdmf_filename
        self.partial_xdmf_file = dolfin.XDMFFile(self.mpi_comm(), xdmf_filename) 
[docs]    def ensure_open(self):
        if self.partial_xdmf_file is None:
            self.open() 
[docs]    def close(self):
        xf = self.partial_xdmf_file
        if xf is None:
            return
        self.partial_xdmf_file = None
        xf.close()
        os.remove(self.partial_base_filename + '.incomplete')
        xdmf_combine(self.base_filename) 
    def __del__(self):
        self.close() 
[docs]def xdmf_combine(filename):
    dirname, name = osp.split(osp.abspath(filename))
    name_ = name + '.p.'
    files = [osp.join(dirname, head) for head, sep, tail
             in (f.rpartition('.xdmf') for f in os.listdir(dirname))
             if head.startswith(name_) and sep and not tail]
    result_grids = []
    for fn in files:
        if osp.exists(fn+'.incomplete'):
            continue
        tree = ET.parse(fn+'.xdmf')
        root = tree.getroot()
        domain = root.find('Domain')
        top_grids = domain.findall('Grid')
        att = top_grids[0].attrib
        grids = defaultdict(list)
        for g in (g
                  for top_grid in top_grids
                  for g        in top_grid.findall('Grid')):
            grids[float(g.find('Time').attrib['Value'])].append(g)
        assert len(grids) == 1
        grid_ts, gs = next(iter(grids.items()))
        g0 = gs[0]
        for g in gs[1:]:
            g0.append(g.find('Attribute'))
        result_grids.append((grid_ts, g0))
    result_grids.sort(key=lambda kv: kv[0])
    result = ET.Element("Grid")
    result.attrib['Name'] = 'TimeSeries'
    result.attrib['GridType'] = 'Collection'
    result.attrib['CollectionType'] = 'Temporal'
    new_et = ET.parse(StringIO('''\
<?xml version="1.0"?>
<!DOCTYPE Xdmf SYSTEM "Xdmf.dtd" []>
<Xdmf Version="3.0" xmlns:xi="http://www.w3.org/2001/XInclude">
<Domain></Domain></Xdmf>'''))
    for g in result_grids:
        result.append(g[1])
    new_et.getroot().find('Domain').append(result)
    with open(filename, 'wb') as h:
        new_et.write(h, encoding="utf-8", xml_declaration=True) 
if __name__ == '__main__':
    from ..util.function_subspace_registry import FunctionSubspaceRegistry
    import pint
    fsr = FunctionSubspaceRegistry()
    mesh = dolfin.UnitSquareMesh(2, 2)
    space = dolfin.FunctionSpace(mesh, "CG", 1)
    space2 = dolfin.FunctionSpace(mesh, "CG", 2)
    ur = pint.UnitRegistry()
    x = dolfin.SpatialCoordinate(mesh)
    c = dolfin.Constant(0.0)
    e1 = dolfin.sin(x[0] + c) * ur.dimensionless
    e2 = dolfin.sin(x[1] + c) * ur.dimensionless
    xp = XdmfPlot("out/zz.xdmf", fsr)
    for i in range(2):
        c.assign(i/10.)
        t = float(i)
        xp.new(t)
        f1 = xp.add("FOO", 1, e1, space=space)
        if True or i % 2 == 0:
            f2 = xp.add("BAR",  1, e2, space=space2)
    xp.close()