Source code for jLM.JupyterDisplay

# 
# University of Illinois Open Source License
# Copyright 2016-2018 Luthey-Schulten Group,
# All rights reserved.
# 
# Developed by: Luthey-Schulten Group
#                           University of Illinois at Urbana-Champaign
#                           http://www.scs.uiuc.edu/~schulten
# 
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the Software), to deal with 
# the Software without restriction, including without limitation the rights to 
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 
# of the Software, and to permit persons to whom the Software is furnished to 
# do so, subject to the following conditions:
# 
# - Redistributions of source code must retain the above copyright notice, 
# this list of conditions and the following disclaimers.
# 
# - Redistributions in binary form must reproduce the above copyright notice, 
# this list of conditions and the following disclaimers in the documentation 
# and/or other materials provided with the distribution.
# 
# - Neither the names of the Luthey-Schulten Group, University of Illinois at
# Urbana-Champaign, nor the names of its contributors may be used to endorse or
# promote products derived from this Software without specific prior written
# permission.
# 
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL 
# THE CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR 
# OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 
# ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 
# OTHER DEALINGS WITH THE SOFTWARE.
# 
# Author(s): Tyler M. Earnest
# 

"""SpatialModel mixins for rich display in Jupyter"""
import itertools, io, base64, random, string

import numpy as np
import warnings
from . import Template as Tmp
from . import Lattice
from . import DisplayUtils as Dsp

def _maybeJupyter(fn):
    def wrapped(*args, **kwargs):
        try:
            return fn(*args, **kwargs)
        except ImportError:
            warnings.warn("Call to {} failed due to missing module dependencies".format(fn.__name__), stacklevel=2)
            return None
    return wrapped

def _num(x, scale=1, unit=None):
    return  Dsp.numToStr(x, scale=scale, unit=unit, n=3, smallest_exp=3)

def _numFmt(fmt, *nums):
    ss = []
    for v in nums:
        try:
            x, scale, unit, n = v
        except (ValueError, TypeError):
            x, scale, unit, n = v, 1.0, None, 3
        ss.append(Dsp.numToStr(x, scale=scale, unit=unit, n=n, smallest_exp=3))
    return fmt.format(*ss)


def _maxEntropySlice(siteLattice, axis):
    maxEntropy, planeIndex = 0,0

    for i in range(siteLattice.shape[axis]):
        s = [slice(None), slice(None), slice(None)]
        s[axis] = i
        im = siteLattice[tuple(s)]
        nbins = siteLattice.max()+1
        e = sum(-x*np.log(x) for x in np.histogram(im,bins=nbins)[0]/im.size if x>0)

        if e > maxEntropy:
            maxEntropy, planeIndex = e, i

    return planeIndex

def _siteTypeImg(siteLattice, siteColors, axis0, axis1, axis2,  planeIndex):
    latticeShape = siteLattice.shape

    stRGB = np.zeros((latticeShape[axis0],latticeShape[axis1],3))
    if planeIndex is None:
        planeIndex = latticeShape[axis2]//2

    def rgb8(i):
        return np.array([int(255*c) for c in siteColors[i]], dtype=np.uint8)

    stImg = np.zeros((latticeShape[axis0],latticeShape[axis1]),dtype=np.uint8)

    for i,j in itertools.product(range(latticeShape[axis0]),range(latticeShape[axis1])):
        idx=[0,0,0]
        idx[axis0] = latticeShape[axis0]-i - 1
        idx[axis1] = j
        idx[axis2] = planeIndex
        stImg[i,j] = siteLattice[idx[0],idx[1],idx[2]]

    stRGB = np.zeros((stImg.shape[0],stImg.shape[1],3),dtype='uint8')
    for i,j in itertools.product(*map(range,stImg.shape)):
        stRGB[i,j,:] = rgb8(stImg[i,j])

    imgFile = io.BytesIO()
    import PIL.Image
    img = PIL.Image.fromarray(stRGB)
    img.save(imgFile,'PNG')
    return base64.b64encode(imgFile.getvalue()).decode()

[docs] class FileJupyterMixin: def _speciesJ2context(self, sps, cs): ctx = super()._speciesJ2context(sps,cs) ns = self.speciesCounts[sps] ts = self.speciesCountTimes minIdx = np.argmin(ns) maxIdx = np.argmax(ns) trajData = dict(replicate=self.replicate, duration=ts[-1], startCt=ns[0], endCt=ns[-1], meanCt=_numFmt("{} ± {}", np.mean(ns), np.std(ns)), minCt=_numFmt("{} @ t={}", (ns[minIdx], 1, None, 0), (ts[minIdx], 1.0, 's', 3)), maxCt=_numFmt("{} @ t={}", (ns[maxIdx], 1, None, 0), (ts[maxIdx], 1.0, 's', 3))) ts = self.speciesCountTimes if ns.shape[0] > 256: stride = int(ns.shape[0]/256) ns = ns[::stride] ts = ts[::stride] import matplotlib.pyplot as plt fig, ax = plt.subplots(figsize=(6,4)) ax.plot(ts,ns) if ts[0] == ts[-1]: xlim=(0,1) else: xlim=(ts[0], ts[-1]) ns0, ns1 = int(ns.min()/1.1), int(ns.max()*1.1) ylim=(ns0-1,ns1+1) ax.set(xlim=xlim, ylim=ylim, xlabel=r"$t/\mathrm{s}$", ylabel=r"$n_{\mathrm{"+sps.name+"}}$") fig.tight_layout() imgFile = io.BytesIO() fig.savefig(imgFile, format="SVG", transparent=True) plt.close(fig) lines = imgFile.getvalue().decode().splitlines() while '<svg' not in lines[0]: lines.pop(0) trajData['svg'] = '\n'.join(lines) ctx['trajData'] = trajData return ctx def _modelJ2context(self): ctx = super()._modelJ2context() ctx['nReplicates']= len(self.h5['Simulations']) ctx['replicates'] = [] for label in self.h5['Simulations']: tsp = self.h5['Simulations'][label]['SpeciesCountTimes'][-1] tlt = self.h5['Simulations'][label]['LatticeTimes'][-1] nsp = self.h5['Simulations'][label]['SpeciesCountTimes'].shape[0] nlt = self.h5['Simulations'][label]['LatticeTimes'].shape[0] ct0 = np.sum(self.h5['Simulations'][label]['SpeciesCounts'][0,:]) ct1 = np.sum(self.h5['Simulations'][label]['SpeciesCounts'][-1,:]) ctx['replicates'].append(dict(label=label, time=_num(max(tsp,tlt),unit='s'), latticeEvals=nlt, speciesEvals=nsp, startCt=ct0, endCt=ct1)) return ctx
[docs] class JupyterDisplayMixin: def _repr_html_(self): return Tmp.j2render("siminfo.html", self._modelJ2context()) # Jinja2 context builders. If overidden, call super()._...J2context() def _speciesJ2context(self, sps, cs): rxns = self._reactionJ2context([r for r in self.reactionList if sps in r.products or sps in r.reactants]) transitionTable = np.zeros((len(self.regionList),len(self.regionList), len(self.speciesList)+1), dtype=int) transitionTable[...] = -1 for sidx, r0idx, r1idx, rt in self._transitionRates: if sidx is None: sidx = slice(None) if r0idx is None: r0idx = slice(None) if r1idx is None: r1idx = slice(None) transitionTable[r0idx, r1idx, sidx] = rt txb = transitionTable[:,:,sps.idx] def style(x,y=None): d = dict(v=x) if y is not None: d['style'] = y return d def getDr(r0,r1): didx = txb[r0.idx,r1.idx] try: v = self.diffRateList[didx].value s = _num(v, 1e12) except KeyError: didx = -1 if didx < 0: return style("undef", "undef") elif v < 1e-20: return style(s, "zero") elif v >= self.maxDiffusionRate(): return style(s, "inf") else: return style(s) dfTable = ( [ [style('<span class="jLMnum">× µm<sup>2</sup>⋅s<sup>-1</sup></span>')] + [style(r._html()) for r in self.regionList] ] + [ [style(r0._html())] + [ getDr(r0, r1) for r1 in self.regionList] for r0 in self.regionList] ) placementList = [(x,y,z) for x,y,z,s in self._particlePlacement if s==sps.idx] attrs = dict() for attr in sps._dynamicAttrs(): v = getattr(sps, attr) if isinstance(v, str): attrs[attr] = v elif np.isscalar(v): attrs[attr] = _num(v) else: attrs[attr] = repr(v.__class__) ctx = dict(label=sps.name, totalCount=cs['countBySpecies'][sps], otherAttrs=attrs, idx=sps.idx, name=sps.name, tex=sps._TeXMath(), annotation=sps.annotation, totalReactions=len(rxns['rxns']), distribution=[ dict(region=st._html(), count=cs['countBySpeciesRegion'][sps][st], conc=_num(cs['concBySpeciesRegion'][sps][st], 1e6, "µM")) for st in self.regionList if cs['countBySpeciesRegion'][sps][st] > 0], rxns=rxns['rxns'], placement=[], nPlaced=len(placementList), diffTbl=dfTable) def lookupRegion(v): return self.regionList[self.siteLattice[v[2],v[1],v[0]]]._html() if 0 < len(placementList) <= 200: placement= {v:dict(count=0,x=v[0],y=v[1],z=v[2],region=lookupRegion(v)) for v in placementList} for v in placementList: placement[v]['count'] += 1 ctx['placement'] = list(placement.values()) else: pd = dict() for v in placementList: try: pd[lookupRegion(v)] += 1 except KeyError: pd[lookupRegion(v)] = 1 ctx['placement'] = [ dict(count=v, region=k) for k,v in pd.items() ] return ctx def _reactionJ2context(self, rxns): ctx = dict(rxns=[]) annotate = False for r in rxns: annotate |= r.annotation is not None or r.rate.annotation is not None ctx['rxns'].append(dict( idx=r.idx, annotation=r.annotation or r.rate.annotation, rxn="$" + r._TeXMath() + "$", reg=r', '.join(self.regionList[rgidx]._html() for rxidx, rgidx in self._reactionLocations if rxidx==r.idx), rate=_num(r.rate.value, unit=Dsp.unicodeUnits(r.rate._unit)))) ctx['annotationCol'] = annotate return ctx def _regionJ2context(self, plane, planeIndex, maxWidth=600, maxHeight=600): axis0 = 'xyz'.index(plane[0]) axis1 = 'xyz'.index(plane[1]) axis2 = (set((0,1,2)) - set((axis0,axis1))).pop() latticeShape = self.siteLattice.shape if planeIndex is None: planeIndex = latticeShape[axis2]//2 xscl = maxWidth/latticeShape[axis0] yscl = maxHeight/latticeShape[axis1] def proc_scl(s): if s >= 1: return yscl else: inv = 1/s inv2 = 2**int(np.ceil(np.log2(inv))) return 1/inv2 scl = proc_scl(min(xscl, yscl)) ctx = dict(w=int(scl*latticeShape[axis0]), h=int(scl*latticeShape[axis1]), plane=plane, reg=[r._html() for r in self.regionList], b64=_siteTypeImg(self.siteLattice, [x._floatColor() for x in self.regionList], axis0, axis1, axis2, planeIndex), dir="xyz"[axis2], planeIndex=planeIndex, renderId=_renderId()) return ctx def _modelJ2context(self): ctx = dict() cs = self.particleStatistics() dims = np.array(self.shape) axis0 = np.argmax(dims) dims[axis0] = -1 axis1 = np.argmax(dims) dims = np.array(self.shape) if dims[axis0] == dims[axis1]: axis0,axis1 = min(axis0,axis1), max(axis0,axis1) else: axis0 = axis0 if dims[axis0]>dims[axis1] else axis1 axis1 = axis1 if dims[axis0]>dims[axis1] else axis0 axis2 = (set((0,1,2)) - set((axis0,axis1))).pop() # find most interesting slice from image entropy planeIndex = _maxEntropySlice(self.siteLattice, axis2) ctx = self._regionJ2context("xyz"[axis0]+"xyz"[axis1], planeIndex, maxWidth=700, maxHeight=500) ctx['name'] = self.name ctx['filename'] = self.filename ctx['nsps'] = len(self.speciesList) ctx['nrxns'] = len(self.reactionList) ctx['nsts'] = len(self.regionList) ctx['nrconst'] = len(self.rxnRateList) ctx['ndconst'] = len(self.diffRateList) ctx['dims'] = dict(x=dims[2],y=dims[1],z=dims[0]) ctx['dims'] = _numFmt("{} × {} × {}", (dims[2], 1, None, 0), (dims[1], 1, None, 0), (dims[0], 1, None, 0)) ctx['siteVol'] = _num(self.siteV, 1e15, "fl") ctx['latticeSpacing'] = _num(self.latticeSpacing, 1e9, "nm") ctx['nPlacedParticles'] = len(self._particlePlacement) ctx['pps'] = self.pps ctx['bytesPerParticle'] = self.bytesPerParticle ctx['regions'] = [] def maybeAttr(attr, cattr, scl, unit): if hasattr(self, attr) and getattr(self,attr): ctx[cattr] = _num(getattr(self, attr), scl, unit) else: ctx[cattr] = "undefined" maybeAttr("timestep", "timeStep", 1e6, "µs") maybeAttr("simulationTime", "simTime", 1e0, "s") maybeAttr("latticeWriteInterval", "latticeWriteInterval", 1e0, "s") maybeAttr("speciesWriteInterval", "speciesWriteInterval", 1e0, "s") maybeAttr("hookInterval", "hookInterval", 1e0, "s") for r in self.regionList: c = dict() c['label'] = r._html() c['counts'] = cs["countByRegion"][r] c['conc'] = _num(cs["concByRegion"][r], 1e6, r"µM") c['volume'] = _num(cs["regionVol"][r], 1e15, r"fl") c['siteCount'] = cs["regionCounts"][r] occScl = 100.0/(max(1.0, cs["regionCounts"][r])*self.pps) c['occ'] = r"${:.2f}\,\%$".format(occScl*cs["countByRegion"][r]) ctx['regions'].append(c) return ctx def _x3dJ2context(self, filterFunctions=None): siteLattice = self.siteLattice if filterFunctions is None: filterFunctions = dict() latticeDims = siteLattice.shape scl = 7/max(latticeDims) centroid = 0.5*np.array(latticeDims) nx,ny,nz = latticeDims jc = dict(centroid=centroid, scl=scl, renderId=_renderId(), species=[], sites=[], modelName=self.name) for i,reg in enumerate(self.regionList): if reg in filterFunctions: siteMatch = (siteLattice == reg.idx) posMatch = filterFunctions[reg](*np.mgrid[0:nx, 0:ny, 0:nz]) binaryLattice = np.array(siteMatch&posMatch, dtype=np.uint8) else: binaryLattice = np.array(siteLattice == reg.idx, dtype=np.uint8) if binaryLattice.any(): verts, faces = Lattice.greedyMesh(binaryLattice) r,g,b = reg._floatColor() jc['sites'].append(dict(name=reg.name, idx=i, hexColor=reg._hexColor(), label=reg._html(), checked="" if reg.idx==0 else " checked", faces=' '.join(str(x) for face in faces for x in face), verts=' '.join(str(x) for vert in verts for x in vert), r=r, g=g, b=b, choice="-1" if reg.idx == 0 else "0")) particleTypes = sorted(set(s for _,_,_,s in self._particlePlacement)) ncolors = len(particleTypes) particleTypes = [(s, i+len(jc['sites']), Dsp.colorWheel(i/ncolors)) for i,s in enumerate(particleTypes)] for s, objIdx, (r,g,b) in particleTypes: jc['species'].append(dict(hexColor=Dsp.toHex((r,g,b)), r=r, g=g, b=b, label=self.speciesList[s]._html(), idx=objIdx, radius=0.5, locs = [dict(x=x+0.5,y=y+0.5,z=z+0.5) for (z,y,x,_) in filter(lambda x:x[3]==s, self._particlePlacement)])) jc['downloadX3D'] = False return jc #### Model builder context manager
[docs] def construct(sim): """Track newly created model objects and display in Notebook. Context manager which tracks new species, reactions, etc., and displays a HTML summary when used in Jupyter""" class InteractiveConstruct: def _getKeys(self): return (set(x.name for x in sim.speciesList), set(x.name for x in sim.regionList), set(x.name for x in sim.rxnRateList), set(x.name for x in sim.reactionList), set(x.name for x in sim.diffRateList)) def __init__(self): self._keys0 = self._getKeys() def __enter__(self): return self def __exit__(self, type, value, traceback): if type is None: try: import IPython.display as ipd except ImportError: warnings.warn("Construct summary not displayed", stacklevel=2) return None sps0, reg0, k0, rxn0, d0 = self._keys0 sps1, reg1, k1, rxn1, d1 = self._getKeys() newSps = sorted((sim.speciesList[x] for x in sps1 - sps0), key=lambda x: x.idx) newRxns = sorted((sim.reactionList[x] for x in rxn1 - rxn0), key=lambda x: x.idx) newRegs = sorted((sim.regionList[x] for x in reg1 - reg0), key=lambda x: x.idx) newKs = sorted((sim.rxnRateList[x] for x in k1 - k0), key=lambda x: x.idx) newDs = sorted((sim.diffRateList[x] for x in d1 - d0), key=lambda x: x.idx) spsCtx = dict(annotationCol=any(o.annotation is not None for o in newSps), texCol=any(o._texstr is not None for o in newSps), objs=[dict(id=o.idx, html=o._repr_html_(), annotation=o.annotation, tex=o._texstr) for o in newSps]) rxnsCtx = dict(annotationCol= (any(o.annotation is not None for o in newRxns) or any(o.rate.annotation is not None for o in newRxns)), objs=[dict(id=o.idx, html=o._repr_html_(), annotation=(o.annotation or o.rate.annotation), reg=r', '.join(sim.regionList[rgidx]._html() for rxidx, rgidx in sim._reactionLocations if rxidx==o.idx), rate=_num(o.rate.value, unit=Dsp.unicodeUnits(o.rate._unit))) for o in newRxns]) regsCtx = dict(annotationCol=any(o.annotation is not None for o in newRegs), objs=[dict(id=o.idx, html=o._repr_html_(), annotation=o.annotation) for o in newRegs]) ksCtx = dict(annotationCol=any(o.annotation is not None for o in newKs), texCol=any(o._texstr is not None for o in newKs), objs=[dict(id=o.idx, html=o._repr_html_(), annotation=o.annotation, tex=o._texstr, rate=_num(o.value, unit=Dsp.unicodeUnits(o._unit))) for o in newKs]) dsCtx = dict(annotationCol=any(o.annotation is not None for o in newDs), texCol=any(o._texstr is not None for o in newDs), objs=[dict(id=o.idx, html=o._repr_html_(), annotation=o.annotation, tex=o._texstr, rate=_num(o.value, unit=Dsp.unicodeUnits(o._unit))) for o in newDs]) ctx = dict(sps=spsCtx, rxns=rxnsCtx, regs=regsCtx, ks=ksCtx, ds=dsCtx) rows = sum(map(lambda x: len(x['objs']), ctx.values()), 0) if rows > 0: ipd.display(ipd.HTML(Tmp.j2render("newObject.html", ctx))) return InteractiveConstruct()
#### Site lattice viewer
[docs] @_maybeJupyter def displayGeometry(self, filterFunctions=None, mode="widget"): """3-D site type lattice viewer The display mode can be "widget", which displays in the notebook, "download_x3d", which opens a download link in the notebook to the X3D scene, or "download_html", which opens a download link in the notebook to a standalone HTML file. To hide parts of the lattice, `filterFunctions` can be specified. This option takes a list of functions which map from a (x,y,z) mesh grid to a [nx,ny,nz] boolean mask where only subvolumes marked True are shown. To only show volumes whose z coordinate are less than 32, the function >>> def zfilter(x,y,z): >>> return z<32 is used. Here the arguments x,y,z are of type :py:class:`numpy.ndarray` and a boolean lattice is returned. Args: filterFunctions (dict): Dict of functions which take a (nx,ny,nz) mesh to a bool [nx,ny,nz] mask mode (str): View mode """ ctx = self._x3dJ2context(filterFunctions) if mode == 'widget': return Tmp.displayj2html("x3d.html", ctx) elif mode == 'download_x3d': ctx['downloadX3D'] = True xml = Tmp.j2render("structure.x3d", ctx) return _downloadFile(xml.encode("ascii"), self.name + ".x3d") elif mode == 'download_html': data = Tmp.j2render("standaloneX3d.html", ctx) return _downloadFile(data.encode("ascii"), self.name + "-3dView.html") else: raise ValueError("mode should be {widget, download_x3d, download_html}")
[docs] @_maybeJupyter def showRegion(self, plane="xz", planeIndex=None): """Display a slice of the site lattice Args: plane (str): Viewing plane, e.g. "xy" planeIndex (int): Index along the normal direction to the plane """ ctx = self._regionJ2context(plane, planeIndex) return Tmp.displayj2html("region.html", ctx)
[docs] @_maybeJupyter def showRegionStack(self, plane='xz', scl=None, maxWidth=600, maxHeight=600): """Display all slices of the site lattice interactively Args: plane (str): Viewing plane, e.g. "xy" scl (int): Scale pixels by this amount maxWidth (int): Maximum width of image maxHeight (int): Maximum height of image """ siteColors = [r._floatColor() for r in self.regionList] htmlNames=[r._html() for r in self.regionList] return _showRegionStack(self.siteLattice, htmlNames, siteColors, plane=plane, scl=scl, maxWidth=maxWidth, maxHeight=maxHeight)
# Model introspection
[docs] @_maybeJupyter def showAllParameters(self): """Display a table of reaction rates and diffusion constants""" return self._paramTable(self, "All defined parameters", list(self.rxnRateList)+list(self.diffRateList))
[docs] @_maybeJupyter def showRateConstants(self): """Display a table of reaction rates""" return self._paramTable(self, "Reaction rate constants", self.rxnRateList)
[docs] @_maybeJupyter def showDiffusionConstants(self): """Display a table of diffusion constants""" return self._paramTable(self, "Diffusion constants", self.diffRateList)
def _paramTable(self, title, params): ctx = dict(title=title, count=len(params), params=[]) counts = {k:0 for k in params} ds = {d.idx: d for d in params if d in self.diffRateList} for rx in self.reactionList: try: counts[rx.rate] += 1 except KeyError: pass for _, _, _, r in self._transitionRates: try: counts[ds[r]] += 1 except KeyError: pass annotate = False for p in params: annotate |= p.annotation is not None ctx['params'].append(dict(id=p.idx, symbol=p._TeXMath(), annotation=p.annotation, value=_num(p.value), unit=Dsp.unicodeUnits(p._unit), count=counts[p])) ctx['annotationCol'] = annotate return Tmp.displayj2html("params.html", ctx)
[docs] @_maybeJupyter def showSpecies(self,sps): """Show details on a species type""" if isinstance(sps, str): if self.speciesList.is_defined(sps): sps = self.species(sps) else: raise ValueError("unknown species: {}".format(sps)) cs = self.particleStatistics() return Tmp.displayj2html("species.html", self._speciesJ2context(sps,cs))
[docs] @_maybeJupyter def showAllSpecies(self): """Inspect all species interactively""" cs = self.particleStatistics() topctx = dict(renderId=_renderId(), spData=[self._speciesJ2context(sps, cs) for sps in self.speciesList]) return Tmp.displayj2html("allSpecies.html", topctx)
[docs] @_maybeJupyter def showReactions(self, rxnList=None): """Display a table of all reactions""" ctx = self._reactionJ2context(rxnList or self.reactionList) return Tmp.displayj2html("reactionTable.html", ctx)
def _renderId(): return 'jLM_' + ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(8)) def _downloadFile(data, filename): import IPython.display as ipd return ipd.Javascript(Tmp.j2render("download.js", dict(filename=filename, blob=base64.b64encode(data).decode()))) def _showRegionStack(lattice, htmlNames, siteColors, plane='xz', scl=None, maxWidth=600, maxHeight=600): axis0 = 'xyz'.index(plane[0]) axis1 = 'xyz'.index(plane[1]) axis2 = (set((0,1,2)) - set((axis0,axis1))).pop() latticeShape = lattice.shape siteTypes = np.unique(lattice) ntypes = len(siteTypes) if scl is None: scl = int(min(maxWidth/latticeShape[axis0], maxHeight/latticeShape[axis1])) else: scl = int(scl) ctx = dict(w=scl*latticeShape[axis0], h=scl*latticeShape[axis1], plane=plane, reg=htmlNames, direction="xyz"[axis2], dim=latticeShape[axis2], start=_maxEntropySlice(lattice, axis2), renderId=_renderId()) pngData=[] for planeIndex in range(latticeShape[axis2]): stRGB = np.zeros((latticeShape[axis0],latticeShape[axis1],3)) def rgb8(i): return np.array([int(255*c) for c in siteColors[i]], dtype=np.uint8) stImg = np.zeros((latticeShape[axis0],latticeShape[axis1]),dtype=np.uint8) for i,j in itertools.product(range(latticeShape[axis0]),range(latticeShape[axis1])): idx=[0,0,0] idx[axis0] = latticeShape[axis0]-i - 1 idx[axis1] = j idx[axis2] = planeIndex stImg[i,j] = lattice[idx[0],idx[1],idx[2]] scl = int(min(maxWidth/stImg.shape[0], maxHeight/stImg.shape[1])) stRGB = np.zeros((stImg.shape[0],stImg.shape[1],3),dtype='uint8') for i,j in itertools.product(*map(range,stImg.shape)): stRGB[i,j,:] = rgb8(stImg[i,j]) imgFile = io.BytesIO() import PIL.Image img = PIL.Image.fromarray(stRGB) img.save(imgFile,'PNG') pngData.append(base64.b64encode(imgFile.getvalue()).decode()) ctx['pngData'] = '[' + ','.join('"'+x+'"' for x in pngData) + ']' return Tmp.displayj2html("sitestack.html", ctx) def _showBinaryLattices(binLattices,manualColor=None, filterFunctions=None, mode="widget"): assert mode in ['widget', 'download_x3d', 'download_html'] if isinstance(binLattices, np.ndarray): lattices = [("lattice", binLattices)] elif isinstance(binLattices, list): lattices = [ ("lattice{:02d}".format(d), l) for d,l in enumerate(binLattices) ] elif isinstance(binLattices, dict): lattices = list(sorted(binLattices.items())) else: raise TypeError latticeDims = lattices[0][1].shape scl = 7/max(latticeDims) centroid = 0.5*np.array(latticeDims) nx,ny,nz = latticeDims ctx = dict(centroid=centroid, scl=scl, renderId=_renderId(), sites=[], species=[]) filterFunctions = filterFunctions or dict() for i,(name, binLattice) in enumerate(lattices): if manualColor is not None: r,g,b = manualColor[name] else: r,g,b = Dsp.colorWheel(i/len(binLattices)) print(name, r,g,b) if binLattice.any(): if name in filterFunctions: posMatch = filterFunctions[name](*np.mgrid[0:nx, 0:ny, 0:nz]) binLattice = np.array(binLattice&posMatch, dtype=np.uint8) else: binLattice = np.array(binLattice, dtype=np.uint8) verts, faces = Lattice.greedyMesh(binLattice) c=Dsp.toHex((r,g,b)) ctx['sites'].append(dict(name=name, idx=i, hexColor=c, label=r'<span class="jLMregion" style="color:white;background:{};">{}</span>'.format(c, name), checked=" checked", faces=' '.join(str(x) for face in faces for x in face), verts=' '.join(str(x) for vert in verts for x in vert), r=r, g=g, b=b, choice="0")) if mode == 'widget': return Tmp.displayj2html("x3d.html", ctx) elif mode == 'download_x3d': ctx['downloadX3D'] = True xml = Tmp.j2render("structure.x3d", ctx) return _downloadFile(xml.encode("ascii"), "lattice.x3d") elif mode == 'download_html': data = Tmp.j2render("standaloneX3d.html", ctx) return _downloadFile(data.encode("ascii"), "lattice-3dView.html")
[docs] def showVolumeStack(vol, plane='xz', cmap='inferno', scl=None, maxWidth=600, maxHeight=600): """Display slices volumetric data interactively Args: vol (:py:class:`numpy.ndarray`): 3-D data Keyword Args: cmap (str): Name of matplotlib colormap plane (str): Viewing plane, e.g. "xy" scl (int): Scale pixels by this amount maxWidth (int): Maximum width of image maxHeight (int): Maximum height of image """ import matplotlib.pyplot as plt import PIL.Image axis0 = 'xyz'.index(plane[0]) axis1 = 'xyz'.index(plane[1]) axis2 = (set((0,1,2)) - set((axis0,axis1))).pop() vol = vol.transpose([axis0,axis1,axis2]) shape = vol.shape if scl is None: scl = int(min(maxWidth/shape[0], maxHeight/shape[1])) else: scl = int(scl) cm = getattr(plt.cm, cmap) vmin = vol.min() vmax = vol.max() cbw = 16 cbh = int(scl*shape[0]*0.6666) cbimg = (255*cm(np.array(cbw*[np.linspace(1,0,cbh)]).T)).astype(np.uint8) imgFile = io.BytesIO() img = PIL.Image.fromarray(cbimg) img.save(imgFile,'PNG') cbimg = base64.b64encode(imgFile.getvalue()).decode() ctx = dict(h=scl*shape[0], w=scl*shape[1], cbw=cbw, cbh=cbh, plane=plane, vmax=_num(vmax), vmin=_num(vmin), direction="xyz"[axis2], dim=shape[2], cbimg=cbimg, start=shape[2]//2, renderId=_renderId()) sl = [slice(None),slice(None),slice(None)] pngData=[] for planeIndex in range(shape[2]): sl[2] = planeIndex im = (vol[sl[0],sl[1],sl[2]]-vmin)/(vmax-vmin) rgb = (255*cm(im)).astype(np.uint8) imgFile = io.BytesIO() img = PIL.Image.fromarray(rgb) img.save(imgFile,'PNG') pngData.append(base64.b64encode(imgFile.getvalue()).decode()) ctx['pngData'] = '[' + ','.join('"'+x+'"' for x in pngData) + ']' return Tmp.displayj2html("datastack.html", ctx)
class _Report: def __init__(self): self._inctx = False def __call__(self, l, v, fmt=None): assert self._inctx if fmt: v = fmt.format(v) else: v = str(v) self._rows.append((l,v)) def __enter__(self): self._inctx = True self._rows = [] def __exit__(self, type, value, traceback): self._inctx = False html = ("<table class='jLMtbl'>" + "".join("<tr><th>{}</th><td>{}</td></tr>".format(l,v) for l,v in self._rows) + "</table>" ) import IPython.display as ipd ipd.display(ipd.HTML(html)) report = _Report()