Source code for fenics_plotly.fenics_plotly

import os
from pathlib import Path

import fenics as fe
import numpy as np
import plotly
import plotly.graph_objects as go
import plotly.io as pio

try:
    _SHOW_PLOT = bool(int(os.getenv("FENICS_PLOTLY_SHOW", 1)))
except ValueError:
    _SHOW_PLOT = True


[docs]def set_renderer(renderer): pio.renderers.default = renderer
[docs]def savefig(fig, filename, save_config=None): """Save figure to file Parameters ---------- fig : `plotly.graph_objects.Figure` This figure that you want to save filename : Path or str Path to the desitnation where you want to save the figure save_config : dict, optional Additionsal cofigurations to be assed to `plotly.offline.plot`, by default None """ filename = Path(filename) outdir = filename.parent assert outdir.exists(), f"Folder {outdir} does not exist" config = { "toImageButtonOptions": { "filename": filename.stem, "width": 1500, "height": 1200, } } if save_config is not None: config.update(save_config) path = outdir.joinpath(filename) plotly.offline.plot(fig, filename=path.as_posix(), auto_open=False, config=config)
def _get_vertex_values(function): space = function.function_space() mesh = space.mesh() shape = function.ufl_shape if isinstance(space.ufl_element(), fe.FiniteElement): return function.compute_vertex_values() else: size_n = shape[0] if isinstance(space.ufl_element(), fe.VectorElement): res = np.zeros((mesh.num_entities(0), shape[0])) for i in range(size_n): res[:, i] = function.sub(i, deepcopy=True).compute_vertex_values() if isinstance(space.ufl_element(), fe.TensorElement): res = np.zeros((mesh.num_entities(0), shape[0], shape[1])) count = 0 for i in range(size_n): for j in range(size_n): res[:, i, j] = function.sub( count + j, deepcopy=True ).compute_vertex_values() count += size_n return res def _get_triangles(mesh): num_faces = mesh.num_faces() if num_faces == 0: num_faces = len(list(fe.faces(mesh))) triangle = np.zeros((3, num_faces), dtype=int) for ind, face in enumerate(fe.faces(mesh)): triangle[:, ind] = face.entities(0) return triangle def _cone_plot(function, size=10, showscale=True, normalize=False, **kwargs): mesh = function.function_space().mesh() points = mesh.coordinates() vectors = _get_vertex_values(function) if len(points[0, :]) == 2: points = np.c_[points, np.zeros(len(points[:, 0]))] if vectors.shape[1] == 2: vectors = np.c_[vectors, np.zeros(len(vectors[:, 0]))] if normalize: vectors = np.divide(vectors.T, np.linalg.norm(vectors, axis=1)).T cones = go.Cone( x=points[:, 0], y=points[:, 1], z=points[:, 2], u=vectors[:, 0], v=vectors[:, 1], w=vectors[:, 2], sizemode="absolute", sizeref=size, showscale=showscale, ) return cones def _wireframe_plot_mesh(mesh, **kwargs): coord = mesh.coordinates() triangle = mesh.cells() if len(coord[0, :]) == 2: coord = np.c_[coord, np.zeros(len(coord[:, 0]))] tri_points = coord[triangle] Xe = [] Ye = [] Ze = [] for T in tri_points: Xe.extend([T[k % 3][0] for k in range(4)] + [None]) Ye.extend([T[k % 3][1] for k in range(4)] + [None]) Ze.extend([T[k % 3][2] for k in range(4)] + [None]) # define the trace for triangle sides lines = go.Scatter3d( x=Xe, y=Ye, z=Ze, mode="lines", name="", line=dict(color="rgb(70,70,70)", width=2), hoverinfo="none", ) return lines def _surface_plot_function( function, colorscale, showscale=True, intensitymode="vertex", **kwargs ): mesh = function.function_space().mesh() if intensitymode == "vertex": val = function.compute_vertex_values() triangle = _get_triangles(mesh) else: # Spooky! val = function.vector().get_local() triangle = mesh.cells().T coord = mesh.coordinates() hoverinfo = ["val:" + "%.5f" % item for item in val] if len(coord[0, :]) == 2: coord = np.c_[coord, np.zeros(len(coord[:, 0]))] surface = go.Mesh3d( x=coord[:, 0], y=coord[:, 1], z=coord[:, 2], i=triangle[0, :], j=triangle[1, :], k=triangle[2, :], flatshading=True, intensitymode=intensitymode, intensity=val, colorscale=colorscale, lighting=dict(ambient=1), name="", hoverinfo="all", text=hoverinfo, showscale=showscale, ) return surface def _scatter_plot_function(function, colorscale, showscale=True, size=10, **kwargs): dofs_coord = function.function_space().tabulate_dof_coordinates() if len(dofs_coord[0, :]) == 2: dofs_coord = np.c_[dofs_coord, np.zeros(len(dofs_coord[:, 0]))] mesh = function.function_space().mesh() val = function.vector().get_local() coord = mesh.coordinates() hoverinfo = ["val:" + "%.5f" % item for item in val] if len(coord[0, :]) == 2: coord = np.c_[coord, np.zeros(len(coord[:, 0]))] points = go.Scatter3d( x=dofs_coord[:, 0], y=dofs_coord[:, 1], z=dofs_coord[:, 2], mode="markers", marker=dict(size=size, color=val, colorscale=colorscale), hoverinfo="all", text=hoverinfo, ) return points def _surface_plot_mesh(mesh, color, opacity=1.0, **kwargs): coord = mesh.coordinates() triangle = _get_triangles(mesh) if len(coord[0, :]) == 2: coord = np.c_[coord, np.zeros(len(coord[:, 0]))] surface = go.Mesh3d( x=coord[:, 0], y=coord[:, 1], z=coord[:, 2], i=triangle[0, :], j=triangle[1, :], k=triangle[2, :], flatshading=True, color=color, opacity=opacity, lighting=dict(ambient=1), ) return surface def _plot_dofs(functionspace, size, **kwargs): dofs_coord = functionspace.tabulate_dof_coordinates() if len(dofs_coord[0, :]) == 2: dofs_coord = np.c_[dofs_coord, np.zeros(len(dofs_coord[:, 0]))] points = go.Scatter3d( x=dofs_coord[:, 0], y=dofs_coord[:, 1], z=dofs_coord[:, 2], mode="markers", name=kwargs.get("name", None), marker=dict(size=size), ) return points def _plot_dirichlet_bc(diri, size, colorscale="inferno", **kwargs): values = diri.get_boundary_values() dofs = diri.function_space().tabulate_dof_coordinates() if len(dofs[0, :]) == 2: dofs = np.c_[dofs, np.zeros(len(dofs[:, 0]))] coords_diri = np.array([dofs[i] for i in values.keys()]) vals_diri = np.array(list(values.values())) points = go.Scatter3d( x=coords_diri[:, 0], y=coords_diri[:, 1], z=coords_diri[:, 2], mode="markers", marker=dict( size=size, color=vals_diri, colorscale=colorscale, colorbar=dict(thickness=20), ), ) return points def _surface_plot_meshfunc(meshfunc, colorscale, **kwargs): assert meshfunc.dim() == 2 mesh = meshfunc.mesh() array = meshfunc.array() coord = mesh.coordinates() if len(coord[0, :]) == 2: coord = np.c_[coord, np.zeros(len(coord[:, 0]))] triangle = _get_triangles(mesh) hoverinfo = ["val:" + "%d" % item for item in array] surface = go.Mesh3d( x=coord[:, 0], y=coord[:, 1], z=coord[:, 2], i=triangle[0, :], j=triangle[1, :], k=triangle[2, :], flatshading=True, intensity=array, colorscale=colorscale, lighting=dict(ambient=1), name="", hoverinfo="all", text=hoverinfo, intensitymode="cell", ) return surface def _handle_mesh(obj, **kwargs): data = [] wireframe = kwargs.get("wireframe", True) if wireframe: surf = _surface_plot_mesh(obj, **kwargs) data.append(surf) data.append(_wireframe_plot_mesh(obj)) return data def _handle_function( obj, **kwargs, ): data = [] scatter = kwargs.get("scatter", False) norm = kwargs.get("norm", False) component = kwargs.get("component", None) if len(obj.ufl_shape) == 0: if scatter: surface = _scatter_plot_function(obj, **kwargs) else: surface = _surface_plot_function(obj, **kwargs) data.append(surface) elif len(obj.ufl_shape) == 1: if norm or component == "magnitude": V = obj.function_space().split()[0].collapse() magnitude = fe.project(fe.sqrt(fe.inner(obj, obj)), V) else: magnitude = None if component is None: if norm: surface = _surface_plot_function(magnitude, **kwargs) data.append(surface) cones = _cone_plot(obj, **kwargs) data.append(cones) else: if component == "magnitude": surface = _surface_plot_function(magnitude, **kwargs) data.append(surface) else: for i, comp in enumerate(["x", "y", "z"]): if component not in [comp, comp.upper()]: continue surface = _surface_plot_function( obj.sub(i, deepcopy=True), **kwargs ) data.append(surface) if kwargs.get("wireframe", True): lines = _wireframe_plot_mesh(obj.function_space().mesh()) data.append(lines) return data def _handle_meshfunction(obj, **kwargs): data = [] if obj.dim() == 3: # This is a cell function of dimension 3 V = fe.FunctionSpace(obj.mesh(), "DG", 0) f = fe.Function(V) f.vector()[:] = obj.array() if kwargs.get("scatter", True): surf = _scatter_plot_function(f, **kwargs) else: surf = _surface_plot_function(f, intensitymode="cell", **kwargs) else: surf = _surface_plot_meshfunc(obj, **kwargs) data.append(surf) if kwargs.get("wireframe", True): lines = _wireframe_plot_mesh(obj.mesh(), **kwargs) data.append(lines) return data def _handle_function_space(obj, **kwargs): data = [] points = _plot_dofs(obj, **kwargs) data.append(points) if kwargs.get("wireframe", True): lines = _wireframe_plot_mesh(obj.mesh(), **kwargs) data.append(lines) return data def _handle_dirichlet_bc(obj, **kwargs): data = [] points = _plot_dirichlet_bc(obj, **kwargs) data.append(points) if kwargs.get("wireframe", True): lines = _wireframe_plot_mesh(obj.function_space().mesh(), **kwargs) data.append(lines) return data
[docs]def plot( obj, colorscale="inferno", wireframe=True, scatter=False, size=10, norm=False, name="f", color="gray", opacity=1.0, show_grid=False, size_frame=None, background=(242, 242, 242), normalize=False, component=None, showscale=True, show=True, filename=None, ): """Plot FEnICS object Parameters ---------- obj : Mesh, Function. FunctionoSpace, MeshFunction, DirichleyBC FEnicS object to be plotted colorscale : str, optional The colorscale, by default "inferno" wireframe : bool, optional Whether you want to show the mesh in wirteframe, by default True scatter : bool, optional Plot function as scatter plot, by default False size : int, optional Size of scatter points, by default 10 norm : bool, optional For vectors plot the norm as a surface, by default False name : str, optional Name to show up in legend, by default "f" color : str, optional Color to be plotted on the mesh, by default "gray" opacity : float, optional opacity of surface, by default 1.0 show_grid : bool, optional Show x, y (and z) axis grid, by default False size_frame : [type], optional Size of plot, by default None background : tuple, optional Background of plot, by default (242, 242, 242) normalize : bool, optional For vectors, normalize then to have unit length, by default False component : [type], optional Plot a componenent (["Magnitude", "x", "y", "z"]) for vector, by default None showscale : bool, optional Show colorbar, by default True show : bool, optional Show figure, by default True filename : [type], optional Path to file where you want to save the figure, by default None Raises ------ TypeError If object to be plotted is not recognized. """ if isinstance(obj, fe.Mesh): handle = _handle_mesh elif isinstance(obj, fe.Function): handle = _handle_function elif isinstance(obj, fe.cpp.mesh.MeshFunctionSizet): handle = _handle_meshfunction elif isinstance(obj, fe.FunctionSpace): handle = _handle_function_space elif isinstance(obj, fe.DirichletBC): handle = _handle_dirichlet_bc else: raise TypeError(f"Cannot plot object of type {type(obj)}") data = handle( obj, scatter=scatter, colorscale=colorscale, norm=norm, normalize=normalize, size=size, size_frame=size_frame, component=component, opacity=opacity, show_grid=show_grid, color=color, wireframe=wireframe, showscale=showscale, name=name, ) layout = go.Layout( scene_xaxis_visible=show_grid, scene_yaxis_visible=show_grid, scene_zaxis_visible=show_grid, paper_bgcolor="rgb" + str(background), margin=dict(l=80, r=80, t=50, b=50), scene=dict(aspectmode="data"), ) if size_frame is not None: layout.update(width=size_frame[0], height=size_frame[1]) fig = go.FigureWidget(data=data, layout=layout) fig.update_layout(hovermode="closest") if filename is not None: savefig(fig, filename) if show and _SHOW_PLOT: fig.show() return FEniCSPlotFig(fig)
[docs]class FEniCSPlotFig: def __init__(self, fig): self.figure = fig
[docs] def add_plot(self, fig): data = list(self.figure.data) + list(fig.figure.data) self.figure = go.FigureWidget(data=data, layout=self.figure.layout)
[docs] def show(self): if _SHOW_PLOT: self.figure.show()
[docs] def save(self, filename): savefig(self.figure, filename)