Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing_extensions import Self, Any

from vtkmodules.vtkCommonDataModel import vtkMultiBlockDataSet, vtkDataSet
from vtkmodules.vtkCommonDataModel import VTK_TRIANGLE,VTK_QUAD,VTK_LINE, VTK_POLY_LINE

from geos.mesh.utils.arrayModifiers import createAttribute
from geos.mesh.utils.arrayHelpers import ( getAttributeSet, getNumberOfComponents, getArrayInObject )
Expand Down Expand Up @@ -53,7 +54,7 @@

# Set the attributes to compare:
dictAttributesToCompare: dict[ Piece, set[ str ] ]
attributesDiffFilter.setDicAttributesToCompare( dicAttributesToCompare )
attributesDiffFilter.setDictAttributesToCompare( dictAttributesToCompare )

# Set the inf norm computation (if wanted):
computeInfNorm: bool
Expand All @@ -75,6 +76,8 @@ class AttributesDiff:
def __init__(
self: Self,
speHandler: bool = False,
computePoints: bool = True,
computeCells: bool = True,
) -> None:
"""Compute differences (L1 and inf norm) between two identical meshes attributes.

Expand All @@ -96,6 +99,9 @@ def __init__(

self.outputMesh: vtkMultiBlockDataSet | vtkDataSet = vtkMultiBlockDataSet()

self.computeCells: bool = computeCells
self.computePoints: bool = computePoints

# Logger.
self.logger: Logger
if not speHandler:
Expand Down Expand Up @@ -138,12 +144,14 @@ def setMeshes(
raise ValueError( "The list of meshes must contain two meshes." )

if listMeshes[ 0 ].GetClassName() != listMeshes[ 1 ].GetClassName():
raise TypeError( "The meshes must have the same type." )
raise TypeError( f"The meshes must have the same type. {listMeshes[0].GetClassName()} and {listMeshes[1].GetClassName()}" )

dictMeshesMaxElementId: dict[ Piece, list[ int ] ] = {}
if self.computeCells:
dictMeshesMaxElementId.update({ Piece.CELLS: [ 0, 0 ]})
if self.computePoints:
dictMeshesMaxElementId.update({ Piece.POINTS: [ 0, 0 ]})

dictMeshesMaxElementId: dict[ Piece, list[ int ] ] = {
Piece.CELLS: [ 0, 0 ],
Piece.POINTS: [ 0, 0 ],
}
if isinstance( listMeshes[ 0 ], vtkDataSet ):
for meshId, mesh in enumerate( listMeshes ):
for piece in dictMeshesMaxElementId:
Expand Down Expand Up @@ -171,8 +179,10 @@ def setMeshes(
raise ValueError( f"The total number of { piece.value } in the meshes must be the same." )

self.listMeshes = listMeshes
self.dictNbElements[ Piece.CELLS ] = dictMeshesMaxElementId[ Piece.CELLS ][ 0 ] + 1
self.dictNbElements[ Piece.POINTS ] = dictMeshesMaxElementId[ Piece.POINTS ][ 0 ] + 1
if self.computeCells:
self.dictNbElements[ Piece.CELLS ] = dictMeshesMaxElementId[ Piece.CELLS ][ 0 ] + 1
if self.computePoints:
self.dictNbElements[ Piece.POINTS ] = dictMeshesMaxElementId[ Piece.POINTS ][ 0 ] + 1
self.outputMesh = listMeshes[ 0 ].NewInstance()
self.outputMesh.ShallowCopy( listMeshes[ 0 ] )
self._computeDictSharedAttributes()
Expand Down Expand Up @@ -219,6 +229,9 @@ def setDictAttributesToCompare( self: Self, dictAttributesToCompare: dict[ Piece
Raises:
ValueError: At least one attribute to compare is not a shared attribute.
"""
assert not ((Piece.CELLS in dictAttributesToCompare) ^ (self.computeCells))
assert not ((Piece.POINTS in dictAttributesToCompare) ^ (self.computePoints))

for piece, setSharedAttributesToCompare in dictAttributesToCompare.items():
if not setSharedAttributesToCompare.issubset( self.dictSharedAttributes[ piece ] ):
wrongAttributes: set[ str ] = setSharedAttributesToCompare.difference(
Expand Down Expand Up @@ -307,6 +320,8 @@ def _computeDictAttributesArray( self: Self ) -> None:
listMeshBlockId: list[ int ] = getBlockElementIndexesFlatten( mesh )
for meshBlockId in listMeshBlockId:
dataset: vtkDataSet = vtkDataSet.SafeDownCast( mesh.GetDataSet( meshBlockId ) )
if dataset.GetCell(0).GetCellType() in [VTK_TRIANGLE, VTK_QUAD, VTK_LINE, VTK_POLY_LINE]:
continue
arrayAttributeData = getArrayInObject( dataset, attributeName, piece )
nbAttributeComponents = getNumberOfComponents( dataset, attributeName, piece )
lToG: npt.NDArray[ Any ] = getArrayInObject( dataset, "localToGlobalMap", piece )
Expand Down
Loading