Skip to content

Instantly share code, notes, and snippets.

@ci7lus
Created November 10, 2024 08:30
Show Gist options
  • Select an option

  • Save ci7lus/5dbef5242d20d4c3aca5ddcdc77ef11e to your computer and use it in GitHub Desktop.

Select an option

Save ci7lus/5dbef5242d20d4c3aca5ddcdc77ef11e to your computer and use it in GitHub Desktop.
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
// https://github.com/WireGuard/wireguard-windows/blob/b279eab97a46bf8382b956b087b6922f88f95f20/conf/parser.go#L160
// https://github.com/WireGuard/wireguard-windows/blob/b279eab97a46bf8382b956b087b6922f88f95f20/conf/config.go#L35
package wgconf
import (
"crypto/subtle"
"encoding/base64"
"encoding/hex"
"fmt"
"net/netip"
"strconv"
"strings"
"time"
"golang.org/x/crypto/curve25519"
)
const KeyLength = 32
type Endpoint struct {
Host string
Port uint16
}
type (
Key [KeyLength]byte
HandshakeTime time.Duration
Bytes uint64
)
type Config struct {
Name string
Interface Interface
Peers []Peer
}
type Interface struct {
PrivateKey Key
Addresses []netip.Prefix
ListenPort uint16
MTU uint16
DNS []netip.Addr
DNSSearch []string
PreUp string
PostUp string
PreDown string
PostDown string
TableOff bool
}
type Peer struct {
PublicKey Key
PresharedKey Key
AllowedIPs []netip.Prefix
Endpoint Endpoint
PersistentKeepalive uint16
RxBytes Bytes
TxBytes Bytes
LastHandshakeTime HandshakeTime
}
type ParseError struct {
why string
offender string
}
func (e *ParseError) Error() string {
return fmt.Sprintf("%s: %q", e.why, e.offender)
}
func parseIPCidr(s string) (netip.Prefix, error) {
ipcidr, err := netip.ParsePrefix(s)
if err == nil {
return ipcidr, nil
}
addr, err := netip.ParseAddr(s)
if err != nil {
return netip.Prefix{}, &ParseError{"Invalid IP address: ", s}
}
return netip.PrefixFrom(addr, addr.BitLen()), nil
}
func parseEndpoint(s string) (*Endpoint, error) {
i := strings.LastIndexByte(s, ':')
if i < 0 {
return nil, &ParseError{"Missing port from endpoint", s}
}
host, portStr := s[:i], s[i+1:]
if len(host) < 1 {
return nil, &ParseError{"Invalid endpoint host", host}
}
port, err := parsePort(portStr)
if err != nil {
return nil, err
}
hostColon := strings.IndexByte(host, ':')
if host[0] == '[' || host[len(host)-1] == ']' || hostColon > 0 {
err := &ParseError{"Brackets must contain an IPv6 address", host}
if len(host) > 3 && host[0] == '[' && host[len(host)-1] == ']' && hostColon > 0 {
end := len(host) - 1
if i := strings.LastIndexByte(host, '%'); i > 1 {
end = i
}
maybeV6, err2 := netip.ParseAddr(host[1:end])
if err2 != nil || !maybeV6.Is6() {
return nil, err
}
} else {
return nil, err
}
host = host[1 : len(host)-1]
}
return &Endpoint{host, port}, nil
}
func parseMTU(s string) (uint16, error) {
m, err := strconv.Atoi(s)
if err != nil {
return 0, err
}
if m < 576 || m > 65535 {
return 0, &ParseError{"Invalid MTU", s}
}
return uint16(m), nil
}
func parsePort(s string) (uint16, error) {
m, err := strconv.Atoi(s)
if err != nil {
return 0, err
}
if m < 0 || m > 65535 {
return 0, &ParseError{"Invalid port", s}
}
return uint16(m), nil
}
func parsePersistentKeepalive(s string) (uint16, error) {
if s == "off" {
return 0, nil
}
m, err := strconv.Atoi(s)
if err != nil {
return 0, err
}
if m < 0 || m > 65535 {
return 0, &ParseError{"Invalid persistent keepalive", s}
}
return uint16(m), nil
}
func parseTableOff(s string) (bool, error) {
if s == "off" {
return true, nil
} else if s == "auto" || s == "main" {
return false, nil
}
_, err := strconv.ParseUint(s, 10, 32)
return false, err
}
func parseKeyBase64(s string) (*Key, error) {
k, err := base64.StdEncoding.DecodeString(s)
if err != nil {
return nil, &ParseError{fmt.Sprintf("Invalid key: %v", err), s}
}
if len(k) != KeyLength {
return nil, &ParseError{"Keys must decode to exactly 32 bytes", s}
}
var key Key
copy(key[:], k)
return &key, nil
}
func splitList(s string) ([]string, error) {
var out []string
for _, split := range strings.Split(s, ",") {
trim := strings.TrimSpace(split)
if len(trim) == 0 {
return nil, &ParseError{"Two commas in a row", s}
}
out = append(out, trim)
}
return out, nil
}
type parserState int
const (
inInterfaceSection parserState = iota
inPeerSection
notInASection
)
func (c *Config) maybeAddPeer(p *Peer) {
if p != nil {
c.Peers = append(c.Peers, *p)
}
}
func FromWgQuick(s, name string) (*Config, error) {
lines := strings.Split(s, "\n")
parserState := notInASection
conf := Config{Name: name}
sawPrivateKey := false
var peer *Peer
for _, line := range lines {
line, _, _ = strings.Cut(line, "#")
line = strings.TrimSpace(line)
lineLower := strings.ToLower(line)
if len(line) == 0 {
continue
}
if lineLower == "[interface]" {
conf.maybeAddPeer(peer)
parserState = inInterfaceSection
continue
}
if lineLower == "[peer]" {
conf.maybeAddPeer(peer)
peer = &Peer{}
parserState = inPeerSection
continue
}
if parserState == notInASection {
return nil, &ParseError{"Line must occur in a section", line}
}
equals := strings.IndexByte(line, '=')
if equals < 0 {
return nil, &ParseError{"Config key is missing an equals separator", line}
}
key, val := strings.TrimSpace(lineLower[:equals]), strings.TrimSpace(line[equals+1:])
if len(val) == 0 {
return nil, &ParseError{"Key must have a value", line}
}
if parserState == inInterfaceSection {
switch key {
case "privatekey":
k, err := parseKeyBase64(val)
if err != nil {
return nil, err
}
conf.Interface.PrivateKey = *k
sawPrivateKey = true
case "listenport":
p, err := parsePort(val)
if err != nil {
return nil, err
}
conf.Interface.ListenPort = p
case "mtu":
m, err := parseMTU(val)
if err != nil {
return nil, err
}
conf.Interface.MTU = m
case "address":
addresses, err := splitList(val)
if err != nil {
return nil, err
}
for _, address := range addresses {
a, err := parseIPCidr(address)
if err != nil {
return nil, err
}
conf.Interface.Addresses = append(conf.Interface.Addresses, a)
}
case "dns":
addresses, err := splitList(val)
if err != nil {
return nil, err
}
for _, address := range addresses {
a, err := netip.ParseAddr(address)
if err != nil {
conf.Interface.DNSSearch = append(conf.Interface.DNSSearch, address)
} else {
conf.Interface.DNS = append(conf.Interface.DNS, a)
}
}
case "preup":
conf.Interface.PreUp = val
case "postup":
conf.Interface.PostUp = val
case "predown":
conf.Interface.PreDown = val
case "postdown":
conf.Interface.PostDown = val
case "table":
tableOff, err := parseTableOff(val)
if err != nil {
return nil, err
}
conf.Interface.TableOff = tableOff
default:
return nil, &ParseError{"Invalid key for [Interface] section", key}
}
} else if parserState == inPeerSection {
switch key {
case "publickey":
k, err := parseKeyBase64(val)
if err != nil {
return nil, err
}
peer.PublicKey = *k
case "presharedkey":
k, err := parseKeyBase64(val)
if err != nil {
return nil, err
}
peer.PresharedKey = *k
case "allowedips":
addresses, err := splitList(val)
if err != nil {
return nil, err
}
for _, address := range addresses {
a, err := parseIPCidr(address)
if err != nil {
return nil, err
}
peer.AllowedIPs = append(peer.AllowedIPs, a)
}
case "persistentkeepalive":
p, err := parsePersistentKeepalive(val)
if err != nil {
return nil, err
}
peer.PersistentKeepalive = p
case "endpoint":
e, err := parseEndpoint(val)
if err != nil {
return nil, err
}
peer.Endpoint = *e
default:
return nil, &ParseError{"Invalid key for [Peer] section", key}
}
}
}
conf.maybeAddPeer(peer)
if !sawPrivateKey {
return nil, &ParseError{"An interface must have a private key", "[none specified]"}
}
for _, p := range conf.Peers {
if p.PublicKey.IsZero() {
return nil, &ParseError{"All peers must have public keys", "[none specified]"}
}
}
return &conf, nil
}
func (k *Key) String() string {
return base64.StdEncoding.EncodeToString(k[:])
}
func (k *Key) IsZero() bool {
var zeros Key
return subtle.ConstantTimeCompare(zeros[:], k[:]) == 1
}
func (k *Key) Public() *Key {
var p [KeyLength]byte
curve25519.ScalarBaseMult(&p, (*[KeyLength]byte)(k))
return (*Key)(&p)
}
func (k *Key) Hex() string {
return hex.EncodeToString(k[:])
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment