Created
October 8, 2022 14:50
-
-
Save dendenxu/ca8a1816919a0fee915d3e72187894c1 to your computer and use it in GitHub Desktop.
Differentiable pytorch implementation of A Halfedge Refinement Rule for Parallel Loop Subdivision: http://kenneth.vanhoey.free.fr/index.php?page=research&lang=en#VD22, the original paper uses OpenGL and OpenMP, we provide a pytorch implementation of half-edge and triangle conversion, parallel loop subdivision algorithm (with considerations for e…
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # NOTE: THIS WON'T RUN, PLZ IMPLEMENT YOUR OWN load_mesh AND export_mesh, THEN CHANGE IMPORTS ACCORDINGLY | |
| import time | |
| import torch | |
| from tqdm import tqdm | |
| # fmt: off | |
| import sys | |
| sys.path.append('.') | |
| from lib.utils.data_utils import load_mesh, export_mesh | |
| from lib.utils.mesh_utils import triangle_to_halfedge, halfedge_to_triangle, multiple_halfedge_loop_subdivision | |
| # fmt: on | |
| depth = 2 | |
| repeat = 1000 | |
| input_file = 'big-sigcat.ply' | |
| output_file = 'big-sigcat_2.ply' | |
| v, f = load_mesh(input_file) | |
| he = triangle_to_halfedge(v, f) | |
| print(f'vert count: {he.V}') | |
| print(f'face count: {he.F}') | |
| print(f'edge count: {he.E}') | |
| print(f'halfedge count: {he.HE}') | |
| rounds = [] | |
| for i in tqdm(range(repeat)): | |
| start = time.time() | |
| he = triangle_to_halfedge(v, f, False) | |
| torch.cuda.synchronize() | |
| end = time.time() | |
| rounds.append(end - start) | |
| print('Conversion from triangle to halfedge representation: ') | |
| print(f'fastest time: {min(rounds) * 1000:.3f}ms') | |
| print(f'slowest time: {max(rounds) * 1000:.3f}ms') | |
| print(f'average time: {sum(rounds) / len(rounds) * 1000:.3f}ms') | |
| rounds = [] | |
| for i in tqdm(range(repeat)): | |
| start = time.time() | |
| he = triangle_to_halfedge(v, f, True) | |
| torch.cuda.synchronize() | |
| end = time.time() | |
| rounds.append(end - start) | |
| print('Conversion from triangle to halfedge with manifold assumption: ') | |
| print(f'fastest time: {min(rounds) * 1000:.3f}ms') | |
| print(f'slowest time: {max(rounds) * 1000:.3f}ms') | |
| print(f'average time: {sum(rounds) / len(rounds) * 1000:.3f}ms') | |
| rounds = [] | |
| for i in tqdm(range(repeat)): | |
| start = time.time() | |
| he = triangle_to_halfedge(v, f, True) | |
| end = time.time() | |
| rounds.append(end - start) | |
| torch.cuda.synchronize() | |
| print('Conversion from triangle to halfedge with manifold assumption (no gpu sync): ') | |
| print(f'fastest time: {min(rounds) * 1000:.3f}ms') | |
| print(f'slowest time: {max(rounds) * 1000:.3f}ms') | |
| print(f'average time: {sum(rounds) / len(rounds) * 1000:.3f}ms') | |
| rounds = [] | |
| for i in tqdm(range(repeat)): | |
| start = time.time() | |
| nhe = multiple_halfedge_loop_subdivision(he, depth, False) | |
| torch.cuda.synchronize() | |
| end = time.time() | |
| rounds.append(end - start) | |
| print(f'Torch Loop subdivision (depth: {depth}): ') | |
| print(f'fastest cpu time: {min(rounds) * 1000:.3f}ms') | |
| print(f'slowest time: {max(rounds) * 1000:.3f}ms') | |
| print(f'average time: {sum(rounds) / len(rounds) * 1000:.3f}ms') | |
| rounds = [] | |
| for i in tqdm(range(repeat)): | |
| start = time.time() | |
| nhe = multiple_halfedge_loop_subdivision(he, depth, True) | |
| torch.cuda.synchronize() | |
| end = time.time() | |
| rounds.append(end - start) | |
| print(f'Torch Loop subdivision with manifold assumption (depth: {depth}): ') | |
| print(f'fastest time: {min(rounds) * 1000:.3f}ms') | |
| print(f'slowest time: {max(rounds) * 1000:.3f}ms') | |
| print(f'average time: {sum(rounds) / len(rounds) * 1000:.3f}ms') | |
| rounds = [] | |
| for i in tqdm(range(repeat)): | |
| start = time.time() | |
| nhe = multiple_halfedge_loop_subdivision(he, depth, True) | |
| end = time.time() | |
| rounds.append(end - start) | |
| torch.cuda.synchronize() | |
| print(f'Torch Loop subdivision with manifold assumption (depth: {depth}) (no gpu sync): ') | |
| print(f'fastest cpu time: {min(rounds) * 1000:.3f}ms') | |
| print(f'slowest time: {max(rounds) * 1000:.3f}ms') | |
| print(f'average time: {sum(rounds) / len(rounds) * 1000:.3f}ms') | |
| # ~1.566ms for 2 steps on 3090 | |
| # while a loop subdiv on MeshLab takes ~232ms | |
| # ~60ms for 5 steps | |
| # while MeshLab takes 17991ms | |
| v, f = halfedge_to_triangle(nhe) | |
| export_mesh(v, f, filename=output_file) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # NOTE: THIS WON'T RUN, PLZ IMPLEMENT YOUR OWN load_mesh AND export_mesh, THEN CHANGE IMPORTS ACCORDINGLY | |
| import torch | |
| from tqdm import tqdm | |
| from largesteps.optimize import AdamUniform | |
| from largesteps.geometry import compute_matrix | |
| from largesteps.parameterize import from_differential, to_differential | |
| # fmt: off | |
| import sys | |
| sys.path.append('.') | |
| from lib.utils.data_utils import load_mesh, export_mesh | |
| from lib.utils.mesh_utils import triangle_to_halfedge, halfedge_to_triangle, multiple_halfedge_loop_subdivision | |
| # fmt: on | |
| def forward(p: torch.Tensor, f: torch.Tensor, M: torch.sparse.FloatTensor, depth: int): | |
| # this shows that our loop subdivision is differentiable w.r.t verts | |
| # and we can trivially connect it to the largesteps mesh optimization program | |
| v = from_differential(M, p, 'Cholesky') | |
| he = triangle_to_halfedge(v, f, True) | |
| nhe = multiple_halfedge_loop_subdivision(he, depth, True) | |
| v, f = halfedge_to_triangle(nhe) | |
| return v, f | |
| def main(): | |
| lr = 3e-2 | |
| depth = 2 | |
| ep_iter = 10 | |
| opt_iter = 50 | |
| lambda_smooth = 29 | |
| input_file = 'big-sigcat.ply' | |
| output_file = 'big-sigcat-to-sphere.ply' | |
| v, f = load_mesh(input_file) | |
| he = triangle_to_halfedge(v, f, True) | |
| print(f'vert count: {he.V}') | |
| print(f'face count: {he.F}') | |
| print(f'edge count: {he.E}') | |
| print(f'halfedge count: {he.HE}') | |
| # assume no batch dim | |
| M = compute_matrix(v, f, lambda_smooth) | |
| p = to_differential(M, v) | |
| p.requires_grad_() | |
| optim = AdamUniform([p], lr=lr) | |
| print() | |
| pbar = tqdm(range(opt_iter)) | |
| for i in range(opt_iter): | |
| v, _ = forward(p, f, M, depth) | |
| loss = ((v.norm(dim=-1) - 1) ** 2).sum() | |
| optim.zero_grad(set_to_none=True) | |
| loss.backward() | |
| optim.step() | |
| pbar.update(1) | |
| if i % ep_iter == 0: | |
| pbar.write(f'L2 loss: {loss.item():.5g}') | |
| v, f = forward(p.detach(), f, M, depth) | |
| export_mesh(v, f, filename=output_file) | |
| if __name__ == "__main__": | |
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from typing import Mapping, TypeVar, Union | |
| # these are generic type vars to tell mapping to accept any type vars when creating a type | |
| KT = TypeVar("KT") # key type | |
| VT = TypeVar("VT") # value type | |
| def make_dotdict(*args, **kwargs): | |
| return DotDict(*args, **kwargs) | |
| class DotDict(dict, Mapping[KT, VT]): | |
| """ | |
| a dictionary that supports dot notation | |
| as well as dictionary access notation | |
| usage: d = make_dotdict() or d = make_dotdict{'val1':'first'}) | |
| set attributes: d.val2 = 'second' or d['val2'] = 'second' | |
| get attributes: d.val2 or d['val2'] | |
| """ | |
| def update(self, dct=None, **kwargs): | |
| if dct is None: | |
| dct = kwargs | |
| else: | |
| dct.update(kwargs) | |
| for k, v in dct.items(): | |
| if k in self: | |
| target_type = type(self[k]) | |
| if not isinstance(v, target_type): | |
| # NOTE: bool('False') will be True | |
| if target_type == bool and isinstance(v, str): | |
| dct[k] = v == 'True' | |
| else: | |
| dct[k] = target_type(v) | |
| dict.update(self, dct) | |
| def __hash__(self): | |
| return hash(''.join([str(self.values().__hash__())])) | |
| def __init__(self, dct=None, **kwargs): | |
| if dct is None: | |
| dct = kwargs | |
| else: | |
| dct.update(kwargs) | |
| if dct is not None: | |
| for key, value in dct.items(): | |
| if hasattr(value, 'keys'): | |
| value = make_dotdict(value) | |
| self[key] = value | |
| """ | |
| Uncomment following lines and | |
| comment out __getattr__ = dict.__getitem__ to get feature: | |
| returns empty numpy array for undefined keys, so that you can easily copy things around | |
| TODO: potential caveat, harder to trace where this is set to np.array([], dtype=np.float32) | |
| """ | |
| def __getitem__(self, key): | |
| try: | |
| return dict.__getitem__(self, key) | |
| except KeyError as e: | |
| raise AttributeError(e) | |
| # MARK: Might encounter exception in newer version of pytorch | |
| # Traceback (most recent call last): | |
| # File "/home/xuzhen/miniconda3/envs/torch/lib/python3.9/multiprocessing/queues.py", line 245, in _feed | |
| # obj = _ForkingPickler.dumps(obj) | |
| # File "/home/xuzhen/miniconda3/envs/torch/lib/python3.9/multiprocessing/reduction.py", line 51, in dumps | |
| # cls(buf, protocol).dump(obj) | |
| # KeyError: '__getstate__' | |
| # MARK: Because you allow your __getattr__() implementation to raise the wrong kind of exception. | |
| __getattr__ = __getitem__ # overidden dict.__getitem__ | |
| # __getattr__ = dict.__getitem__ | |
| __setattr__ = dict.__setitem__ | |
| __delattr__ = dict.__delitem__ | |
| def halfedge_loop_subdivision(halfedge: DotDict[str, torch.Tensor], is_manifold=False): | |
| # Please just watch this: https://www.youtube.com/watch?v=mxk2HHk1NK4 | |
| # Adapted from https://github.com/kvanhoey/ParallelHalfedgeSubdivision | |
| # assuming the mesh has clean topology? except for boundary edges | |
| # assume no boundary edge for now! | |
| # loading from dotdict | |
| verts: torch.Tensor = halfedge.verts # V, 3 | |
| twin: torch.Tensor = halfedge.twin # HE, | |
| vert: torch.Tensor = halfedge.vert # HE, | |
| edge: torch.Tensor = halfedge.edge # HE, | |
| HE = halfedge.HE | |
| E = halfedge.E | |
| F = halfedge.F | |
| V = halfedge.V | |
| NHE = 4 * HE | |
| NF = 4 * F | |
| NE = 2 * E + 3 * F | |
| NV = V + E | |
| # assign empty memory | |
| ntwin = torch.empty(NHE, device=vert.device, dtype=vert.dtype) | |
| nvert = torch.empty(NHE, device=vert.device, dtype=vert.dtype) | |
| nedge = torch.empty(NHE, device=vert.device, dtype=vert.dtype) | |
| # prepare input for topology computation | |
| hedg = torch.arange(HE, device=vert.device) | |
| next = hedg - (hedg % 3) + (hedg + 1) % 3 | |
| prev = hedg - (hedg % 3) + (hedg + 2) % 3 | |
| next_twin = next[twin] | |
| twin_prev = twin[prev] | |
| edge_prev = edge[prev] | |
| # assign next topology | |
| i0 = 3*hedg + 0 | |
| i1 = 3*hedg + 1 | |
| i2 = 3*hedg + 2 | |
| i3 = 3*HE + hedg | |
| ntwin[i0] = 3 * next_twin + 2 | |
| ntwin[i1] = 3 * HE + hedg | |
| ntwin[i2] = 3 * twin_prev | |
| ntwin[i3] = 3 * hedg + 1 | |
| nedge[i0] = 2 * edge + (hedg < twin).long() | |
| nedge[i1] = 2 * E + hedg | |
| nedge[i2] = 2 * edge_prev + (prev > twin_prev).long() | |
| nedge[i3] = nedge[i1] | |
| nvert[i0] = vert | |
| nvert[i1] = V + edge | |
| nvert[i2] = V + edge_prev | |
| nvert[i3] = nvert[i2] | |
| if not is_manifold: | |
| # deal with non-manifold cases | |
| manifold_mask = twin >= 0 | |
| non_manifold_mask = ~manifold_mask | |
| non_manifold_mask_prev = non_manifold_mask[prev] | |
| manifold = manifold_mask.nonzero(as_tuple=True)[0] # only non manifold half edge are here | |
| non_manifold = non_manifold_mask.nonzero(as_tuple=True)[0] # only non manifold half edge are here | |
| non_manifold_prev = non_manifold_mask_prev.nonzero(as_tuple=True)[0] # only non manifold half edge are here | |
| ntwin[i0[non_manifold]] = twin[non_manifold] | |
| ntwin[i2[non_manifold_prev]] = twin_prev[non_manifold_prev] # should store the non-manifold twin (whether previsou is non-manifold) | |
| nedge[i0[non_manifold]] = 2 * edge[non_manifold] | |
| nedge[i2[non_manifold_prev]] = 2 * edge_prev[non_manifold_prev] + 1 # should store the non-manifold edge (whether previous is non-manifold) | |
| # pre-compute vertex velance & beta values | |
| _, inverse, velance = vert.unique(sorted=False, return_inverse=True, return_counts=True) | |
| velance = scatter(velance[inverse], vert, dim=0, dim_size=V, reduce='mean') # all verts velance, in original order (not some sorted order) | |
| beta = (1 / velance) * (5/8 - (3/8 + 1/4 * torch.cos(2 * torch.pi / velance))**2) | |
| # prepare geometric topological variables | |
| vert_next = vert[next] | |
| vert_prev = vert[prev] | |
| vert_edge = V + edge # no duplication between vert and vert_edge | |
| velance_vert = velance[vert] | |
| beta_vert = beta[vert] | |
| if not is_manifold: | |
| # prepare for computing non-manifold original vertices | |
| incident = -twin[non_manifold] | |
| non_manifold_vert_mask = scatter(non_manifold_mask.int(), inverse, dim=0, dim_size=V, reduce='max').bool() # non_manifold vert mask, sorted | |
| non_manifold_vert_mask = non_manifold_vert_mask[inverse] # non-manifold edge mask, if vert is non-manifold, this would be non-manifold | |
| non_manifold_vert = non_manifold_vert_mask.nonzero(as_tuple=True)[0] | |
| # prev and current is manifold | |
| manifold_prev = manifold_mask[prev].nonzero(as_tuple=True)[0] | |
| non_manifold_prev_twin_prev_mask = torch.zeros(HE, device=vert.device, dtype=manifold_mask.dtype) | |
| non_manifold_prev_twin_prev_mask[manifold_prev] = non_manifold_mask[prev[twin_prev[manifold_prev]]] | |
| non_manifold_prev_twin_prev = non_manifold_prev_twin_prev_mask.nonzero(as_tuple=True)[0] | |
| # actually distribute geometric values, if no topology change, only these should be retained | |
| verts_vert = verts[vert] | |
| verts_vert_next = verts[vert_next] | |
| verts_vert_prev = verts[vert_prev] | |
| nverts_vert_edge = (3 * verts_vert + verts_vert_prev) / 8 # vertex position for vertices created on edges | |
| nverts_vert = (1/velance_vert - beta_vert)[..., None] * verts_vert + beta_vert[..., None] * verts_vert_next # vertex position for older vertices | |
| if not is_manifold: | |
| # non-manifold edges | |
| nverts_vert_edge[non_manifold] = (verts_vert[non_manifold] + verts_vert_next[non_manifold]) / (incident * 2)[..., None] | |
| nverts_vert[non_manifold_vert] = 0 | |
| nverts_vert[non_manifold] += 1/8 * verts_vert_next[non_manifold] + 3/8 * verts_vert[non_manifold] | |
| nverts_vert[non_manifold_prev_twin_prev] += 1/8 * verts_vert_prev[twin_prev[non_manifold_prev_twin_prev]] + 3/8 * verts_vert[twin_prev[non_manifold_prev_twin_prev]] | |
| # origianl vertices and new edge vertices will have no overlap | |
| nverts_vert_edge = scatter(nverts_vert_edge, vert_edge, dim=0, dim_size=NV) | |
| nverts_vert = scatter(nverts_vert, vert, dim=0, dim_size=NV) | |
| nverts = nverts_vert_edge + nverts_vert | |
| # prepare return dotdict structure | |
| nhalfedge = make_dotdict() | |
| # geometric info | |
| nhalfedge.verts = nverts # V, 3 | |
| # topologinal info | |
| nhalfedge.twin = ntwin # HE, | |
| nhalfedge.vert = nvert # HE, | |
| nhalfedge.edge = nedge # HE, | |
| # size info (some of them could be omitted like HE) | |
| nhalfedge.HE = NHE | |
| nhalfedge.E = NE | |
| nhalfedge.F = NF | |
| nhalfedge.V = NV | |
| return nhalfedge | |
| def multiple_halfedge_loop_subdivision(halfedge: DotDict[str, torch.Tensor], steps=2, is_manifold=False): | |
| for i in range(steps): | |
| halfedge = halfedge_loop_subdivision(halfedge, is_manifold) | |
| return halfedge | |
| def halfedge_to_triangle(halfedge: DotDict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: | |
| # assuming the mesh has clean topology? except for boundary edges | |
| # assume no boundary edge for now! | |
| verts = halfedge.verts | |
| vert = halfedge.vert # HE, | |
| HE = len(vert) | |
| hedg = torch.arange(HE, device=verts.device) | |
| next = hedg & ~3 | (hedg + 1) & 3 | |
| e = torch.stack([vert, vert[next]], dim=-1) | |
| e01, e12, e20 = e[0::3], e[1::3], e[2::3] | |
| faces = torch.stack([e01[..., 0], e12[..., 0], e20[..., 0]], dim=-1) | |
| return verts, faces | |
| def triangle_to_halfedge(verts: Union[torch.Tensor, None], | |
| faces: torch.Tensor, | |
| is_manifold: bool = False, | |
| ): | |
| # assuming the mesh has clean topology? except for boundary edges | |
| # assume no boundary edge for now! | |
| F = len(faces) | |
| V = len(verts) if verts is not None else faces.max().item() | |
| HE = 3 * F | |
| # create halfedges | |
| v0, v1, v2 = faces.chunk(3, dim=-1) | |
| e01 = torch.cat([v0, v1], dim=-1) # (sum(F_n), 2) | |
| e12 = torch.cat([v1, v2], dim=-1) # (sum(F_n), 2) | |
| e20 = torch.cat([v2, v0], dim=-1) # (sum(F_n), 2) | |
| # stores the vertex indices for each half edge | |
| e = torch.empty(HE, 2, device=faces.device, dtype=faces.dtype) | |
| e[0::3] = e01 | |
| e[1::3] = e12 | |
| e[2::3] = e20 | |
| vert = e[..., 0] # HE, :record starting half edge | |
| vert_next = e[..., 1] | |
| edges = torch.stack([torch.minimum(vert_next, vert), torch.maximum(vert_next, vert)], dim=-1) | |
| hash = V * edges[..., 0] + edges[..., 1] # HE, 2, contains edge hash, should be unique | |
| _, edge, counts = hash.unique(sorted=False, return_inverse=True, return_counts=True) | |
| E = len(counts) | |
| hedg = torch.arange(HE, device=faces.device) # HE, :record half edge indices | |
| if is_manifold: | |
| inds = edge.argsort() # 00, 11, 22 ... | |
| twin = torch.empty_like(inds) | |
| twin[inds[0::2]] = inds[1::2] | |
| twin[inds[1::2]] = inds[0::2] | |
| else: | |
| # now we have edge indices, if it's a good mesh, each edge should have two half edges | |
| # in some non-manifold cases this would be broken so we need to first filter those non-manifold edges out | |
| manifold = counts == 2 # non-manifold mask | |
| manifold = manifold[edge] # non-manifold half edge mask | |
| edge_manifold = edge[manifold] # manifold edge indices | |
| args = edge_manifold.argsort() # 00, 11, 22 ... | |
| inds = hedg[manifold][args] | |
| twin_manifold = torch.empty_like(inds) | |
| twin_manifold[args[0::2]] = inds[1::2] | |
| twin_manifold[args[1::2]] = inds[0::2] | |
| twin = torch.empty(HE, device=faces.device, dtype=torch.long) | |
| twin[manifold] = twin_manifold | |
| twin[~manifold] = -counts[edge][~manifold] # non-manifold half edge mask, number of half edges stored in the twin | |
| # should return these values | |
| halfedge = make_dotdict() | |
| # geometric info | |
| halfedge.verts = verts # V, 3 | |
| # connectivity info | |
| halfedge.twin = twin # HE, | |
| halfedge.vert = vert # HE, | |
| halfedge.edge = edge # HE, | |
| halfedge.HE = HE | |
| halfedge.E = E | |
| halfedge.F = F | |
| halfedge.V = V | |
| return halfedge |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment