Skip to content

Structure

Source code in src/baderkit/plotting/core/plotter.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
class StructurePlotter:

    def __init__(
        self,
        structure: Structure,
        off_screen: bool = False,
    ):
        """
        A convenience class for creating plots of crystal structures using
        pyvista's package for VTK.

        Parameters
        ----------
        structure : Structure
            The pymatgen Structure object to plot.
        off_screen : bool, optional
            Whether or not the plotter should be in offline mode. The default is False.

        Returns
        -------
        None.

        """
        # sort and relabel structure for consistency
        structure = structure.copy()
        structure.sort()
        structure.relabel_sites()
        # create initial class variables
        self.structure = structure
        self._off_screen = off_screen
        self._visible_atoms = [i for i in range(len(self.structure))]
        self._show_lattice = True
        self._wrap_atoms = True
        self._lattice_thickness = 0.1
        self._atom_metallicness = 0.0
        self._background = "#FFFFFF"
        self._view_indices = [1, 0, 0]
        self._camera_rotation = (0.0,)
        self._show_axes = True
        self._parallel_projection = True
        self._radii = [s.specie.atomic_radius for s in structure]
        self._colors = [ATOM_COLORS.get(s.specie.symbol, "#FFFFFF") for s in structure]
        # generate initial plotter
        self.plotter = self._create_structure_plot(off_screen)
        self.view_indices = [1, 0, 0]
        self.up_indices = [0, 0, 1]

    ###########################################################################
    # Properties and Setters
    ###########################################################################
    @property
    def visible_atoms(self) -> list[int]:
        """

        Returns
        -------
        list[int]
            A list of atom indices to display in the plot.

        """
        return self._visible_atoms

    @visible_atoms.setter
    def visible_atoms(self, visible_atoms: list[int]):
        # update visibility of atoms
        for i, site in enumerate(self.structure):
            label = site.label
            actor = self.plotter.actors[f"{label}"]
            if i in visible_atoms:
                actor.visibility = True
            else:
                actor.visibility = False
        # set visible atoms
        self._visible_atoms = visible_atoms

    @property
    def show_lattice(self) -> bool:
        """

        Returns
        -------
        bool
            Whether or not to display the outline of the unit cell.

        """
        return self._show_lattice

    @show_lattice.setter
    def show_lattice(self, show_lattice: bool):
        actor = self.plotter.actors["lattice"]
        actor.visibility = show_lattice
        self._show_lattice = show_lattice

    # @property
    # def wrap_atoms(self):
    #     return self._wrap_atoms

    # TODO: Make two sets of atoms with and without wraps?
    # @wrap_atoms.setter
    # def wrap_atoms(self, wrap_atoms: bool):
    #     actor = self.plotter.

    @property
    def lattice_thickness(self) -> float:
        """

        Returns
        -------
        float
            The thickness of the lines outlining the unit cell.

        """
        return self._lattice_thickness

    @lattice_thickness.setter
    def lattice_thickness(self, lattice_thickness: float):
        actor = self.plotter.actors["lattice"]
        actor.prop.line_width = lattice_thickness
        self._lattice_thickness = lattice_thickness

    @property
    def atom_metallicness(self) -> float:
        """

        Returns
        -------
        float
            The amount of metallic character in the atom display.

        """
        return self._atom_metallicness

    @atom_metallicness.setter
    def atom_metallicness(self, atom_metallicness: float):
        # update all atoms
        for site in self.structure:
            label = site.label
            actor = self.plotter.actors[f"{label}"]
            actor.prop.metallic = atom_metallicness
        self._atom_metallicness = atom_metallicness

    @property
    def background(self) -> str:
        """

        Returns
        -------
        str
            The color of the plot background as a hex code.

        """
        return self._background

    @background.setter
    def background(self, background: str):
        self.plotter.set_background(background)
        self._background = background

    @property
    def show_axes(self) -> bool:
        """

        Returns
        -------
        bool
            Whether or not to show the axis widget. Note this currently only
            displays the cartesian axes.

        """
        return self._show_axes

    @show_axes.setter
    def show_axes(self, show_axes: bool):
        if show_axes:
            self.plotter.add_axes()
        else:
            self.plotter.hide_axes()
        self._show_axes = show_axes

    @property
    def parallel_projection(self) -> bool:
        """

        Returns
        -------
        bool
            If True, a parallel projection scheme will be used rather than
            perspective.

        """
        return self._parallel_projection

    @parallel_projection.setter
    def parallel_projection(self, parallel_projection: bool):
        if parallel_projection:
            self.plotter.renderer.enable_parallel_projection()
        else:
            self.plotter.renderer.disable_parallel_projection()
        self._parallel_projection = parallel_projection

    @property
    def radii(self) -> list[float]:
        """

        Returns
        -------
        list[float]
            The radius to display for each atom in the structure. The actual
            displayed radius will be 0.3*radius.

        """
        return self._radii

    @radii.setter
    def radii(self, radii: list[float]):
        # fix radii to be a list and make any negative values == 0.01
        radii = list(radii)
        for i, val in enumerate(radii):
            if val <= 0:
                radii[i] = 0.01
        # check which radii have changed and replace these atoms
        old_radii = self.radii
        # update radii
        self._radii = radii
        # for each site, check if the radius has changed and if it has remove it
        # then remake
        for i, (site, old_r, new_r, color) in enumerate(
            zip(self.structure, old_radii, radii, self.colors)
        ):
            if old_r == new_r:
                continue
            # otherwise remove the actor, regenerate, and replot
            self.plotter.remove_actor(f"{site.label}")
            atom_mesh = self.get_site_mesh(i)
            self.plotter.add_mesh(
                atom_mesh,
                color=color,
                metallic=self.atom_metallicness,
                pbr=True,  # enable physical based rendering
                name=f"{site.label}",
            )

    @property
    def colors(self) -> list[str]:
        """

        Returns
        -------
        list[str]
            The colors to use for each atom as hex codes.

        """
        return self._colors

    @colors.setter
    def colors(self, colors: list[str]):
        # for each site, check if the radius has changed and if it has remove it
        # then remake
        for site, old_color, new_color in zip(self.structure, self.colors, colors):
            if old_color == new_color:
                continue
            actor = self.plotter.actors[f"{site.label}"]
            actor.prop.color = new_color
        self._colors = colors

    @property
    def atom_df(self) -> pd.DataFrame:
        """

        Returns
        -------
        atom_df : TYPE
            A dataframe summarizing the properties of the atom meshes.

        """
        # construct a pandas dataframe for each atom
        visible = []
        for i in range(len(self.structure)):
            if i in self.visible_atoms:
                visible.append(True)
            else:
                visible.append(False)
        atom_df = pd.DataFrame(
            {
                "Label": self.structure.labels,
                "Visible": visible,
                "Color": self.colors,
                "Radius": self.radii,
            }
        )
        return atom_df

    @atom_df.setter
    def atom_df(self, atom_df: pd.DataFrame):
        visible = atom_df["Visible"]
        visible_atoms = []
        for i, val in enumerate(visible):
            if val == True:
                visible_atoms.append(i)
        # set each property from the dataframe
        self.visible_atoms = visible_atoms
        self.colors = atom_df["Color"]
        self.radii = atom_df["Radius"]

    @property
    def view_indices(self) -> NDArray[int]:
        """

        Returns
        -------
        NDArray[int]
            The miller indices of the plane that the camera is perpendicular to.

        """
        return self._view_indices

    @view_indices.setter
    def view_indices(self, view_indices: NDArray[int]):
        assert len(view_indices) == 3 and all(
            type(i) == int for i in view_indices
        ), "View indices must be an array or list of miller indices"
        h, k, l = view_indices
        camera_position = self.get_camera_position_from_miller(
            h, k, l, self.camera_rotation
        )
        self.camera_position = camera_position
        # reset the camera zoom so that it fits the screen
        self.plotter.reset_camera()
        self._view_indices = view_indices

    @property
    def camera_rotation(self) -> float:
        """

        Returns
        -------
        float
            The rotation of the camera from the default. The default is to set
            the camera so that the upwards view is as close to the z axis as
            possible, or the y axis if the view indices are perpendicular to z.

        """
        return self._camera_rotation

    @camera_rotation.setter
    def camera_rotation(self, camera_rotation: float):
        h, k, l = self.view_indices
        camera_position = self.get_camera_position_from_miller(h, k, l, camera_rotation)
        self.camera_position = camera_position
        # reset the camera zoom so that it fits the screen
        self.plotter.reset_camera()
        self._camera_rotation = camera_rotation

    @property
    def camera_position(self) -> list[tuple, tuple, tuple]:
        """

        Returns
        -------
        list[tuple, tuple, tuple]
            The set of tuples defining the camera position. In order, this is
            the camera's position, the focal point, and the view up vector.

        """
        pos = self.plotter.camera_position
        # convert to list for serializability
        return [list(pos[0]), list(pos[1]), list(pos[2])]

    @camera_position.setter
    def camera_position(self, camera_position: NDArray):
        camera_position = np.array(camera_position).astype(float)
        if camera_position.ndim == 1:
            h, k, l = camera_position
            camera_pos = self.get_camera_position_from_miller(h, k, l)
            self.plotter.camera_position = camera_pos
        else:
            # convert to tuples
            camera_position = [
                tuple(camera_position[0]),
                tuple(camera_position[1]),
                tuple(camera_position[2]),
            ]
            self.plotter.camera_position = camera_position

    @staticmethod
    def get_edge_atom_fracs(frac_coord: NDArray, tol: float = 1e-08) -> NDArray:
        """
        Generates translationally equivalent atoms if coords are exactly on an edge
        of the lattice

        Parameters
        ----------
        frac_coord : NDArray
            The fractiona coordinates of a single atom to wrap.
        tol : float, optional
            The tolerance in fractional coords to consider an atom on an edge
            of the unit cell. The default is 1e-08.

        Returns
        -------
        NDArray
            The fractional coordinates of the atom wrapped at edges.

        """
        transforms = [
            [0, 1] if abs(x) < tol else [0, -1] if abs(x - 1) < tol else [0]
            for x in frac_coord
        ]

        shifts = set(product(*transforms))
        return [np.array(frac_coord) + np.array(shift) for shift in shifts]

    def get_camera_position_from_miller(
        self,
        h: int,
        k: int,
        l: int,
        rotation: float = 0,
    ) -> list[tuple, tuple, tuple]:
        """
        Creates a camera position list from a list of miller indices.

        Parameters
        ----------
        h : int
            First miller index.
        k : int
            Second miller index.
        l : int
            Third miller index.
        rotation: float
            The rotation in degrees of the camera. The default of 0 will arrange
            the camera as close to Z=1 as possible, or in the case that it this
            is parallel, it will default to close to Y=1.

        Returns
        -------
        list[tuple, tuple, tuple]
            The set of tuples defining the camera position. In order, this is
            the camera's position, the focal point, and the view up vector.

        """
        # check for all 0s and adjust
        if all([x == 0 for x in [h, k, l]]):
            h, k, l = 1, 0, 0
        # convert to vector perpendicular to the miller plane
        view_direction = self.structure.get_cart_from_miller(h, k, l)
        # Calculate a distance to the camera that doesn't clip any bodies. It's
        # fine if this is very large as methods using this function should reset
        # the camera after. We use half the sum of all lattice sides plus the largest
        # atom radius as this should always be well outside the camera's range
        camera_distance = sum(self.structure.lattice.lengths) + max(self.radii)

        # Set focal point as center of lattice
        matrix = self.structure.lattice.matrix
        far_corner = np.sum(matrix, axis=0)
        focal_point = far_corner / 2
        # set the cameras position by adding the view direction to the focal point.
        # The position is scaled by multiplying by the desired distance
        camera_position = focal_point + view_direction * camera_distance

        # Find an orthogonal vector that has the maximum z value. This is done
        # using Gram-Schmidt orthogonalization.
        z_axis = np.array([0, 0, 1])
        view_up = z_axis - np.dot(z_axis, view_direction) * view_direction
        norm_proj_z = np.linalg.norm(view_up)
        if norm_proj_z < 1e-14:
            # fallback to y-axis if view direction is exactly perpendicular to
            # the z direction
            y_axis = np.array([0, 1, 0])
            view_up = y_axis - np.dot(y_axis, view_direction) * view_direction

        # Now we rotate the camera. We intentionally rotate counter clockwise to
        # make the structure appear to rotate clockwise.
        # convert degrees to radians
        angle_rad = np.deg2rad(rotation)
        view_up_rot = view_up * np.cos(angle_rad) + np.cross(
            view_direction, view_up
        ) * np.sin(angle_rad)
        # return camera position
        return [
            tuple(camera_position),  # where the camera is
            tuple(focal_point),  # where it's looking
            tuple(view_up_rot),  # which direction is up
        ]

    def get_site_mesh(self, site_idx: int) -> pv.PolyData:
        """
        Generates a mesh for the provided site index.

        Parameters
        ----------
        site_idx : int
            The index of the atom to create the mesh for.

        Returns
        -------
        pv.PolyData
            A pyvista mesh representing an atom.

        """
        site = self.structure[site_idx]
        radius = self.radii[site_idx]
        frac_coords = site.frac_coords
        # wrap atom if on edge
        if self._wrap_atoms:
            all_frac_coords = self.get_edge_atom_fracs(frac_coords)
        else:
            all_frac_coords = [frac_coords]
        # convert to cart coords
        cart_coords = all_frac_coords @ self.structure.lattice.matrix
        # generate meshes for each atom
        spheres = []
        for cart_coord in cart_coords:
            spheres.append(
                pv.Sphere(
                    radius=radius * 0.3,
                    center=cart_coord,
                    theta_resolution=30,
                    phi_resolution=30,
                )
            )
        # merge all meshes
        return pv.merge(spheres)

    def get_all_site_meshes(self) -> list[pv.PolyData]:
        """
        Gets a list of pyvista meshes representing the atoms in the structure

        Returns
        -------
        meshes : pv.PolyData
            A list of pyvista meshes representing each atom.

        """
        meshes = [self.get_site_mesh(i) for i in range(len(self.structure))]
        return meshes

    def get_lattice_mesh(self) -> pv.PolyData:
        """
        Generates the mesh representing the outline of the unit cell.

        Returns
        -------
        pv.PolyData
            A pyvista mesh representing the outline of the unit cell.

        """
        # get the lattice matrix
        a, b, c = self.structure.lattice.matrix
        # get the corners of the matrix
        corners = [np.array([0, 0, 0]), a, b, c, a + b, a + c, b + c, a + b + c]
        # get the indices indicating edges of the lattice
        edges = [
            (0, 1),
            (0, 2),
            (0, 3),
            (1, 4),
            (1, 5),
            (2, 4),
            (2, 6),
            (3, 5),
            (3, 6),
            (4, 7),
            (5, 7),
            (6, 7),
        ]
        # generate lines with pv
        lines = []
        for i, j in edges:
            line = pv.Line(corners[i], corners[j])
            lines.append(line)
        # combine and return
        return pv.merge(lines)

    def _create_structure_plot(self, off_screen: bool) -> pv.Plotter:
        """
        Generates a pyvista.Plotter object from the current class variables.
        This is called when the class is first instanced and generally shouldn't
        be called again.

        Parameters
        ----------
        off_screen : bool
            Whether or not the plotter should run in off_screen mode.

        Returns
        -------
        plotter : pv.Plotter
            A pyvista Plotter object representing the provided Structure object.

        """
        plotter = pv.Plotter(off_screen=off_screen)
        # set background
        plotter.set_background(self.background)
        # add atoms
        atom_meshes = self.get_all_site_meshes()
        for i, (site, atom_mesh, color) in enumerate(
            zip(self.structure, atom_meshes, self.colors)
        ):
            actor = plotter.add_mesh(
                atom_mesh,
                color=color,
                metallic=self.atom_metallicness,
                pbr=True,  # enable physical based rendering
                name=f"{site.label}",
            )
            if not i in self.visible_atoms:
                actor.visibility = False

        # add lattice if desired
        lattice_mesh = self.get_lattice_mesh()
        plotter.add_mesh(
            lattice_mesh,
            line_width=self.lattice_thickness,
            color="k",
            name="lattice",
        )

        # # set camera direction
        # plotter.camera_position = self.get_camera_position()

        # set camera perspective type
        if self.parallel_projection:
            plotter.renderer.enable_parallel_projection()

        return plotter

    def show(self):
        """
        Renders the plot to a window. After closing the window, a new instance
        must be created to plot again. Pressing q pauses the rendering allowing
        changes to be made without fully exiting.

        Returns
        -------
        None.

        """
        self.plotter.show(auto_close=False)

    def update(self):
        """
        Updates the pyvista plotter object when linked with a render window in
        Trame. Generally this is not needed outside of Trame applications.

        Returns
        -------
        None.

        """
        self.plotter.update()

    def rebuild(self) -> pv.Plotter:
        """
        Builds a new pyvista plotter object representing the current state of
        the Plotter class.

        Returns
        -------
        pv.Plotter
            A pyvista Plotter object representing the current state of the
            StructurePlotter class.

        """
        return self._create_structure_plot(self._off_screen)

    def get_plot_html(self) -> str:
        """
        Creates an html string representing the current state of the StructurePlotter
        class.

        Returns
        -------
        str
            The html string representing the current StructurePlotter class.

        """
        if sys.platform == "win32":
            # We can return the html directly without opening a subprocess. And
            # we need to because the "fork" start method doesn't work
            html_plotter = self.plotter.export_html(filename=None)
            return html_plotter.read()
        # BUG-FIX: On Linux and maybe MacOS, pyvista's export_html must be run
        # as a main process. To do this within our streamlit apps, we use python's
        # multiprocess to run the process as is done in [stpyvista](https://github.com/edsaac/stpyvista/blob/main/src/stpyvista/trame_backend.py)
        queue = Queue(maxsize=1)
        process = Process(target=_export_html, args=(queue, self.plotter))
        process.start()
        html_plotter = queue.get().read()
        process.join()
        return html_plotter

    # def get_plot_vtksz(self):
    #     # Create temp file path manually
    #     with tempfile.NamedTemporaryFile(suffix=".vtkjs", delete=False) as f:
    #         temp_path = f.name  # Just get the path, don't use the open file
    #     # Now write to it
    #     self.plotter.export_vtksz(temp_path)
    #     # Read the contents
    #     with open(temp_path, "rb") as f:
    #         content = f.read()
    #     # Clean up
    #     os.remove(temp_path)

    #     return content

    def get_plot_screenshot(
        self,
        filename: str | Path | io.BytesIO = None,
        transparent_background: bool = None,
        return_img: bool = True,
        window_size: tuple[int, int] = None,
        scale: int = None,
    ) -> NDArray[float]:
        """
        Creates a screenshot of the current state of the StructurePlotter class.
        This is a wraparound of pyvista's screenshot method

        Parameters
        ----------
        filename: str | Path | io.BytesIO
            Location to write image to. If None, no image is written.

        transparent_background: bool
            Whether to make the background transparent.
            The default is looked up on the plotter’s theme.

        return_img: bool
            If True, a numpy.ndarray of the image will be returned. Defaults to
            True.

        window_size: tuple[int, int]
            Set the plotter’s size to this (width, height) before taking the
            screenshot.

        scale: int
            Set the factor to scale the window size to make a higher resolution image. If None this will use the image_scale property on this plotter which defaults to one.

        Returns
        -------
        NDArray[float]
            Array containing pixel RGB and alpha. Sized:

            [Window height x Window width x 3] if transparent_background is set to False.

            [Window height x Window width x 4] if transparent_background is set to True.

        """
        plotter = self.rebuild()
        plotter.camera = self.plotter.camera.copy()
        screenshot = plotter.screenshot(
            filename=filename,
            transparent_background=transparent_background,
            return_img=return_img,
            window_size=window_size,
            scale=scale,
        )
        plotter.close()
        return screenshot

atom_df property writable

Returns:

Name Type Description
atom_df TYPE

A dataframe summarizing the properties of the atom meshes.

atom_metallicness property writable

Returns:

Type Description
float

The amount of metallic character in the atom display.

background property writable

Returns:

Type Description
str

The color of the plot background as a hex code.

camera_position property writable

Returns:

Type Description
list[tuple, tuple, tuple]

The set of tuples defining the camera position. In order, this is the camera's position, the focal point, and the view up vector.

camera_rotation property writable

Returns:

Type Description
float

The rotation of the camera from the default. The default is to set the camera so that the upwards view is as close to the z axis as possible, or the y axis if the view indices are perpendicular to z.

colors property writable

Returns:

Type Description
list[str]

The colors to use for each atom as hex codes.

lattice_thickness property writable

Returns:

Type Description
float

The thickness of the lines outlining the unit cell.

parallel_projection property writable

Returns:

Type Description
bool

If True, a parallel projection scheme will be used rather than perspective.

radii property writable

Returns:

Type Description
list[float]

The radius to display for each atom in the structure. The actual displayed radius will be 0.3*radius.

show_axes property writable

Returns:

Type Description
bool

Whether or not to show the axis widget. Note this currently only displays the cartesian axes.

show_lattice property writable

Returns:

Type Description
bool

Whether or not to display the outline of the unit cell.

view_indices property writable

Returns:

Type Description
NDArray[int]

The miller indices of the plane that the camera is perpendicular to.

visible_atoms property writable

Returns:

Type Description
list[int]

A list of atom indices to display in the plot.

__init__(structure, off_screen=False)

A convenience class for creating plots of crystal structures using pyvista's package for VTK.

Parameters:

Name Type Description Default
structure Structure

The pymatgen Structure object to plot.

required
off_screen bool

Whether or not the plotter should be in offline mode. The default is False.

False

Returns:

Type Description
None.
Source code in src/baderkit/plotting/core/plotter.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def __init__(
    self,
    structure: Structure,
    off_screen: bool = False,
):
    """
    A convenience class for creating plots of crystal structures using
    pyvista's package for VTK.

    Parameters
    ----------
    structure : Structure
        The pymatgen Structure object to plot.
    off_screen : bool, optional
        Whether or not the plotter should be in offline mode. The default is False.

    Returns
    -------
    None.

    """
    # sort and relabel structure for consistency
    structure = structure.copy()
    structure.sort()
    structure.relabel_sites()
    # create initial class variables
    self.structure = structure
    self._off_screen = off_screen
    self._visible_atoms = [i for i in range(len(self.structure))]
    self._show_lattice = True
    self._wrap_atoms = True
    self._lattice_thickness = 0.1
    self._atom_metallicness = 0.0
    self._background = "#FFFFFF"
    self._view_indices = [1, 0, 0]
    self._camera_rotation = (0.0,)
    self._show_axes = True
    self._parallel_projection = True
    self._radii = [s.specie.atomic_radius for s in structure]
    self._colors = [ATOM_COLORS.get(s.specie.symbol, "#FFFFFF") for s in structure]
    # generate initial plotter
    self.plotter = self._create_structure_plot(off_screen)
    self.view_indices = [1, 0, 0]
    self.up_indices = [0, 0, 1]

get_all_site_meshes()

Gets a list of pyvista meshes representing the atoms in the structure

Returns:

Name Type Description
meshes PolyData

A list of pyvista meshes representing each atom.

Source code in src/baderkit/plotting/core/plotter.py
557
558
559
560
561
562
563
564
565
566
567
568
def get_all_site_meshes(self) -> list[pv.PolyData]:
    """
    Gets a list of pyvista meshes representing the atoms in the structure

    Returns
    -------
    meshes : pv.PolyData
        A list of pyvista meshes representing each atom.

    """
    meshes = [self.get_site_mesh(i) for i in range(len(self.structure))]
    return meshes

get_camera_position_from_miller(h, k, l, rotation=0)

Creates a camera position list from a list of miller indices.

Parameters:

Name Type Description Default
h int

First miller index.

required
k int

Second miller index.

required
l int

Third miller index.

required
rotation float

The rotation in degrees of the camera. The default of 0 will arrange the camera as close to Z=1 as possible, or in the case that it this is parallel, it will default to close to Y=1.

0

Returns:

Type Description
list[tuple, tuple, tuple]

The set of tuples defining the camera position. In order, this is the camera's position, the focal point, and the view up vector.

Source code in src/baderkit/plotting/core/plotter.py
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
def get_camera_position_from_miller(
    self,
    h: int,
    k: int,
    l: int,
    rotation: float = 0,
) -> list[tuple, tuple, tuple]:
    """
    Creates a camera position list from a list of miller indices.

    Parameters
    ----------
    h : int
        First miller index.
    k : int
        Second miller index.
    l : int
        Third miller index.
    rotation: float
        The rotation in degrees of the camera. The default of 0 will arrange
        the camera as close to Z=1 as possible, or in the case that it this
        is parallel, it will default to close to Y=1.

    Returns
    -------
    list[tuple, tuple, tuple]
        The set of tuples defining the camera position. In order, this is
        the camera's position, the focal point, and the view up vector.

    """
    # check for all 0s and adjust
    if all([x == 0 for x in [h, k, l]]):
        h, k, l = 1, 0, 0
    # convert to vector perpendicular to the miller plane
    view_direction = self.structure.get_cart_from_miller(h, k, l)
    # Calculate a distance to the camera that doesn't clip any bodies. It's
    # fine if this is very large as methods using this function should reset
    # the camera after. We use half the sum of all lattice sides plus the largest
    # atom radius as this should always be well outside the camera's range
    camera_distance = sum(self.structure.lattice.lengths) + max(self.radii)

    # Set focal point as center of lattice
    matrix = self.structure.lattice.matrix
    far_corner = np.sum(matrix, axis=0)
    focal_point = far_corner / 2
    # set the cameras position by adding the view direction to the focal point.
    # The position is scaled by multiplying by the desired distance
    camera_position = focal_point + view_direction * camera_distance

    # Find an orthogonal vector that has the maximum z value. This is done
    # using Gram-Schmidt orthogonalization.
    z_axis = np.array([0, 0, 1])
    view_up = z_axis - np.dot(z_axis, view_direction) * view_direction
    norm_proj_z = np.linalg.norm(view_up)
    if norm_proj_z < 1e-14:
        # fallback to y-axis if view direction is exactly perpendicular to
        # the z direction
        y_axis = np.array([0, 1, 0])
        view_up = y_axis - np.dot(y_axis, view_direction) * view_direction

    # Now we rotate the camera. We intentionally rotate counter clockwise to
    # make the structure appear to rotate clockwise.
    # convert degrees to radians
    angle_rad = np.deg2rad(rotation)
    view_up_rot = view_up * np.cos(angle_rad) + np.cross(
        view_direction, view_up
    ) * np.sin(angle_rad)
    # return camera position
    return [
        tuple(camera_position),  # where the camera is
        tuple(focal_point),  # where it's looking
        tuple(view_up_rot),  # which direction is up
    ]

get_edge_atom_fracs(frac_coord, tol=1e-08) staticmethod

Generates translationally equivalent atoms if coords are exactly on an edge of the lattice

Parameters:

Name Type Description Default
frac_coord NDArray

The fractiona coordinates of a single atom to wrap.

required
tol float

The tolerance in fractional coords to consider an atom on an edge of the unit cell. The default is 1e-08.

1e-08

Returns:

Type Description
NDArray

The fractional coordinates of the atom wrapped at edges.

Source code in src/baderkit/plotting/core/plotter.py
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
@staticmethod
def get_edge_atom_fracs(frac_coord: NDArray, tol: float = 1e-08) -> NDArray:
    """
    Generates translationally equivalent atoms if coords are exactly on an edge
    of the lattice

    Parameters
    ----------
    frac_coord : NDArray
        The fractiona coordinates of a single atom to wrap.
    tol : float, optional
        The tolerance in fractional coords to consider an atom on an edge
        of the unit cell. The default is 1e-08.

    Returns
    -------
    NDArray
        The fractional coordinates of the atom wrapped at edges.

    """
    transforms = [
        [0, 1] if abs(x) < tol else [0, -1] if abs(x - 1) < tol else [0]
        for x in frac_coord
    ]

    shifts = set(product(*transforms))
    return [np.array(frac_coord) + np.array(shift) for shift in shifts]

get_lattice_mesh()

Generates the mesh representing the outline of the unit cell.

Returns:

Type Description
PolyData

A pyvista mesh representing the outline of the unit cell.

Source code in src/baderkit/plotting/core/plotter.py
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
def get_lattice_mesh(self) -> pv.PolyData:
    """
    Generates the mesh representing the outline of the unit cell.

    Returns
    -------
    pv.PolyData
        A pyvista mesh representing the outline of the unit cell.

    """
    # get the lattice matrix
    a, b, c = self.structure.lattice.matrix
    # get the corners of the matrix
    corners = [np.array([0, 0, 0]), a, b, c, a + b, a + c, b + c, a + b + c]
    # get the indices indicating edges of the lattice
    edges = [
        (0, 1),
        (0, 2),
        (0, 3),
        (1, 4),
        (1, 5),
        (2, 4),
        (2, 6),
        (3, 5),
        (3, 6),
        (4, 7),
        (5, 7),
        (6, 7),
    ]
    # generate lines with pv
    lines = []
    for i, j in edges:
        line = pv.Line(corners[i], corners[j])
        lines.append(line)
    # combine and return
    return pv.merge(lines)

get_plot_html()

Creates an html string representing the current state of the StructurePlotter class.

Returns:

Type Description
str

The html string representing the current StructurePlotter class.

Source code in src/baderkit/plotting/core/plotter.py
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
def get_plot_html(self) -> str:
    """
    Creates an html string representing the current state of the StructurePlotter
    class.

    Returns
    -------
    str
        The html string representing the current StructurePlotter class.

    """
    if sys.platform == "win32":
        # We can return the html directly without opening a subprocess. And
        # we need to because the "fork" start method doesn't work
        html_plotter = self.plotter.export_html(filename=None)
        return html_plotter.read()
    # BUG-FIX: On Linux and maybe MacOS, pyvista's export_html must be run
    # as a main process. To do this within our streamlit apps, we use python's
    # multiprocess to run the process as is done in [stpyvista](https://github.com/edsaac/stpyvista/blob/main/src/stpyvista/trame_backend.py)
    queue = Queue(maxsize=1)
    process = Process(target=_export_html, args=(queue, self.plotter))
    process.start()
    html_plotter = queue.get().read()
    process.join()
    return html_plotter

get_plot_screenshot(filename=None, transparent_background=None, return_img=True, window_size=None, scale=None)

Creates a screenshot of the current state of the StructurePlotter class. This is a wraparound of pyvista's screenshot method

Parameters:

Name Type Description Default
filename str | Path | BytesIO

Location to write image to. If None, no image is written.

None
transparent_background bool

Whether to make the background transparent. The default is looked up on the plotter’s theme.

None
return_img bool

If True, a numpy.ndarray of the image will be returned. Defaults to True.

True
window_size tuple[int, int]

Set the plotter’s size to this (width, height) before taking the screenshot.

None
scale int

Set the factor to scale the window size to make a higher resolution image. If None this will use the image_scale property on this plotter which defaults to one.

None

Returns:

Type Description
NDArray[float]

Array containing pixel RGB and alpha. Sized:

[Window height x Window width x 3] if transparent_background is set to False.

[Window height x Window width x 4] if transparent_background is set to True.

Source code in src/baderkit/plotting/core/plotter.py
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
def get_plot_screenshot(
    self,
    filename: str | Path | io.BytesIO = None,
    transparent_background: bool = None,
    return_img: bool = True,
    window_size: tuple[int, int] = None,
    scale: int = None,
) -> NDArray[float]:
    """
    Creates a screenshot of the current state of the StructurePlotter class.
    This is a wraparound of pyvista's screenshot method

    Parameters
    ----------
    filename: str | Path | io.BytesIO
        Location to write image to. If None, no image is written.

    transparent_background: bool
        Whether to make the background transparent.
        The default is looked up on the plotter’s theme.

    return_img: bool
        If True, a numpy.ndarray of the image will be returned. Defaults to
        True.

    window_size: tuple[int, int]
        Set the plotter’s size to this (width, height) before taking the
        screenshot.

    scale: int
        Set the factor to scale the window size to make a higher resolution image. If None this will use the image_scale property on this plotter which defaults to one.

    Returns
    -------
    NDArray[float]
        Array containing pixel RGB and alpha. Sized:

        [Window height x Window width x 3] if transparent_background is set to False.

        [Window height x Window width x 4] if transparent_background is set to True.

    """
    plotter = self.rebuild()
    plotter.camera = self.plotter.camera.copy()
    screenshot = plotter.screenshot(
        filename=filename,
        transparent_background=transparent_background,
        return_img=return_img,
        window_size=window_size,
        scale=scale,
    )
    plotter.close()
    return screenshot

get_site_mesh(site_idx)

Generates a mesh for the provided site index.

Parameters:

Name Type Description Default
site_idx int

The index of the atom to create the mesh for.

required

Returns:

Type Description
PolyData

A pyvista mesh representing an atom.

Source code in src/baderkit/plotting/core/plotter.py
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
def get_site_mesh(self, site_idx: int) -> pv.PolyData:
    """
    Generates a mesh for the provided site index.

    Parameters
    ----------
    site_idx : int
        The index of the atom to create the mesh for.

    Returns
    -------
    pv.PolyData
        A pyvista mesh representing an atom.

    """
    site = self.structure[site_idx]
    radius = self.radii[site_idx]
    frac_coords = site.frac_coords
    # wrap atom if on edge
    if self._wrap_atoms:
        all_frac_coords = self.get_edge_atom_fracs(frac_coords)
    else:
        all_frac_coords = [frac_coords]
    # convert to cart coords
    cart_coords = all_frac_coords @ self.structure.lattice.matrix
    # generate meshes for each atom
    spheres = []
    for cart_coord in cart_coords:
        spheres.append(
            pv.Sphere(
                radius=radius * 0.3,
                center=cart_coord,
                theta_resolution=30,
                phi_resolution=30,
            )
        )
    # merge all meshes
    return pv.merge(spheres)

rebuild()

Builds a new pyvista plotter object representing the current state of the Plotter class.

Returns:

Type Description
Plotter

A pyvista Plotter object representing the current state of the StructurePlotter class.

Source code in src/baderkit/plotting/core/plotter.py
685
686
687
688
689
690
691
692
693
694
695
696
697
def rebuild(self) -> pv.Plotter:
    """
    Builds a new pyvista plotter object representing the current state of
    the Plotter class.

    Returns
    -------
    pv.Plotter
        A pyvista Plotter object representing the current state of the
        StructurePlotter class.

    """
    return self._create_structure_plot(self._off_screen)

show()

Renders the plot to a window. After closing the window, a new instance must be created to plot again. Pressing q pauses the rendering allowing changes to be made without fully exiting.

Returns:

Type Description
None.
Source code in src/baderkit/plotting/core/plotter.py
660
661
662
663
664
665
666
667
668
669
670
671
def show(self):
    """
    Renders the plot to a window. After closing the window, a new instance
    must be created to plot again. Pressing q pauses the rendering allowing
    changes to be made without fully exiting.

    Returns
    -------
    None.

    """
    self.plotter.show(auto_close=False)

update()

Updates the pyvista plotter object when linked with a render window in Trame. Generally this is not needed outside of Trame applications.

Returns:

Type Description
None.
Source code in src/baderkit/plotting/core/plotter.py
673
674
675
676
677
678
679
680
681
682
683
def update(self):
    """
    Updates the pyvista plotter object when linked with a render window in
    Trame. Generally this is not needed outside of Trame applications.

    Returns
    -------
    None.

    """
    self.plotter.update()