wireguardCommon.go 31 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135
  1. package wireguardcommon
  2. import (
  3. "bytes"
  4. "crypto"
  5. "crypto/cipher"
  6. "crypto/rand"
  7. "crypto/sha256"
  8. "encoding/base64"
  9. "errors"
  10. "fmt"
  11. "github.com/BurntSushi/toml"
  12. "github.com/aead/ecdh"
  13. "github.com/vishvananda/netlink"
  14. "golang.org/x/crypto/chacha20poly1305"
  15. "golang.org/x/crypto/curve25519"
  16. "golang.zx2c4.com/wireguard/wgctrl"
  17. "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
  18. "io"
  19. "log"
  20. "net"
  21. "net/http"
  22. "net/netip"
  23. "os"
  24. "os/exec"
  25. "path/filepath"
  26. "regexp"
  27. "strings"
  28. "time"
  29. )
  30. const (
  31. InitialPort = 32323
  32. WireguardListenPort = 32322
  33. PersistentKeepAliveInterval = 32
  34. InterfaceName = "wg3"
  35. Ip2 = "5.161.133.68"
  36. Ip1 = "192.168.1.152"
  37. Ip3 = "127.0.0.1"
  38. IPRange = "192.168.88.1/24"
  39. ClientConfigFile = "/etc/3nets.conf"
  40. pkFilename = "resources/managerPrivateKey"
  41. ConfFilename = "resources/wg3.conf"
  42. ConfFilenameAlt = "resources/wg3.alt.conf"
  43. deviceFormatStr = `interface: %s (%s)
  44. public key: %s
  45. private key: (hidden)
  46. listening port: %d
  47. `
  48. peerFormatStr = `peer: %s
  49. endpoint: %s
  50. allowed ips: %s
  51. latest handshake: %s
  52. transfer: %d B received, %d B sent
  53. `
  54. deviceFormatStrWithSecrets = `[Interface] # %s (%s)
  55. ListenPort = %d
  56. PrivateKey = %s
  57. Address = %s
  58. `
  59. peerFormatStrWithSecrets = `[Peer]
  60. PublicKey = %s
  61. PresharedKey = %s
  62. AllowedIps = %s
  63. PersistentKeepalive = %d
  64. `
  65. )
  66. const (
  67. ExitSetupSuccess = 0
  68. ExitSetupFailed = 1
  69. )
  70. func GenerateKeyPair() (privateKey crypto.PrivateKey, publicKey crypto.PublicKey, err error) {
  71. if _, err = os.Stat(pkFilename); errors.Is(err, os.ErrNotExist) {
  72. fmt.Printf("managerPrivateKey not found\n")
  73. c25519 := ecdh.X25519()
  74. privateKey, publicKey, err = c25519.GenerateKey(rand.Reader)
  75. if err != nil {
  76. fmt.Printf("Failed to generate manager privateKey/public key pair: %s\n", err)
  77. return
  78. }
  79. var managerPrivateBytes [32]byte
  80. if ok := CheckType(&managerPrivateBytes, privateKey); !ok {
  81. panic("ecdh: unexpected type of manager public key")
  82. }
  83. var managerPublicBytes [32]byte
  84. if ok := CheckType(&managerPublicBytes, publicKey); !ok {
  85. panic("ecdh: unexpected type of manager public key")
  86. }
  87. managerDHPrivate := base64.StdEncoding.EncodeToString(managerPrivateBytes[:])
  88. managerDHPublic := base64.StdEncoding.EncodeToString(managerPublicBytes[:])
  89. fmt.Printf("managerDHPrivate: %s\n", managerDHPrivate)
  90. fmt.Printf("managerDHPublic: %s\n", managerDHPublic)
  91. dir := filepath.Dir(pkFilename)
  92. err = os.MkdirAll(dir, 0555)
  93. if err != nil {
  94. log.Fatalf("MkdirAll for privateKey key failed, %s", err)
  95. }
  96. err = os.WriteFile(pkFilename, []byte(managerDHPrivate), 0600)
  97. if err != nil {
  98. log.Fatalf("Write file for privateKey key failed, %s", err)
  99. }
  100. } else {
  101. fmt.Printf("managerPrivateKey found\n")
  102. var managerBytes []byte
  103. var managerPrivateBytes []byte
  104. var managerPublicBytes [32]byte
  105. managerBytes, err = os.ReadFile(pkFilename)
  106. if err != nil {
  107. log.Fatalf("Read file for privateKey key failed, %s", err)
  108. }
  109. fmt.Printf("managerBytes = %s\n", managerBytes)
  110. managerPrivateBytes, err = base64.StdEncoding.DecodeString(string(managerBytes))
  111. if err != nil {
  112. log.Fatalf("Could not decode managerBytes, err:%s", err)
  113. }
  114. privateKey = managerPrivateBytes
  115. // build public key from privateKey key
  116. publicKey, err = curve25519.X25519(managerPrivateBytes, curve25519.Basepoint)
  117. if err != nil {
  118. log.Fatalf("Could not build public from managerPrivate, err:%s", err)
  119. }
  120. if ok := CheckType(&managerPublicBytes, publicKey); !ok {
  121. panic("ecdh: unexpected type of manager public key")
  122. }
  123. managerDHPrivate := base64.StdEncoding.EncodeToString(managerPrivateBytes[:])
  124. managerDHPublic := base64.StdEncoding.EncodeToString(managerPublicBytes[:])
  125. fmt.Printf("managerDHPublic: %s\n", managerDHPublic)
  126. fmt.Printf("managerDHPrivate: %s\n", managerDHPrivate)
  127. }
  128. return
  129. }
  130. func CheckType(key *[32]byte, typeToCheck interface{}) (ok bool) {
  131. switch t := typeToCheck.(type) {
  132. case [32]byte:
  133. copy(key[:], t[:])
  134. ok = true
  135. case *[32]byte:
  136. copy(key[:], t[:])
  137. ok = true
  138. case []byte:
  139. if len(t) == 32 {
  140. copy(key[:], t)
  141. ok = true
  142. }
  143. case *[]byte:
  144. if len(*t) == 32 {
  145. copy(key[:], *t)
  146. ok = true
  147. }
  148. }
  149. return
  150. }
  151. func GenerateSignature(url string, message string) (signature string, err error) {
  152. if url == "" {
  153. url = "http://test1.jagat.me:8123/nacl-sign"
  154. }
  155. requestJson := "{\"message\": \"" + message + "\"}"
  156. response, err := http.Post(url, "application/json", bytes.NewBuffer([]byte(requestJson)))
  157. if err != nil {
  158. log.Fatalf("Cant reach the url for generating signature, %s", err)
  159. }
  160. signatureBytes, err := io.ReadAll(response.Body)
  161. signature = string(signatureBytes)
  162. return
  163. }
  164. func VerifySignature(url string, signature string) (message string, status bool, err error) {
  165. if url == "" {
  166. url = "http://test1.jagat.me:8123/nacl-verify"
  167. }
  168. status = false
  169. requestJson := "{\"signature\":\"" + signature + "\"}"
  170. response, err := http.Post(url, "application/json", bytes.NewBuffer([]byte(requestJson)))
  171. if err != nil {
  172. log.Fatalf("Cant reach the url for verifying signature, %s", err)
  173. }
  174. messageBytes, err := io.ReadAll(response.Body)
  175. message = string(messageBytes)
  176. if response.StatusCode == 200 {
  177. status = true
  178. return
  179. }
  180. return
  181. }
  182. func VerifySignatureRegex(url string, regexMatchString string, encrypted string) (status bool, err error) {
  183. status = false
  184. var message string
  185. message, err = DecryptMessage(url, encrypted)
  186. fmt.Printf("descrypted message : --%s--\n", message)
  187. if err == nil {
  188. var re = regexp.MustCompile(regexMatchString)
  189. if re.MatchString(message) {
  190. err = nil
  191. status = true
  192. }
  193. }
  194. return
  195. }
  196. func EncryptMessage(url string, message string) (encrypted string, err error) {
  197. if url == "" {
  198. url = "http://test1.jagat.me:8123/nacl-encrypt"
  199. }
  200. requestJson := "{\"message\": \"" + message + "\"}"
  201. response, err := http.Post(url, "application/json", bytes.NewBuffer([]byte(requestJson)))
  202. if err != nil {
  203. log.Fatalf("Cant reach the url for encrypting, %s", err)
  204. }
  205. encryptedBytes, err := io.ReadAll(response.Body)
  206. encrypted = string(encryptedBytes)
  207. return
  208. }
  209. func DecryptMessage(url string, encrypted string) (message string, err error) {
  210. if url == "" {
  211. url = "http://test1.jagat.me:8123/nacl-decrypt"
  212. }
  213. requestJson := "{\"encrypted\":\"" + encrypted + "\"}"
  214. response, err := http.Post(url, "application/json", bytes.NewBuffer([]byte(requestJson)))
  215. if err != nil {
  216. log.Fatalf("Cant reach the url for decrypting, %s", err)
  217. }
  218. decryptedBytes, err := io.ReadAll(response.Body)
  219. message = string(decryptedBytes)
  220. return
  221. }
  222. func GenPskOnEdge(managerDHPublicStr string) (presharedKey string, edgeDHPublic string, err error) {
  223. c25519 := ecdh.X25519()
  224. var edgePrivate crypto.PrivateKey
  225. var edgePublic crypto.PublicKey
  226. edgePrivate, edgePublic, err = c25519.GenerateKey(rand.Reader)
  227. var edgePublicBytes [32]byte
  228. if ok := CheckType(&edgePublicBytes, edgePublic); !ok {
  229. panic("ecdh: unexpected type of edge public key")
  230. }
  231. edgeDHPublic = base64.StdEncoding.EncodeToString(edgePublicBytes[:])
  232. //fmt.Printf("edgeDHPublic: %s\n", edgeDHPublic)
  233. if err != nil {
  234. fmt.Printf("Failed to generate edge private/public key pair: %s\n", err)
  235. return
  236. }
  237. managerDHPublic, err := base64.StdEncoding.DecodeString(managerDHPublicStr)
  238. if err != nil {
  239. fmt.Printf("manager public key is could not be decoded: %s\n", err)
  240. }
  241. if err = c25519.Check(managerDHPublic); err != nil {
  242. fmt.Printf("manager public key is not on the curve: %s\n", err)
  243. }
  244. secret := c25519.ComputeSecret(edgePrivate, managerDHPublic)
  245. presharedKey = base64.StdEncoding.EncodeToString(secret)
  246. return
  247. }
  248. func GenPskOnManager(edgeDHPublicStr string,
  249. managerPrivate crypto.PrivateKey) (presharedKey string, err error) {
  250. c25519 := ecdh.X25519()
  251. edgeDHPublic, err := base64.StdEncoding.DecodeString(edgeDHPublicStr)
  252. if err != nil {
  253. fmt.Printf("edge public key is could not be decoded: %s\n", err)
  254. }
  255. if err = c25519.Check(edgeDHPublic); err != nil {
  256. fmt.Printf("edge public key is not on the curve: %s\n", err)
  257. }
  258. secret := c25519.ComputeSecret(managerPrivate, edgeDHPublic)
  259. presharedKey = base64.StdEncoding.EncodeToString(secret)
  260. return
  261. }
  262. func WriteConn(conn net.Conn, message []byte) (writeLen int, err error) {
  263. // Send manager DH public
  264. writeLen, err = conn.Write(message)
  265. if err != nil {
  266. fmt.Printf("Errored when send data on %s, error=%d\n", conn, err)
  267. return
  268. }
  269. return
  270. }
  271. func ReadConn(conn net.Conn, maxLength int) (responseStr string, responseLen int, err error) {
  272. // Make a buffer to hold incoming data.
  273. edgeResponse := make([]byte, maxLength+32)
  274. // Read the incoming connection into the buffer.
  275. responseLen, err = conn.Read(edgeResponse)
  276. if err != nil {
  277. fmt.Println("Error reading:", err.Error())
  278. CloseConn(conn)
  279. return
  280. }
  281. if responseLen > maxLength {
  282. fmt.Printf("read more than needed. responseLen = %d\n", responseLen)
  283. CloseConn(conn)
  284. return
  285. }
  286. responseStr = string(edgeResponse[:responseLen])
  287. return
  288. }
  289. func CloseConn(conn net.Conn) {
  290. err := conn.Close()
  291. if err != nil {
  292. fmt.Printf("error closing conn. ignored\n")
  293. return
  294. }
  295. }
  296. func GetCurrentConfWithSecrets() (conf string, err error) {
  297. c, err := wgctrl.New()
  298. if err != nil {
  299. log.Fatalf("failed to open wgctrl: %v", err)
  300. return
  301. }
  302. defer c.Close()
  303. var retStr strings.Builder
  304. var devices []*wgtypes.Device
  305. devices, err = c.Devices()
  306. if err != nil {
  307. log.Fatalf("failed to get devices: %v", err)
  308. return
  309. }
  310. for _, d := range devices {
  311. var link netlink.Link
  312. var addrs []net.IP
  313. var netaddrs []netlink.Addr
  314. link, err = netlink.LinkByName(d.Name)
  315. if err != nil {
  316. fmt.Printf("netlink.LinkByName(%s) fialed, err: %s\n", d.Name, err)
  317. } else {
  318. netaddrs, err = netlink.AddrList(link, netlink.FAMILY_V4)
  319. if err != nil {
  320. fmt.Printf("netlink.AddrList(%s) fialed, err: %s\n", d.Name, err)
  321. }
  322. for _, a := range netaddrs {
  323. addrs = append(addrs, a.IPNet.IP)
  324. }
  325. }
  326. retStr.WriteString(fmt.Sprintf(
  327. deviceFormatStrWithSecrets,
  328. d.Name,
  329. d.Type.String(),
  330. d.ListenPort,
  331. d.PrivateKey.String(),
  332. addrs[0],
  333. ))
  334. for _, p := range d.Peers {
  335. retStr.WriteString(fmt.Sprintf(
  336. peerFormatStrWithSecrets,
  337. p.PublicKey.String(),
  338. p.PresharedKey.String(),
  339. // TODO: get right endpoint with getnameinfo.
  340. ipsString(p.AllowedIPs),
  341. int(p.PersistentKeepaliveInterval.Seconds()),
  342. ))
  343. }
  344. }
  345. conf = retStr.String()
  346. return
  347. }
  348. func GetCurrentConf() (conf string, err error) {
  349. c, err := wgctrl.New()
  350. if err != nil {
  351. log.Fatalf("failed to open wgctrl: %v", err)
  352. return
  353. }
  354. defer c.Close()
  355. var retStr strings.Builder
  356. var devices []*wgtypes.Device
  357. devices, err = c.Devices()
  358. if err != nil {
  359. log.Fatalf("failed to get devices: %v", err)
  360. return
  361. }
  362. for _, d := range devices {
  363. retStr.WriteString(fmt.Sprintf(
  364. deviceFormatStr,
  365. d.Name,
  366. d.Type.String(),
  367. d.PublicKey.String(),
  368. d.ListenPort))
  369. for _, p := range d.Peers {
  370. retStr.WriteString(fmt.Sprintf(
  371. peerFormatStr,
  372. p.PublicKey.String(),
  373. // TODO(mdlayher): get right endpoint with getnameinfo.
  374. p.Endpoint.String(),
  375. ipsString(p.AllowedIPs),
  376. p.LastHandshakeTime.String(),
  377. p.ReceiveBytes,
  378. p.TransmitBytes))
  379. }
  380. }
  381. conf = retStr.String()
  382. return
  383. }
  384. func ipsString(ipns []net.IPNet) string {
  385. ss := make([]string, 0, len(ipns))
  386. for _, ipn := range ipns {
  387. ss = append(ss, ipn.String())
  388. }
  389. return strings.Join(ss, ", ")
  390. }
  391. func ResolveOutgoingInterface() string {
  392. routes, _ := netlink.RouteGet(net.ParseIP("1.1.1.1"))
  393. for _, route := range routes {
  394. log.Printf("connected on linkindexlink %d\n", route.LinkIndex)
  395. link, err := netlink.LinkByIndex(route.LinkIndex)
  396. if err != nil {
  397. fmt.Printf("unable to get the link. error = %s\n", err)
  398. } else {
  399. log.Printf("connected on olink %s\n", link.Attrs().Name)
  400. return link.Attrs().Name
  401. }
  402. }
  403. return ""
  404. }
  405. func ResolveHostIp() string {
  406. netInterfaceAddresses, err := net.InterfaceAddrs()
  407. if err != nil {
  408. return ""
  409. }
  410. for _, netInterfaceAddress := range netInterfaceAddresses {
  411. fmt.Printf("netInterfaceAddress = %s\n", netInterfaceAddress)
  412. networkIp, ok := netInterfaceAddress.(*net.IPNet)
  413. if ok && !networkIp.IP.IsLoopback() && networkIp.IP.To4() != nil {
  414. ip := networkIp.IP.String()
  415. fmt.Println("Resolved Host IP: " + ip)
  416. return ip
  417. }
  418. }
  419. return ""
  420. }
  421. func CreateKeyPairFromSeed(seed string) (privateKeyStr string, publicKeyStr string, err error) {
  422. var edgePublicKey []byte
  423. seedHash := sha256.Sum256([]byte(seed))
  424. //create a curve25519 point (for privateKey)
  425. seedHash[0] &= 248
  426. seedHash[31] &= 127
  427. seedHash[31] |= 64
  428. privateKeyStr = base64.StdEncoding.EncodeToString(seedHash[:])
  429. fmt.Printf("privateKey = %s\n", privateKeyStr)
  430. // build publickey from privatekey
  431. edgePublicKey, err = curve25519.X25519(seedHash[:], curve25519.Basepoint)
  432. if err != nil {
  433. log.Fatalf("curve25519.X25519() failed: %v", err)
  434. return
  435. }
  436. publicKeyStr = base64.StdEncoding.EncodeToString(edgePublicKey)
  437. fmt.Printf("publicKey = %s\n", publicKeyStr)
  438. return
  439. }
  440. func Chacha20poly1305Decrypt(key []byte, encryptedMsg []byte) (message []byte, err error) {
  441. var aead cipher.AEAD
  442. aead, err = chacha20poly1305.NewX(key)
  443. if err != nil {
  444. fmt.Printf("failed to create aead from key")
  445. return
  446. }
  447. if len(encryptedMsg) < aead.NonceSize() {
  448. panic("ciphertext too short")
  449. }
  450. // Split nonce and ciphertext.
  451. nonce, ciphertext := encryptedMsg[:aead.NonceSize()], encryptedMsg[aead.NonceSize():]
  452. // Decrypt the message and check it wasn't tampered with.
  453. message, err = aead.Open(nil, nonce, ciphertext, nil)
  454. if err != nil {
  455. return
  456. }
  457. return
  458. }
  459. func Chacha20poly1305Encrypt(key []byte, message []byte) (encryptedMsg []byte, err error) {
  460. var aead cipher.AEAD
  461. aead, err = chacha20poly1305.NewX(key)
  462. if err != nil {
  463. fmt.Printf("failed to create aead from key")
  464. return
  465. }
  466. // Select a random nonce, and leave capacity for the ciphertext.
  467. nonce := make([]byte, aead.NonceSize(), aead.NonceSize()+len(message)+aead.Overhead())
  468. if _, err = rand.Read(nonce); err != nil {
  469. return
  470. }
  471. // Encrypt the message and append the ciphertext to the nonce.
  472. encryptedMsg = aead.Seal(nonce, nonce, message, nil)
  473. return
  474. }
  475. type WireguardLink struct {
  476. netlink.LinkAttrs
  477. }
  478. func (generic *WireguardLink) Attrs() *netlink.LinkAttrs {
  479. return &generic.LinkAttrs
  480. }
  481. func (generic *WireguardLink) Type() string {
  482. return "wireguard"
  483. }
  484. func MakeWireguardInterface() (managerPrivate crypto.PrivateKey, managerPublic crypto.PrivateKey, outgoingLink netlink.Link, wireguardLink netlink.Link, err error) {
  485. la := netlink.NewLinkAttrs()
  486. la.Name = InterfaceName
  487. linkExists := true
  488. link, err := netlink.LinkByName(la.Name)
  489. if err != nil || link == nil {
  490. linkExists = false
  491. }
  492. var outgoingInterfaceName string
  493. managerPrivate, managerPublic, err = GenerateKeyPair()
  494. if !linkExists {
  495. fmt.Printf("interface %s does not exist. Creating\n", la.Name)
  496. link := WireguardLink{la}
  497. err = netlink.LinkAdd(&link)
  498. if err != nil {
  499. fmt.Printf("Failed to add link for %s, err: %s\n", la.Name, err)
  500. return nil, nil, nil, nil, err
  501. }
  502. outgoingInterfaceName = ResolveOutgoingInterface()
  503. wireguardLink, err = netlink.LinkByName(la.Name)
  504. c, _ := wgctrl.New()
  505. defer c.Close()
  506. if _, err = os.Stat(ConfFilename); err == nil {
  507. _, err = ConfigureWgFromSavedConfig()
  508. if err != nil {
  509. fmt.Println("Error loading with saved config, err:", err.Error())
  510. }
  511. var device *wgtypes.Device
  512. device, err = c.Device(InterfaceName)
  513. if err != nil {
  514. fmt.Printf("Error finding device at %s, err:%s", InterfaceName, err)
  515. }
  516. privateBytes, _ := base64.StdEncoding.DecodeString(device.PrivateKey.String())
  517. publicBytes, _ := base64.StdEncoding.DecodeString(device.PrivateKey.PublicKey().String())
  518. managerPrivate = privateBytes
  519. managerPublic = publicBytes
  520. } else {
  521. //////////////////////////////////////////////////////////////////////
  522. //configure wireguard interface with specific details
  523. //////////////////////////////////////////////////////////////////////
  524. var privateBytes [32]byte
  525. if ok := CheckType(&privateBytes, managerPrivate); !ok {
  526. panic("ecdh: unexpected type of manager public key")
  527. }
  528. var key wgtypes.Key
  529. key, err = wgtypes.NewKey(privateBytes[:])
  530. port := WireguardListenPort
  531. cfg := wgtypes.Config{
  532. PrivateKey: &key,
  533. ListenPort: &port,
  534. ReplacePeers: true,
  535. }
  536. if err != nil {
  537. return
  538. }
  539. err = c.ConfigureDevice(la.Name, cfg)
  540. if err != nil {
  541. fmt.Println("wgctrlclient ConfigureDevice failed:", err.Error())
  542. return
  543. }
  544. //////////////////////////////////////////////////////////////////////
  545. //end of configure wireguard interface with manager specific details
  546. //////////////////////////////////////////////////////////////////////
  547. }
  548. //Add the ip addr
  549. managerIp, _ := GetFirstAddr()
  550. fmt.Printf("managerIP : %s\n", managerIp)
  551. err = AddLinkAddress(wireguardLink, managerIp.String(), "")
  552. if err != nil {
  553. fmt.Printf("11 Failed to set address to %s, err: %s\n", la.Name, err)
  554. return nil, nil, nil, nil, err
  555. }
  556. } else {
  557. fmt.Printf("interface %s exists\n", la.Name)
  558. outgoingInterfaceName = ResolveOutgoingInterface()
  559. wireguardLink, _ = netlink.LinkByName(la.Name)
  560. }
  561. // bring the interface up
  562. err = netlink.LinkSetUp(wireguardLink)
  563. if err != nil {
  564. fmt.Printf("Could not set the link up: %s\n", err)
  565. return nil, nil, nil, nil, err
  566. }
  567. // add a default route for the wireguard interface
  568. foundRoute := false
  569. _, dst, _ := net.ParseCIDR(IPRange)
  570. route := netlink.Route{LinkIndex: wireguardLink.Attrs().Index, Dst: dst, Scope: netlink.SCOPE_LINK}
  571. //fmt.Printf("to be added route= %s\n", route.String())
  572. routes, err := netlink.RouteList(link, netlink.FAMILY_V4)
  573. for _, r := range routes {
  574. //fmt.Printf("existing route= %s\n", r.String())
  575. if CompareRoutes(route, r) {
  576. foundRoute = true
  577. break
  578. }
  579. }
  580. if !foundRoute {
  581. if err := netlink.RouteAdd(&route); err != nil {
  582. fmt.Printf("Failed to add a default route. err=%s\n", err)
  583. return nil, nil, nil, nil, err
  584. }
  585. }
  586. outgoingLink, _ = netlink.LinkByName(outgoingInterfaceName)
  587. return
  588. }
  589. // ipNetEqual returns true iff both IPNet are equal
  590. func ipNetEqual(ipn1 *net.IPNet, ipn2 *net.IPNet) bool {
  591. if ipn1 == ipn2 {
  592. return true
  593. }
  594. if ipn1 == nil || ipn2 == nil {
  595. return false
  596. }
  597. m1, _ := ipn1.Mask.Size()
  598. m2, _ := ipn2.Mask.Size()
  599. return m1 == m2 && ipn1.IP.Equal(ipn2.IP)
  600. }
  601. func printRoute(r netlink.Route) {
  602. fmt.Printf("LinkIndex %d\n", r.LinkIndex)
  603. fmt.Printf("ILinkIndex %d\n", r.ILinkIndex)
  604. fmt.Printf("Scope %s\n", r.Scope)
  605. fmt.Printf("Dst %s\n", r.Dst)
  606. fmt.Printf("Src %s\n", r.Src)
  607. fmt.Printf("Gw %s\n", r.Gw)
  608. fmt.Printf("MultiPath %s\n", r.MultiPath)
  609. fmt.Printf("Protocol %d\n", r.Protocol)
  610. fmt.Printf("Priority %d\n", r.Priority)
  611. fmt.Printf("Table %d\n", r.Table)
  612. fmt.Printf("Type %d\n", r.Type)
  613. fmt.Printf("Tos %d\n", r.Tos)
  614. fmt.Printf("Flags %d\n", r.Flags)
  615. fmt.Printf("MPLSDst %d\n", r.MPLSDst)
  616. fmt.Printf("NewDst %s\n", r.NewDst)
  617. fmt.Printf("Encap %s\n", r.Encap)
  618. fmt.Printf("MTU %d\n", r.MTU)
  619. fmt.Printf("AdvMSS %d\n", r.AdvMSS)
  620. fmt.Printf("Hoplimit %d\n", r.Hoplimit)
  621. }
  622. // Copied from netlink.Route.Equal and removed checks
  623. // nexthopInfoSlice(r.MultiPath).Equal(x.MultiPath) &&
  624. // r.Table == x.Table
  625. func CompareRoutes(r netlink.Route, x netlink.Route) bool {
  626. return r.LinkIndex == x.LinkIndex &&
  627. r.ILinkIndex == x.ILinkIndex &&
  628. r.Scope == x.Scope &&
  629. ipNetEqual(r.Dst, x.Dst) &&
  630. r.Src.Equal(x.Src) &&
  631. r.Gw.Equal(x.Gw) &&
  632. r.Priority == x.Priority &&
  633. r.Tos == x.Tos &&
  634. r.Hoplimit == x.Hoplimit &&
  635. r.Flags == x.Flags &&
  636. (r.MPLSDst == x.MPLSDst || (r.MPLSDst != nil && x.MPLSDst != nil && *r.MPLSDst == *x.MPLSDst)) &&
  637. (r.NewDst == x.NewDst || (r.NewDst != nil && r.NewDst.Equal(x.NewDst))) &&
  638. (r.Encap == x.Encap || (r.Encap != nil && r.Encap.Equal(x.Encap)))
  639. }
  640. // CreateWireguardInterface TODO: use this in the function MakeWireguardInterface
  641. func CreateWireguardInterface() (outgoingLink netlink.Link, wireguardLink netlink.Link, err error) {
  642. la := netlink.NewLinkAttrs()
  643. la.Name = InterfaceName
  644. linkExists := true
  645. link, err := netlink.LinkByName(la.Name)
  646. if err != nil || link == nil {
  647. linkExists = false
  648. }
  649. var outgoingInterfaceName string
  650. if !linkExists {
  651. fmt.Printf("interface %s does not exist. Creating\n", la.Name)
  652. link := WireguardLink{la}
  653. err = netlink.LinkAdd(&link)
  654. if err != nil {
  655. fmt.Printf("Failed to add link for %s, err: %s\n", la.Name, err)
  656. return
  657. }
  658. } else {
  659. fmt.Printf("interface %s exists\n", la.Name)
  660. }
  661. outgoingInterfaceName = ResolveOutgoingInterface()
  662. wireguardLink, err = netlink.LinkByName(la.Name)
  663. if err != nil {
  664. fmt.Printf("Could not find the link %S, err: %s\n", la.Name, err)
  665. }
  666. outgoingLink, _ = netlink.LinkByName(outgoingInterfaceName)
  667. routes, _ := netlink.RouteList(link, netlink.FAMILY_V4)
  668. for _, r := range routes {
  669. fmt.Printf("reote= %s\n", r.String())
  670. }
  671. // bring the interface up
  672. err = netlink.LinkSetUp(wireguardLink)
  673. if err != nil {
  674. fmt.Printf("Could not set the link up: %s\n", err)
  675. }
  676. // add a default route for the wireguard interface
  677. _, dst, _ := net.ParseCIDR(IPRange)
  678. route := netlink.Route{LinkIndex: wireguardLink.Attrs().Index, Dst: dst, Scope: netlink.SCOPE_LINK}
  679. if err1 := netlink.RouteAdd(&route); err1 != nil {
  680. fmt.Printf("Failed to add a default route. err=%s\n", err1)
  681. return
  682. }
  683. return
  684. }
  685. func ProvisionWireguardInterface(privateKey crypto.PrivateKey, ipaddrAndMask string) (outgoingLink netlink.Link, wireguardLink netlink.Link, err error) {
  686. outgoingLink, wireguardLink, err = CreateWireguardInterface()
  687. if err != nil {
  688. fmt.Printf("Failed to add link for %s, err: %s\n", InterfaceName, err)
  689. return
  690. }
  691. err = AddLinkAddress(wireguardLink, ipaddrAndMask, "")
  692. if err != nil {
  693. fmt.Printf("12 Failed to set address to %s, err: %s\n", InterfaceName, err)
  694. return nil, nil, err
  695. }
  696. var privateBytes [32]byte
  697. if ok := CheckType(&privateBytes, privateKey); !ok {
  698. panic("ecdh: unexpected type of manager public key")
  699. }
  700. key, _ := wgtypes.NewKey(privateBytes[:])
  701. port := WireguardListenPort
  702. cfg := wgtypes.Config{PrivateKey: &key, ListenPort: &port, ReplacePeers: true}
  703. var wgctrlClient *wgctrl.Client
  704. wgctrlClient, err = wgctrl.New()
  705. if err != nil {
  706. return
  707. }
  708. err = wgctrlClient.ConfigureDevice(InterfaceName, cfg)
  709. if err != nil {
  710. fmt.Printf("wgctrlclient ConfigureDevice failed, err=%s\n", err)
  711. return
  712. }
  713. // bring the interface up
  714. err = netlink.LinkSetUp(wireguardLink)
  715. if err != nil {
  716. fmt.Printf("Could not set the link up: %s\n", err)
  717. return nil, nil, err
  718. }
  719. return
  720. }
  721. func AddPeer(presharedKey string, edgePublicStr string, endpointIp string, allowedIp string, replace bool) {
  722. ipNet, err := ConstructIPAndMask(allowedIp)
  723. if err != nil {
  724. fmt.Println("ParseCIDR() failed:", err.Error())
  725. return
  726. }
  727. var allowedIps []net.IPNet
  728. allowedIps = append(allowedIps, *ipNet)
  729. presharedKeyBytes, _ := base64.StdEncoding.DecodeString(presharedKey)
  730. edgePublicBytes, _ := base64.StdEncoding.DecodeString(edgePublicStr)
  731. key, _ := wgtypes.NewKey(edgePublicBytes)
  732. pk, _ := wgtypes.NewKey(presharedKeyBytes)
  733. var dur = time.Second * PersistentKeepAliveInterval
  734. peerConfig := wgtypes.PeerConfig{
  735. PublicKey: key,
  736. PresharedKey: &pk,
  737. ReplaceAllowedIPs: true,
  738. PersistentKeepaliveInterval: &dur,
  739. AllowedIPs: allowedIps,
  740. }
  741. var ipEndpoint net.IP
  742. if endpointIp != "" {
  743. ipEndpoint = net.ParseIP(endpointIp)
  744. peerConfig.Endpoint = &net.UDPAddr{
  745. IP: ipEndpoint,
  746. Port: WireguardListenPort,
  747. }
  748. }
  749. var peers []wgtypes.PeerConfig
  750. peers = append(peers, peerConfig)
  751. cfg := wgtypes.Config{
  752. ReplacePeers: replace,
  753. Peers: peers,
  754. }
  755. wgctrlClient, err := wgctrl.New()
  756. if err != nil {
  757. fmt.Println("wgctrl.New() failed:", err.Error())
  758. return
  759. }
  760. err = wgctrlClient.ConfigureDevice(InterfaceName, cfg)
  761. if err != nil {
  762. fmt.Println("wgctrlclient ConfigureDevice failed:", err.Error())
  763. return
  764. }
  765. }
  766. func HasLinkAddress(link netlink.Link, address string) (exists bool, err error) {
  767. exists = false
  768. fmt.Printf("AddLinkAddress(linkAddressAndMask=%s)\n", address)
  769. var netmask *net.IPNet
  770. netmask, _ = ConstructIPAndMask(address)
  771. var addr = &netlink.Addr{IPNet: netmask}
  772. var addrs []netlink.Addr
  773. addrs, err = netlink.AddrList(link, netlink.FAMILY_V4)
  774. if err != nil {
  775. fmt.Printf("14 Failed to get list of addresses from %s, err: %s\n", InterfaceName, err)
  776. return
  777. }
  778. for _, a := range addrs {
  779. if addr.Equal(a) {
  780. exists = true
  781. return
  782. }
  783. }
  784. return
  785. }
  786. func AddLinkAddress(link netlink.Link, linkAddressAndMask string, peerAddress string) (err error) {
  787. //////////////////////////////////////////////////////////////////////
  788. //configure wireguard interface with manager specific details
  789. //////////////////////////////////////////////////////////////////////
  790. fmt.Printf("AddLinkAddress(linkAddressAndMask=%s)\n", linkAddressAndMask)
  791. var netmask *net.IPNet
  792. netmask, _ = ConstructIPAndMask(linkAddressAndMask)
  793. var addr = &netlink.Addr{IPNet: netmask}
  794. if peerAddress != "" {
  795. fmt.Printf("peerAddress %s\n", peerAddress)
  796. peerAddr, _ := ConstructIPAndMask(peerAddress)
  797. fmt.Printf("peerAddr %s\n", peerAddr)
  798. addr.Peer = peerAddr
  799. }
  800. err = netlink.AddrAdd(link, addr)
  801. if err != nil {
  802. fmt.Printf("14 Failed to set address to %s, err: %s\n", InterfaceName, err)
  803. return
  804. }
  805. return
  806. }
  807. func ConstructIPAndMask(inputStr string) (netipnet *net.IPNet, err error) {
  808. fmt.Printf("inputStr=%s\n", inputStr)
  809. var netipAddr net.IP
  810. netipAddr, netipnet, err = net.ParseCIDR(inputStr)
  811. if err != nil {
  812. netipAddr = net.ParseIP(inputStr)
  813. if netipAddr == nil {
  814. fmt.Printf("net.ParseIP(inputStr=%s) failed, err: %s\n", inputStr, err)
  815. return
  816. } else {
  817. netipAddr, netipnet, _ = net.ParseCIDR(inputStr + "/32")
  818. err = nil
  819. }
  820. }
  821. netipnet.IP = netipAddr
  822. return
  823. }
  824. func FindPeerByPublicKey(pubkey string) (*wgtypes.Peer, error) {
  825. peerKeyMap, err := GetPeersByPublicKey()
  826. if err != nil {
  827. return nil, err
  828. }
  829. peer1, found := peerKeyMap[pubkey]
  830. if found {
  831. return &peer1, nil
  832. }
  833. return nil, errors.New("peer not found")
  834. }
  835. func GetPeersByPublicKey() (peerKeyMap map[string]wgtypes.Peer, err error) {
  836. wgctrlClient, err := wgctrl.New()
  837. if err != nil {
  838. fmt.Printf("Failed to find the default wireguard device %s, err: %s\n", InterfaceName, err)
  839. return nil, err
  840. }
  841. defaultDevice, err := wgctrlClient.Device(InterfaceName)
  842. if err != nil {
  843. fmt.Printf("Failed to find the default wireguard device %s, err: %s\n", InterfaceName, err)
  844. return nil, err
  845. }
  846. devices, err := wgctrlClient.Devices()
  847. if err != nil {
  848. return
  849. }
  850. for _, device := range devices {
  851. fmt.Printf("device = %s\n", device.Name)
  852. fmt.Printf("Peers(%d):\n", len(device.Peers))
  853. for _, p := range device.Peers {
  854. fmt.Printf(" public=%s, psk=%s\n", p.PublicKey, p.PresharedKey)
  855. fmt.Printf(" ips=%s, endpoint=%s\n", p.AllowedIPs, p.Endpoint)
  856. }
  857. }
  858. peerKeyMap = make(map[string]wgtypes.Peer)
  859. var peers = defaultDevice.Peers
  860. for _, p := range peers {
  861. peerPubkey := base64.StdEncoding.EncodeToString(p.PublicKey[:])
  862. peerKeyMap[peerPubkey] = p
  863. }
  864. return
  865. }
  866. func GetPeersByIp() (peerIpMap map[netip.Addr]wgtypes.Peer, err error) {
  867. wgctrlClient, err := wgctrl.New()
  868. if err != nil {
  869. fmt.Printf("Failed to find the default wireguard device %s, err: %s\n", InterfaceName, err)
  870. return nil, err
  871. }
  872. defaultDevice, err := wgctrlClient.Device(InterfaceName)
  873. if err != nil {
  874. fmt.Printf("Failed to find the default wireguard device %s, err: %s\n", InterfaceName, err)
  875. return nil, err
  876. }
  877. devices, err := wgctrlClient.Devices()
  878. if err != nil {
  879. return
  880. }
  881. for _, device := range devices {
  882. fmt.Printf("device = %s\n", device.Name)
  883. fmt.Printf("Peers(%d):\n", len(device.Peers))
  884. for _, p := range device.Peers {
  885. fmt.Printf(" public=%s, psk=%s\n", p.PublicKey, p.PresharedKey)
  886. fmt.Printf(" ips=%s, endpoint=%s\n", p.AllowedIPs, p.Endpoint)
  887. }
  888. }
  889. peerIpMap = make(map[netip.Addr]wgtypes.Peer)
  890. var peers = defaultDevice.Peers
  891. for _, p := range peers {
  892. allowedIp, err := netip.ParseAddr(p.AllowedIPs[0].IP.String())
  893. if err != nil {
  894. fmt.Printf("failed to parse %s\n", p.AllowedIPs[0].IP.String())
  895. }
  896. peerIpMap[allowedIp] = p
  897. }
  898. return
  899. }
  900. func GetFirstAddr() (netip.Addr, *net.IPNet) {
  901. _, iPRange, _ := net.ParseCIDR(IPRange)
  902. ipFirst, _ := netip.ParseAddr(iPRange.IP.String())
  903. return ipFirst.Next(), iPRange
  904. }
  905. func FindUnusedIp() (ipNext netip.Addr, err error) {
  906. peers, _ := GetPeersByIp()
  907. ipNext, ipRange := GetFirstAddr()
  908. /*
  909. for ip, p := range peers {
  910. fmt.Printf("ip = %s, peer=%s\n", ip, p)
  911. }
  912. */
  913. ipNext = ipNext.Next()
  914. ip := net.ParseIP(ipNext.String())
  915. for {
  916. //fmt.Printf("ipNext = %s, ip=%s\n", ipNext, ip)
  917. if _, ok := peers[ipNext]; !ok {
  918. //fmt.Printf("ipNext=%s, ip=%s is not in peers\n", ipNext, ip)
  919. if ipRange.Contains(ip) {
  920. //fmt.Printf("ipRange container ip=%s\n", ip)
  921. return ipNext, nil
  922. } else {
  923. //fmt.Printf("ipRange does not contain ip=%s\n", ip)
  924. break
  925. }
  926. } else {
  927. //fmt.Printf("ipNext=%s, ip=%s is in peers\n", ipNext, ip)
  928. }
  929. ipNext = ipNext.Next()
  930. ip = net.ParseIP(ipNext.String())
  931. }
  932. return ipNext, errors.New("cant find unused ip address")
  933. }
  934. type ClientConfig struct {
  935. Signature string
  936. InitialPort int
  937. ManagerIP string
  938. }
  939. func ReadClientConfig(path string) (conf ClientConfig, meta toml.MetaData, err error) {
  940. meta, err = toml.DecodeFile(path, &conf)
  941. if err != nil {
  942. fmt.Fprintln(os.Stderr, err)
  943. }
  944. return
  945. }
  946. func VerifyVmInstanceId(instanceId string) (valid bool) {
  947. //TODO:check the instanceId with the list of instances we created
  948. //For now return true
  949. if len(instanceId) > 2 {
  950. return true
  951. } else {
  952. return false
  953. }
  954. }
  955. func StoreWgConfig() (err error) {
  956. //////////////////////////////////////////////
  957. //Write everything we have, to a file
  958. currentConf, err := GetCurrentConfWithSecrets()
  959. if err != nil {
  960. log.Printf("error getting config, %s", err)
  961. return
  962. }
  963. err = os.WriteFile(ConfFilename, []byte(currentConf), 0600)
  964. if err != nil {
  965. log.Printf("Write file for config key failed, %s", err)
  966. }
  967. return
  968. }
  969. func StoreWgConfigUsingWgCommand() (err error) {
  970. /// get conf using /usr/bin/wg
  971. cmd := exec.Command("/usr/bin/wg", "showconf", InterfaceName)
  972. var stdout, stderr bytes.Buffer
  973. cmd.Stdout = &stdout
  974. cmd.Stderr = &stderr
  975. err = cmd.Run()
  976. if err != nil {
  977. fmt.Printf("Failed to run wg showconf wg3: %s\n", err)
  978. fmt.Printf("stderr %s\n", stderr.String())
  979. os.Exit(ExitSetupFailed)
  980. return
  981. }
  982. err = os.WriteFile(ConfFilename, stdout.Bytes(), 0600)
  983. return
  984. }
  985. func ConfigureWgFromSavedConfig() (specialContent []string, err error) {
  986. /// set conf using /usr/bin/wg
  987. confBytes, err := os.ReadFile(ConfFilename)
  988. if err != nil {
  989. fmt.Printf("Error reading file %s, err: %s\n", ConfFilename, err)
  990. return
  991. }
  992. var re = regexp.MustCompile(`(?m)^\s*Address\s*=\s*(?:\d{1,3}\.){3}\d{1,3}\s*$`)
  993. confStr := string(confBytes)
  994. var removedLines []string
  995. removedLines = append(removedLines, re.FindAllString(confStr, -1)...)
  996. fmt.Printf("removedLines = %s\n", removedLines)
  997. for _, s := range removedLines {
  998. specialContent = append(specialContent, strings.ReplaceAll(s, " ", ""))
  999. }
  1000. confStr1 := re.ReplaceAllString(confStr, "")
  1001. fmt.Printf("confStr1 = %s\n", confStr1)
  1002. tmpFileName := ConfFilename + ".tmp"
  1003. fmt.Printf("tmpFileName = %s\n", tmpFileName)
  1004. err1 := os.WriteFile(tmpFileName, []byte(confStr1), 0660)
  1005. if err1 != nil {
  1006. fmt.Printf("Error writing curated file %s, err: %s\n", tmpFileName, err)
  1007. return
  1008. }
  1009. cmd := exec.Command("/usr/bin/wg", "setconf", InterfaceName, tmpFileName)
  1010. var stdout, stderr bytes.Buffer
  1011. cmd.Stdout = &stdout
  1012. cmd.Stderr = &stderr
  1013. fmt.Printf("running %s\n", cmd)
  1014. err = cmd.Run()
  1015. _ = os.Remove(tmpFileName)
  1016. if err != nil {
  1017. fmt.Printf("Failed to run %s, err: %s\n", cmd, err)
  1018. fmt.Printf("stderr %s\n", stderr.String())
  1019. os.Exit(ExitSetupFailed)
  1020. }
  1021. return
  1022. }