Skip to content

Instantly share code, notes, and snippets.

@dendenxu
Created October 8, 2022 14:50
Show Gist options
  • Select an option

  • Save dendenxu/ca8a1816919a0fee915d3e72187894c1 to your computer and use it in GitHub Desktop.

Select an option

Save dendenxu/ca8a1816919a0fee915d3e72187894c1 to your computer and use it in GitHub Desktop.
Differentiable pytorch implementation of A Halfedge Refinement Rule for Parallel Loop Subdivision: http://kenneth.vanhoey.free.fr/index.php?page=research&lang=en#VD22, the original paper uses OpenGL and OpenMP, we provide a pytorch implementation of half-edge and triangle conversion, parallel loop subdivision algorithm (with considerations for e…
# NOTE: THIS WON'T RUN, PLZ IMPLEMENT YOUR OWN load_mesh AND export_mesh, THEN CHANGE IMPORTS ACCORDINGLY
import time
import torch
from tqdm import tqdm
# fmt: off
import sys
sys.path.append('.')
from lib.utils.data_utils import load_mesh, export_mesh
from lib.utils.mesh_utils import triangle_to_halfedge, halfedge_to_triangle, multiple_halfedge_loop_subdivision
# fmt: on
depth = 2
repeat = 1000
input_file = 'big-sigcat.ply'
output_file = 'big-sigcat_2.ply'
v, f = load_mesh(input_file)
he = triangle_to_halfedge(v, f)
print(f'vert count: {he.V}')
print(f'face count: {he.F}')
print(f'edge count: {he.E}')
print(f'halfedge count: {he.HE}')
rounds = []
for i in tqdm(range(repeat)):
start = time.time()
he = triangle_to_halfedge(v, f, False)
torch.cuda.synchronize()
end = time.time()
rounds.append(end - start)
print('Conversion from triangle to halfedge representation: ')
print(f'fastest time: {min(rounds) * 1000:.3f}ms')
print(f'slowest time: {max(rounds) * 1000:.3f}ms')
print(f'average time: {sum(rounds) / len(rounds) * 1000:.3f}ms')
rounds = []
for i in tqdm(range(repeat)):
start = time.time()
he = triangle_to_halfedge(v, f, True)
torch.cuda.synchronize()
end = time.time()
rounds.append(end - start)
print('Conversion from triangle to halfedge with manifold assumption: ')
print(f'fastest time: {min(rounds) * 1000:.3f}ms')
print(f'slowest time: {max(rounds) * 1000:.3f}ms')
print(f'average time: {sum(rounds) / len(rounds) * 1000:.3f}ms')
rounds = []
for i in tqdm(range(repeat)):
start = time.time()
he = triangle_to_halfedge(v, f, True)
end = time.time()
rounds.append(end - start)
torch.cuda.synchronize()
print('Conversion from triangle to halfedge with manifold assumption (no gpu sync): ')
print(f'fastest time: {min(rounds) * 1000:.3f}ms')
print(f'slowest time: {max(rounds) * 1000:.3f}ms')
print(f'average time: {sum(rounds) / len(rounds) * 1000:.3f}ms')
rounds = []
for i in tqdm(range(repeat)):
start = time.time()
nhe = multiple_halfedge_loop_subdivision(he, depth, False)
torch.cuda.synchronize()
end = time.time()
rounds.append(end - start)
print(f'Torch Loop subdivision (depth: {depth}): ')
print(f'fastest cpu time: {min(rounds) * 1000:.3f}ms')
print(f'slowest time: {max(rounds) * 1000:.3f}ms')
print(f'average time: {sum(rounds) / len(rounds) * 1000:.3f}ms')
rounds = []
for i in tqdm(range(repeat)):
start = time.time()
nhe = multiple_halfedge_loop_subdivision(he, depth, True)
torch.cuda.synchronize()
end = time.time()
rounds.append(end - start)
print(f'Torch Loop subdivision with manifold assumption (depth: {depth}): ')
print(f'fastest time: {min(rounds) * 1000:.3f}ms')
print(f'slowest time: {max(rounds) * 1000:.3f}ms')
print(f'average time: {sum(rounds) / len(rounds) * 1000:.3f}ms')
rounds = []
for i in tqdm(range(repeat)):
start = time.time()
nhe = multiple_halfedge_loop_subdivision(he, depth, True)
end = time.time()
rounds.append(end - start)
torch.cuda.synchronize()
print(f'Torch Loop subdivision with manifold assumption (depth: {depth}) (no gpu sync): ')
print(f'fastest cpu time: {min(rounds) * 1000:.3f}ms')
print(f'slowest time: {max(rounds) * 1000:.3f}ms')
print(f'average time: {sum(rounds) / len(rounds) * 1000:.3f}ms')
# ~1.566ms for 2 steps on 3090
# while a loop subdiv on MeshLab takes ~232ms
# ~60ms for 5 steps
# while MeshLab takes 17991ms
v, f = halfedge_to_triangle(nhe)
export_mesh(v, f, filename=output_file)
# NOTE: THIS WON'T RUN, PLZ IMPLEMENT YOUR OWN load_mesh AND export_mesh, THEN CHANGE IMPORTS ACCORDINGLY
import torch
from tqdm import tqdm
from largesteps.optimize import AdamUniform
from largesteps.geometry import compute_matrix
from largesteps.parameterize import from_differential, to_differential
# fmt: off
import sys
sys.path.append('.')
from lib.utils.data_utils import load_mesh, export_mesh
from lib.utils.mesh_utils import triangle_to_halfedge, halfedge_to_triangle, multiple_halfedge_loop_subdivision
# fmt: on
def forward(p: torch.Tensor, f: torch.Tensor, M: torch.sparse.FloatTensor, depth: int):
# this shows that our loop subdivision is differentiable w.r.t verts
# and we can trivially connect it to the largesteps mesh optimization program
v = from_differential(M, p, 'Cholesky')
he = triangle_to_halfedge(v, f, True)
nhe = multiple_halfedge_loop_subdivision(he, depth, True)
v, f = halfedge_to_triangle(nhe)
return v, f
def main():
lr = 3e-2
depth = 2
ep_iter = 10
opt_iter = 50
lambda_smooth = 29
input_file = 'big-sigcat.ply'
output_file = 'big-sigcat-to-sphere.ply'
v, f = load_mesh(input_file)
he = triangle_to_halfedge(v, f, True)
print(f'vert count: {he.V}')
print(f'face count: {he.F}')
print(f'edge count: {he.E}')
print(f'halfedge count: {he.HE}')
# assume no batch dim
M = compute_matrix(v, f, lambda_smooth)
p = to_differential(M, v)
p.requires_grad_()
optim = AdamUniform([p], lr=lr)
print()
pbar = tqdm(range(opt_iter))
for i in range(opt_iter):
v, _ = forward(p, f, M, depth)
loss = ((v.norm(dim=-1) - 1) ** 2).sum()
optim.zero_grad(set_to_none=True)
loss.backward()
optim.step()
pbar.update(1)
if i % ep_iter == 0:
pbar.write(f'L2 loss: {loss.item():.5g}')
v, f = forward(p.detach(), f, M, depth)
export_mesh(v, f, filename=output_file)
if __name__ == "__main__":
main()
import torch
from typing import Mapping, TypeVar, Union
# these are generic type vars to tell mapping to accept any type vars when creating a type
KT = TypeVar("KT") # key type
VT = TypeVar("VT") # value type
def make_dotdict(*args, **kwargs):
return DotDict(*args, **kwargs)
class DotDict(dict, Mapping[KT, VT]):
"""
a dictionary that supports dot notation
as well as dictionary access notation
usage: d = make_dotdict() or d = make_dotdict{'val1':'first'})
set attributes: d.val2 = 'second' or d['val2'] = 'second'
get attributes: d.val2 or d['val2']
"""
def update(self, dct=None, **kwargs):
if dct is None:
dct = kwargs
else:
dct.update(kwargs)
for k, v in dct.items():
if k in self:
target_type = type(self[k])
if not isinstance(v, target_type):
# NOTE: bool('False') will be True
if target_type == bool and isinstance(v, str):
dct[k] = v == 'True'
else:
dct[k] = target_type(v)
dict.update(self, dct)
def __hash__(self):
return hash(''.join([str(self.values().__hash__())]))
def __init__(self, dct=None, **kwargs):
if dct is None:
dct = kwargs
else:
dct.update(kwargs)
if dct is not None:
for key, value in dct.items():
if hasattr(value, 'keys'):
value = make_dotdict(value)
self[key] = value
"""
Uncomment following lines and
comment out __getattr__ = dict.__getitem__ to get feature:
returns empty numpy array for undefined keys, so that you can easily copy things around
TODO: potential caveat, harder to trace where this is set to np.array([], dtype=np.float32)
"""
def __getitem__(self, key):
try:
return dict.__getitem__(self, key)
except KeyError as e:
raise AttributeError(e)
# MARK: Might encounter exception in newer version of pytorch
# Traceback (most recent call last):
# File "/home/xuzhen/miniconda3/envs/torch/lib/python3.9/multiprocessing/queues.py", line 245, in _feed
# obj = _ForkingPickler.dumps(obj)
# File "/home/xuzhen/miniconda3/envs/torch/lib/python3.9/multiprocessing/reduction.py", line 51, in dumps
# cls(buf, protocol).dump(obj)
# KeyError: '__getstate__'
# MARK: Because you allow your __getattr__() implementation to raise the wrong kind of exception.
__getattr__ = __getitem__ # overidden dict.__getitem__
# __getattr__ = dict.__getitem__
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
def halfedge_loop_subdivision(halfedge: DotDict[str, torch.Tensor], is_manifold=False):
# Please just watch this: https://www.youtube.com/watch?v=mxk2HHk1NK4
# Adapted from https://github.com/kvanhoey/ParallelHalfedgeSubdivision
# assuming the mesh has clean topology? except for boundary edges
# assume no boundary edge for now!
# loading from dotdict
verts: torch.Tensor = halfedge.verts # V, 3
twin: torch.Tensor = halfedge.twin # HE,
vert: torch.Tensor = halfedge.vert # HE,
edge: torch.Tensor = halfedge.edge # HE,
HE = halfedge.HE
E = halfedge.E
F = halfedge.F
V = halfedge.V
NHE = 4 * HE
NF = 4 * F
NE = 2 * E + 3 * F
NV = V + E
# assign empty memory
ntwin = torch.empty(NHE, device=vert.device, dtype=vert.dtype)
nvert = torch.empty(NHE, device=vert.device, dtype=vert.dtype)
nedge = torch.empty(NHE, device=vert.device, dtype=vert.dtype)
# prepare input for topology computation
hedg = torch.arange(HE, device=vert.device)
next = hedg - (hedg % 3) + (hedg + 1) % 3
prev = hedg - (hedg % 3) + (hedg + 2) % 3
next_twin = next[twin]
twin_prev = twin[prev]
edge_prev = edge[prev]
# assign next topology
i0 = 3*hedg + 0
i1 = 3*hedg + 1
i2 = 3*hedg + 2
i3 = 3*HE + hedg
ntwin[i0] = 3 * next_twin + 2
ntwin[i1] = 3 * HE + hedg
ntwin[i2] = 3 * twin_prev
ntwin[i3] = 3 * hedg + 1
nedge[i0] = 2 * edge + (hedg < twin).long()
nedge[i1] = 2 * E + hedg
nedge[i2] = 2 * edge_prev + (prev > twin_prev).long()
nedge[i3] = nedge[i1]
nvert[i0] = vert
nvert[i1] = V + edge
nvert[i2] = V + edge_prev
nvert[i3] = nvert[i2]
if not is_manifold:
# deal with non-manifold cases
manifold_mask = twin >= 0
non_manifold_mask = ~manifold_mask
non_manifold_mask_prev = non_manifold_mask[prev]
manifold = manifold_mask.nonzero(as_tuple=True)[0] # only non manifold half edge are here
non_manifold = non_manifold_mask.nonzero(as_tuple=True)[0] # only non manifold half edge are here
non_manifold_prev = non_manifold_mask_prev.nonzero(as_tuple=True)[0] # only non manifold half edge are here
ntwin[i0[non_manifold]] = twin[non_manifold]
ntwin[i2[non_manifold_prev]] = twin_prev[non_manifold_prev] # should store the non-manifold twin (whether previsou is non-manifold)
nedge[i0[non_manifold]] = 2 * edge[non_manifold]
nedge[i2[non_manifold_prev]] = 2 * edge_prev[non_manifold_prev] + 1 # should store the non-manifold edge (whether previous is non-manifold)
# pre-compute vertex velance & beta values
_, inverse, velance = vert.unique(sorted=False, return_inverse=True, return_counts=True)
velance = scatter(velance[inverse], vert, dim=0, dim_size=V, reduce='mean') # all verts velance, in original order (not some sorted order)
beta = (1 / velance) * (5/8 - (3/8 + 1/4 * torch.cos(2 * torch.pi / velance))**2)
# prepare geometric topological variables
vert_next = vert[next]
vert_prev = vert[prev]
vert_edge = V + edge # no duplication between vert and vert_edge
velance_vert = velance[vert]
beta_vert = beta[vert]
if not is_manifold:
# prepare for computing non-manifold original vertices
incident = -twin[non_manifold]
non_manifold_vert_mask = scatter(non_manifold_mask.int(), inverse, dim=0, dim_size=V, reduce='max').bool() # non_manifold vert mask, sorted
non_manifold_vert_mask = non_manifold_vert_mask[inverse] # non-manifold edge mask, if vert is non-manifold, this would be non-manifold
non_manifold_vert = non_manifold_vert_mask.nonzero(as_tuple=True)[0]
# prev and current is manifold
manifold_prev = manifold_mask[prev].nonzero(as_tuple=True)[0]
non_manifold_prev_twin_prev_mask = torch.zeros(HE, device=vert.device, dtype=manifold_mask.dtype)
non_manifold_prev_twin_prev_mask[manifold_prev] = non_manifold_mask[prev[twin_prev[manifold_prev]]]
non_manifold_prev_twin_prev = non_manifold_prev_twin_prev_mask.nonzero(as_tuple=True)[0]
# actually distribute geometric values, if no topology change, only these should be retained
verts_vert = verts[vert]
verts_vert_next = verts[vert_next]
verts_vert_prev = verts[vert_prev]
nverts_vert_edge = (3 * verts_vert + verts_vert_prev) / 8 # vertex position for vertices created on edges
nverts_vert = (1/velance_vert - beta_vert)[..., None] * verts_vert + beta_vert[..., None] * verts_vert_next # vertex position for older vertices
if not is_manifold:
# non-manifold edges
nverts_vert_edge[non_manifold] = (verts_vert[non_manifold] + verts_vert_next[non_manifold]) / (incident * 2)[..., None]
nverts_vert[non_manifold_vert] = 0
nverts_vert[non_manifold] += 1/8 * verts_vert_next[non_manifold] + 3/8 * verts_vert[non_manifold]
nverts_vert[non_manifold_prev_twin_prev] += 1/8 * verts_vert_prev[twin_prev[non_manifold_prev_twin_prev]] + 3/8 * verts_vert[twin_prev[non_manifold_prev_twin_prev]]
# origianl vertices and new edge vertices will have no overlap
nverts_vert_edge = scatter(nverts_vert_edge, vert_edge, dim=0, dim_size=NV)
nverts_vert = scatter(nverts_vert, vert, dim=0, dim_size=NV)
nverts = nverts_vert_edge + nverts_vert
# prepare return dotdict structure
nhalfedge = make_dotdict()
# geometric info
nhalfedge.verts = nverts # V, 3
# topologinal info
nhalfedge.twin = ntwin # HE,
nhalfedge.vert = nvert # HE,
nhalfedge.edge = nedge # HE,
# size info (some of them could be omitted like HE)
nhalfedge.HE = NHE
nhalfedge.E = NE
nhalfedge.F = NF
nhalfedge.V = NV
return nhalfedge
def multiple_halfedge_loop_subdivision(halfedge: DotDict[str, torch.Tensor], steps=2, is_manifold=False):
for i in range(steps):
halfedge = halfedge_loop_subdivision(halfedge, is_manifold)
return halfedge
def halfedge_to_triangle(halfedge: DotDict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
# assuming the mesh has clean topology? except for boundary edges
# assume no boundary edge for now!
verts = halfedge.verts
vert = halfedge.vert # HE,
HE = len(vert)
hedg = torch.arange(HE, device=verts.device)
next = hedg & ~3 | (hedg + 1) & 3
e = torch.stack([vert, vert[next]], dim=-1)
e01, e12, e20 = e[0::3], e[1::3], e[2::3]
faces = torch.stack([e01[..., 0], e12[..., 0], e20[..., 0]], dim=-1)
return verts, faces
def triangle_to_halfedge(verts: Union[torch.Tensor, None],
faces: torch.Tensor,
is_manifold: bool = False,
):
# assuming the mesh has clean topology? except for boundary edges
# assume no boundary edge for now!
F = len(faces)
V = len(verts) if verts is not None else faces.max().item()
HE = 3 * F
# create halfedges
v0, v1, v2 = faces.chunk(3, dim=-1)
e01 = torch.cat([v0, v1], dim=-1) # (sum(F_n), 2)
e12 = torch.cat([v1, v2], dim=-1) # (sum(F_n), 2)
e20 = torch.cat([v2, v0], dim=-1) # (sum(F_n), 2)
# stores the vertex indices for each half edge
e = torch.empty(HE, 2, device=faces.device, dtype=faces.dtype)
e[0::3] = e01
e[1::3] = e12
e[2::3] = e20
vert = e[..., 0] # HE, :record starting half edge
vert_next = e[..., 1]
edges = torch.stack([torch.minimum(vert_next, vert), torch.maximum(vert_next, vert)], dim=-1)
hash = V * edges[..., 0] + edges[..., 1] # HE, 2, contains edge hash, should be unique
_, edge, counts = hash.unique(sorted=False, return_inverse=True, return_counts=True)
E = len(counts)
hedg = torch.arange(HE, device=faces.device) # HE, :record half edge indices
if is_manifold:
inds = edge.argsort() # 00, 11, 22 ...
twin = torch.empty_like(inds)
twin[inds[0::2]] = inds[1::2]
twin[inds[1::2]] = inds[0::2]
else:
# now we have edge indices, if it's a good mesh, each edge should have two half edges
# in some non-manifold cases this would be broken so we need to first filter those non-manifold edges out
manifold = counts == 2 # non-manifold mask
manifold = manifold[edge] # non-manifold half edge mask
edge_manifold = edge[manifold] # manifold edge indices
args = edge_manifold.argsort() # 00, 11, 22 ...
inds = hedg[manifold][args]
twin_manifold = torch.empty_like(inds)
twin_manifold[args[0::2]] = inds[1::2]
twin_manifold[args[1::2]] = inds[0::2]
twin = torch.empty(HE, device=faces.device, dtype=torch.long)
twin[manifold] = twin_manifold
twin[~manifold] = -counts[edge][~manifold] # non-manifold half edge mask, number of half edges stored in the twin
# should return these values
halfedge = make_dotdict()
# geometric info
halfedge.verts = verts # V, 3
# connectivity info
halfedge.twin = twin # HE,
halfedge.vert = vert # HE,
halfedge.edge = edge # HE,
halfedge.HE = HE
halfedge.E = E
halfedge.F = F
halfedge.V = V
return halfedge
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment