plugin_test.go 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. // Copyright 2018-present the CoreDHCP Authors. All rights reserved
  2. // This source code is licensed under the MIT license found in the
  3. // LICENSE file in the root directory of this source tree.
  4. package file
  5. import (
  6. "net"
  7. "os"
  8. "testing"
  9. "time"
  10. "github.com/insomniacslk/dhcp/dhcpv4"
  11. "github.com/insomniacslk/dhcp/dhcpv6"
  12. "github.com/stretchr/testify/assert"
  13. "github.com/stretchr/testify/require"
  14. )
  15. func TestLoadDHCPv4Records(t *testing.T) {
  16. t.Run("valid leases", func(t *testing.T) {
  17. // setup temp leases file
  18. tmp, err := os.CreateTemp("", "test_plugin_file")
  19. require.NoError(t, err)
  20. defer func() {
  21. tmp.Close()
  22. os.Remove(tmp.Name())
  23. }()
  24. // fill temp file with valid lease lines and some comments
  25. _, err = tmp.WriteString("00:11:22:33:44:55 192.0.2.100\n")
  26. require.NoError(t, err)
  27. _, err = tmp.WriteString("11:22:33:44:55:66 192.0.2.101\n")
  28. require.NoError(t, err)
  29. _, err = tmp.WriteString("# this is a comment\n")
  30. require.NoError(t, err)
  31. records, err := LoadDHCPv4Records(tmp.Name())
  32. if !assert.NoError(t, err) {
  33. return
  34. }
  35. if assert.Equal(t, 2, len(records)) {
  36. if assert.Contains(t, records, "00:11:22:33:44:55") {
  37. assert.Equal(t, net.ParseIP("192.0.2.100"), records["00:11:22:33:44:55"])
  38. }
  39. if assert.Contains(t, records, "11:22:33:44:55:66") {
  40. assert.Equal(t, net.ParseIP("192.0.2.101"), records["11:22:33:44:55:66"])
  41. }
  42. }
  43. })
  44. t.Run("missing field", func(t *testing.T) {
  45. // setup temp leases file
  46. tmp, err := os.CreateTemp("", "test_plugin_file")
  47. require.NoError(t, err)
  48. defer func() {
  49. tmp.Close()
  50. os.Remove(tmp.Name())
  51. }()
  52. // add line with too few fields
  53. _, err = tmp.WriteString("foo\n")
  54. require.NoError(t, err)
  55. _, err = LoadDHCPv4Records(tmp.Name())
  56. assert.Error(t, err)
  57. })
  58. t.Run("invalid MAC", func(t *testing.T) {
  59. // setup temp leases file
  60. tmp, err := os.CreateTemp("", "test_plugin_file")
  61. require.NoError(t, err)
  62. defer func() {
  63. tmp.Close()
  64. os.Remove(tmp.Name())
  65. }()
  66. // add line with invalid MAC address to trigger an error
  67. _, err = tmp.WriteString("abcd 192.0.2.102\n")
  68. require.NoError(t, err)
  69. _, err = LoadDHCPv4Records(tmp.Name())
  70. assert.Error(t, err)
  71. })
  72. t.Run("invalid IP address", func(t *testing.T) {
  73. // setup temp leases file
  74. tmp, err := os.CreateTemp("", "test_plugin_file")
  75. require.NoError(t, err)
  76. defer func() {
  77. tmp.Close()
  78. os.Remove(tmp.Name())
  79. }()
  80. // add line with invalid MAC address to trigger an error
  81. _, err = tmp.WriteString("22:33:44:55:66:77 bcde\n")
  82. require.NoError(t, err)
  83. _, err = LoadDHCPv4Records(tmp.Name())
  84. assert.Error(t, err)
  85. })
  86. t.Run("lease with IPv6 address", func(t *testing.T) {
  87. // setup temp leases file
  88. tmp, err := os.CreateTemp("", "test_plugin_file")
  89. require.NoError(t, err)
  90. defer func() {
  91. tmp.Close()
  92. os.Remove(tmp.Name())
  93. }()
  94. // add line with IPv6 address instead to trigger an error
  95. _, err = tmp.WriteString("00:11:22:33:44:55 2001:db8::10:1\n")
  96. require.NoError(t, err)
  97. _, err = LoadDHCPv4Records(tmp.Name())
  98. assert.Error(t, err)
  99. })
  100. }
  101. func TestLoadDHCPv6Records(t *testing.T) {
  102. t.Run("valid leases", func(t *testing.T) {
  103. // setup temp leases file
  104. tmp, err := os.CreateTemp("", "test_plugin_file")
  105. require.NoError(t, err)
  106. defer func() {
  107. tmp.Close()
  108. os.Remove(tmp.Name())
  109. }()
  110. // fill temp file with valid lease lines and some comments
  111. _, err = tmp.WriteString("00:11:22:33:44:55 2001:db8::10:1\n")
  112. require.NoError(t, err)
  113. _, err = tmp.WriteString("11:22:33:44:55:66 2001:db8::10:2\n")
  114. require.NoError(t, err)
  115. _, err = tmp.WriteString("# this is a comment\n")
  116. require.NoError(t, err)
  117. records, err := LoadDHCPv6Records(tmp.Name())
  118. if !assert.NoError(t, err) {
  119. return
  120. }
  121. if assert.Equal(t, 2, len(records)) {
  122. if assert.Contains(t, records, "00:11:22:33:44:55") {
  123. assert.Equal(t, net.ParseIP("2001:db8::10:1"), records["00:11:22:33:44:55"])
  124. }
  125. if assert.Contains(t, records, "11:22:33:44:55:66") {
  126. assert.Equal(t, net.ParseIP("2001:db8::10:2"), records["11:22:33:44:55:66"])
  127. }
  128. }
  129. })
  130. t.Run("missing field", func(t *testing.T) {
  131. // setup temp leases file
  132. tmp, err := os.CreateTemp("", "test_plugin_file")
  133. require.NoError(t, err)
  134. defer func() {
  135. tmp.Close()
  136. os.Remove(tmp.Name())
  137. }()
  138. // add line with too few fields
  139. _, err = tmp.WriteString("foo\n")
  140. require.NoError(t, err)
  141. _, err = LoadDHCPv6Records(tmp.Name())
  142. assert.Error(t, err)
  143. })
  144. t.Run("invalid MAC", func(t *testing.T) {
  145. // setup temp leases file
  146. tmp, err := os.CreateTemp("", "test_plugin_file")
  147. require.NoError(t, err)
  148. defer func() {
  149. tmp.Close()
  150. os.Remove(tmp.Name())
  151. }()
  152. // add line with invalid MAC address to trigger an error
  153. _, err = tmp.WriteString("abcd 2001:db8::10:3\n")
  154. require.NoError(t, err)
  155. _, err = LoadDHCPv6Records(tmp.Name())
  156. assert.Error(t, err)
  157. })
  158. t.Run("invalid IP address", func(t *testing.T) {
  159. // setup temp leases file
  160. tmp, err := os.CreateTemp("", "test_plugin_file")
  161. require.NoError(t, err)
  162. defer func() {
  163. tmp.Close()
  164. os.Remove(tmp.Name())
  165. }()
  166. // add line with invalid MAC address to trigger an error
  167. _, err = tmp.WriteString("22:33:44:55:66:77 bcde\n")
  168. require.NoError(t, err)
  169. _, err = LoadDHCPv6Records(tmp.Name())
  170. assert.Error(t, err)
  171. })
  172. t.Run("lease with IPv4 address", func(t *testing.T) {
  173. // setup temp leases file
  174. tmp, err := os.CreateTemp("", "test_plugin_file")
  175. require.NoError(t, err)
  176. defer func() {
  177. tmp.Close()
  178. os.Remove(tmp.Name())
  179. }()
  180. // add line with IPv4 address instead to trigger an error
  181. _, err = tmp.WriteString("00:11:22:33:44:55 192.0.2.100\n")
  182. require.NoError(t, err)
  183. _, err = LoadDHCPv6Records(tmp.Name())
  184. assert.Error(t, err)
  185. })
  186. }
  187. func TestHandler4(t *testing.T) {
  188. t.Run("unknown MAC", func(t *testing.T) {
  189. // prepare DHCPv4 request
  190. mac := "00:11:22:33:44:55"
  191. claddr, _ := net.ParseMAC(mac)
  192. req := &dhcpv4.DHCPv4{
  193. ClientHWAddr: claddr,
  194. }
  195. resp := &dhcpv4.DHCPv4{}
  196. assert.Nil(t, resp.ClientIPAddr)
  197. // if we handle this DHCP request, nothing should change since the lease is
  198. // unknown
  199. result, stop := Handler4(req, resp)
  200. assert.Same(t, result, resp)
  201. assert.False(t, stop)
  202. assert.Nil(t, result.YourIPAddr)
  203. })
  204. t.Run("known MAC", func(t *testing.T) {
  205. // prepare DHCPv4 request
  206. mac := "00:11:22:33:44:55"
  207. claddr, _ := net.ParseMAC(mac)
  208. req := &dhcpv4.DHCPv4{
  209. ClientHWAddr: claddr,
  210. }
  211. resp := &dhcpv4.DHCPv4{}
  212. assert.Nil(t, resp.ClientIPAddr)
  213. // add lease for the MAC in the lease map
  214. clIPAddr := net.ParseIP("192.0.2.100")
  215. StaticRecords = map[string]net.IP{
  216. mac: clIPAddr,
  217. }
  218. // if we handle this DHCP request, the YourIPAddr field should be set
  219. // in the result
  220. result, stop := Handler4(req, resp)
  221. assert.Same(t, result, resp)
  222. assert.True(t, stop)
  223. assert.Equal(t, clIPAddr, result.YourIPAddr)
  224. // cleanup
  225. StaticRecords = make(map[string]net.IP)
  226. })
  227. }
  228. func TestHandler6(t *testing.T) {
  229. t.Run("unknown MAC", func(t *testing.T) {
  230. // prepare DHCPv6 request
  231. mac := "11:22:33:44:55:66"
  232. claddr, _ := net.ParseMAC(mac)
  233. req, err := dhcpv6.NewSolicit(claddr)
  234. require.NoError(t, err)
  235. resp, err := dhcpv6.NewAdvertiseFromSolicit(req)
  236. require.NoError(t, err)
  237. assert.Equal(t, 0, len(resp.GetOption(dhcpv6.OptionIANA)))
  238. // if we handle this DHCP request, nothing should change since the lease is
  239. // unknown
  240. result, stop := Handler6(req, resp)
  241. assert.False(t, stop)
  242. assert.Equal(t, 0, len(result.GetOption(dhcpv6.OptionIANA)))
  243. })
  244. t.Run("known MAC", func(t *testing.T) {
  245. // prepare DHCPv6 request
  246. mac := "11:22:33:44:55:66"
  247. claddr, _ := net.ParseMAC(mac)
  248. req, err := dhcpv6.NewSolicit(claddr)
  249. require.NoError(t, err)
  250. resp, err := dhcpv6.NewAdvertiseFromSolicit(req)
  251. require.NoError(t, err)
  252. assert.Equal(t, 0, len(resp.GetOption(dhcpv6.OptionIANA)))
  253. // add lease for the MAC in the lease map
  254. clIPAddr := net.ParseIP("2001:db8::10:1")
  255. StaticRecords = map[string]net.IP{
  256. mac: clIPAddr,
  257. }
  258. // if we handle this DHCP request, there should be a specific IANA option
  259. // set in the resulting response
  260. result, stop := Handler6(req, resp)
  261. assert.False(t, stop)
  262. if assert.Equal(t, 1, len(result.GetOption(dhcpv6.OptionIANA))) {
  263. opt := result.GetOneOption(dhcpv6.OptionIANA)
  264. assert.Contains(t, opt.String(), "IP=2001:db8::10:1")
  265. }
  266. // cleanup
  267. StaticRecords = make(map[string]net.IP)
  268. })
  269. }
  270. func TestSetupFile(t *testing.T) {
  271. // too few arguments
  272. _, _, err := setupFile(false)
  273. assert.Error(t, err)
  274. // empty file name
  275. _, _, err = setupFile(false, "")
  276. assert.Error(t, err)
  277. // trigger error in LoadDHCPv*Records
  278. _, _, err = setupFile(false, "/foo/bar")
  279. assert.Error(t, err)
  280. _, _, err = setupFile(true, "/foo/bar")
  281. assert.Error(t, err)
  282. // setup temp leases file
  283. tmp, err := os.CreateTemp("", "test_plugin_file")
  284. require.NoError(t, err)
  285. defer func() {
  286. tmp.Close()
  287. os.Remove(tmp.Name())
  288. }()
  289. t.Run("typical case", func(t *testing.T) {
  290. _, err = tmp.WriteString("00:11:22:33:44:55 2001:db8::10:1\n")
  291. require.NoError(t, err)
  292. _, err = tmp.WriteString("11:22:33:44:55:66 2001:db8::10:2\n")
  293. require.NoError(t, err)
  294. assert.Equal(t, 0, len(StaticRecords))
  295. // leases should show up in StaticRecords
  296. _, _, err = setupFile(true, tmp.Name())
  297. if assert.NoError(t, err) {
  298. assert.Equal(t, 2, len(StaticRecords))
  299. }
  300. })
  301. t.Run("autorefresh enabled", func(t *testing.T) {
  302. _, _, err = setupFile(true, tmp.Name(), autoRefreshArg)
  303. if assert.NoError(t, err) {
  304. assert.Equal(t, 2, len(StaticRecords))
  305. }
  306. // we add more leases to the file
  307. // this should trigger an event to refresh the leases database
  308. // without calling setupFile again
  309. _, err = tmp.WriteString("22:33:44:55:66:77 2001:db8::10:3\n")
  310. require.NoError(t, err)
  311. // since the event is processed asynchronously, give it a little time
  312. time.Sleep(time.Millisecond * 100)
  313. // an additional record should show up in the database
  314. // but we should respect the locking first
  315. recLock.RLock()
  316. defer recLock.RUnlock()
  317. assert.Equal(t, 3, len(StaticRecords))
  318. })
  319. }