plugin_test.go 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. // Copyright 2018-present the CoreDHCP Authors. All rights reserved
  2. // This source code is licensed under the MIT license found in the
  3. // LICENSE file in the root directory of this source tree.
  4. package file
  5. import (
  6. "io/ioutil"
  7. "net"
  8. "os"
  9. "testing"
  10. "time"
  11. "github.com/insomniacslk/dhcp/dhcpv4"
  12. "github.com/insomniacslk/dhcp/dhcpv6"
  13. "github.com/stretchr/testify/assert"
  14. "github.com/stretchr/testify/require"
  15. )
  16. func TestLoadDHCPv4Records(t *testing.T) {
  17. t.Run("valid leases", func(t *testing.T) {
  18. // setup temp leases file
  19. tmp, err := ioutil.TempFile("", "test_plugin_file")
  20. require.NoError(t, err)
  21. defer func() {
  22. tmp.Close()
  23. os.Remove(tmp.Name())
  24. }()
  25. // fill temp file with valid lease lines and some comments
  26. _, err = tmp.WriteString("00:11:22:33:44:55 192.0.2.100\n")
  27. require.NoError(t, err)
  28. _, err = tmp.WriteString("11:22:33:44:55:66 192.0.2.101\n")
  29. require.NoError(t, err)
  30. _, err = tmp.WriteString("# this is a comment\n")
  31. require.NoError(t, err)
  32. records, err := LoadDHCPv4Records(tmp.Name())
  33. if !assert.NoError(t, err) {
  34. return
  35. }
  36. if assert.Equal(t, 2, len(records)) {
  37. if assert.Contains(t, records, "00:11:22:33:44:55") {
  38. assert.Equal(t, net.ParseIP("192.0.2.100"), records["00:11:22:33:44:55"])
  39. }
  40. if assert.Contains(t, records, "11:22:33:44:55:66") {
  41. assert.Equal(t, net.ParseIP("192.0.2.101"), records["11:22:33:44:55:66"])
  42. }
  43. }
  44. })
  45. t.Run("missing field", func(t *testing.T) {
  46. // setup temp leases file
  47. tmp, err := ioutil.TempFile("", "test_plugin_file")
  48. require.NoError(t, err)
  49. defer func() {
  50. tmp.Close()
  51. os.Remove(tmp.Name())
  52. }()
  53. // add line with too few fields
  54. _, err = tmp.WriteString("foo\n")
  55. require.NoError(t, err)
  56. _, err = LoadDHCPv4Records(tmp.Name())
  57. assert.Error(t, err)
  58. })
  59. t.Run("invalid MAC", func(t *testing.T) {
  60. // setup temp leases file
  61. tmp, err := ioutil.TempFile("", "test_plugin_file")
  62. require.NoError(t, err)
  63. defer func() {
  64. tmp.Close()
  65. os.Remove(tmp.Name())
  66. }()
  67. // add line with invalid MAC address to trigger an error
  68. _, err = tmp.WriteString("abcd 192.0.2.102\n")
  69. require.NoError(t, err)
  70. _, err = LoadDHCPv4Records(tmp.Name())
  71. assert.Error(t, err)
  72. })
  73. t.Run("invalid IP address", func(t *testing.T) {
  74. // setup temp leases file
  75. tmp, err := ioutil.TempFile("", "test_plugin_file")
  76. require.NoError(t, err)
  77. defer func() {
  78. tmp.Close()
  79. os.Remove(tmp.Name())
  80. }()
  81. // add line with invalid MAC address to trigger an error
  82. _, err = tmp.WriteString("22:33:44:55:66:77 bcde\n")
  83. require.NoError(t, err)
  84. _, err = LoadDHCPv4Records(tmp.Name())
  85. assert.Error(t, err)
  86. })
  87. t.Run("lease with IPv6 address", func(t *testing.T) {
  88. // setup temp leases file
  89. tmp, err := ioutil.TempFile("", "test_plugin_file")
  90. require.NoError(t, err)
  91. defer func() {
  92. tmp.Close()
  93. os.Remove(tmp.Name())
  94. }()
  95. // add line with IPv6 address instead to trigger an error
  96. _, err = tmp.WriteString("00:11:22:33:44:55 2001:db8::10:1\n")
  97. require.NoError(t, err)
  98. _, err = LoadDHCPv4Records(tmp.Name())
  99. assert.Error(t, err)
  100. })
  101. }
  102. func TestLoadDHCPv6Records(t *testing.T) {
  103. t.Run("valid leases", func(t *testing.T) {
  104. // setup temp leases file
  105. tmp, err := ioutil.TempFile("", "test_plugin_file")
  106. require.NoError(t, err)
  107. defer func() {
  108. tmp.Close()
  109. os.Remove(tmp.Name())
  110. }()
  111. // fill temp file with valid lease lines and some comments
  112. _, err = tmp.WriteString("00:11:22:33:44:55 2001:db8::10:1\n")
  113. require.NoError(t, err)
  114. _, err = tmp.WriteString("11:22:33:44:55:66 2001:db8::10:2\n")
  115. require.NoError(t, err)
  116. _, err = tmp.WriteString("# this is a comment\n")
  117. require.NoError(t, err)
  118. records, err := LoadDHCPv6Records(tmp.Name())
  119. if !assert.NoError(t, err) {
  120. return
  121. }
  122. if assert.Equal(t, 2, len(records)) {
  123. if assert.Contains(t, records, "00:11:22:33:44:55") {
  124. assert.Equal(t, net.ParseIP("2001:db8::10:1"), records["00:11:22:33:44:55"])
  125. }
  126. if assert.Contains(t, records, "11:22:33:44:55:66") {
  127. assert.Equal(t, net.ParseIP("2001:db8::10:2"), records["11:22:33:44:55:66"])
  128. }
  129. }
  130. })
  131. t.Run("missing field", func(t *testing.T) {
  132. // setup temp leases file
  133. tmp, err := ioutil.TempFile("", "test_plugin_file")
  134. require.NoError(t, err)
  135. defer func() {
  136. tmp.Close()
  137. os.Remove(tmp.Name())
  138. }()
  139. // add line with too few fields
  140. _, err = tmp.WriteString("foo\n")
  141. require.NoError(t, err)
  142. _, err = LoadDHCPv6Records(tmp.Name())
  143. assert.Error(t, err)
  144. })
  145. t.Run("invalid MAC", func(t *testing.T) {
  146. // setup temp leases file
  147. tmp, err := ioutil.TempFile("", "test_plugin_file")
  148. require.NoError(t, err)
  149. defer func() {
  150. tmp.Close()
  151. os.Remove(tmp.Name())
  152. }()
  153. // add line with invalid MAC address to trigger an error
  154. _, err = tmp.WriteString("abcd 2001:db8::10:3\n")
  155. require.NoError(t, err)
  156. _, err = LoadDHCPv6Records(tmp.Name())
  157. assert.Error(t, err)
  158. })
  159. t.Run("invalid IP address", func(t *testing.T) {
  160. // setup temp leases file
  161. tmp, err := ioutil.TempFile("", "test_plugin_file")
  162. require.NoError(t, err)
  163. defer func() {
  164. tmp.Close()
  165. os.Remove(tmp.Name())
  166. }()
  167. // add line with invalid MAC address to trigger an error
  168. _, err = tmp.WriteString("22:33:44:55:66:77 bcde\n")
  169. require.NoError(t, err)
  170. _, err = LoadDHCPv6Records(tmp.Name())
  171. assert.Error(t, err)
  172. })
  173. t.Run("lease with IPv4 address", func(t *testing.T) {
  174. // setup temp leases file
  175. tmp, err := ioutil.TempFile("", "test_plugin_file")
  176. require.NoError(t, err)
  177. defer func() {
  178. tmp.Close()
  179. os.Remove(tmp.Name())
  180. }()
  181. // add line with IPv4 address instead to trigger an error
  182. _, err = tmp.WriteString("00:11:22:33:44:55 192.0.2.100\n")
  183. require.NoError(t, err)
  184. _, err = LoadDHCPv6Records(tmp.Name())
  185. assert.Error(t, err)
  186. })
  187. }
  188. func TestHandler4(t *testing.T) {
  189. t.Run("unknown MAC", func(t *testing.T) {
  190. // prepare DHCPv4 request
  191. mac := "00:11:22:33:44:55"
  192. claddr, _ := net.ParseMAC(mac)
  193. req := &dhcpv4.DHCPv4{
  194. ClientHWAddr: claddr,
  195. }
  196. resp := &dhcpv4.DHCPv4{}
  197. assert.Nil(t, resp.ClientIPAddr)
  198. // if we handle this DHCP request, nothing should change since the lease is
  199. // unknown
  200. result, stop := Handler4(req, resp)
  201. assert.Same(t, result, resp)
  202. assert.False(t, stop)
  203. assert.Nil(t, result.YourIPAddr)
  204. })
  205. t.Run("known MAC", func(t *testing.T) {
  206. // prepare DHCPv4 request
  207. mac := "00:11:22:33:44:55"
  208. claddr, _ := net.ParseMAC(mac)
  209. req := &dhcpv4.DHCPv4{
  210. ClientHWAddr: claddr,
  211. }
  212. resp := &dhcpv4.DHCPv4{}
  213. assert.Nil(t, resp.ClientIPAddr)
  214. // add lease for the MAC in the lease map
  215. clIPAddr := net.ParseIP("192.0.2.100")
  216. StaticRecords = map[string]net.IP{
  217. mac: clIPAddr,
  218. }
  219. // if we handle this DHCP request, the YourIPAddr field should be set
  220. // in the result
  221. result, stop := Handler4(req, resp)
  222. assert.Same(t, result, resp)
  223. assert.True(t, stop)
  224. assert.Equal(t, clIPAddr, result.YourIPAddr)
  225. // cleanup
  226. StaticRecords = make(map[string]net.IP)
  227. })
  228. }
  229. func TestHandler6(t *testing.T) {
  230. t.Run("unknown MAC", func(t *testing.T) {
  231. // prepare DHCPv6 request
  232. mac := "11:22:33:44:55:66"
  233. claddr, _ := net.ParseMAC(mac)
  234. req, err := dhcpv6.NewSolicit(claddr)
  235. require.NoError(t, err)
  236. resp, err := dhcpv6.NewAdvertiseFromSolicit(req)
  237. require.NoError(t, err)
  238. assert.Equal(t, 0, len(resp.GetOption(dhcpv6.OptionIANA)))
  239. // if we handle this DHCP request, nothing should change since the lease is
  240. // unknown
  241. result, stop := Handler6(req, resp)
  242. assert.False(t, stop)
  243. assert.Equal(t, 0, len(result.GetOption(dhcpv6.OptionIANA)))
  244. })
  245. t.Run("known MAC", func(t *testing.T) {
  246. // prepare DHCPv6 request
  247. mac := "11:22:33:44:55:66"
  248. claddr, _ := net.ParseMAC(mac)
  249. req, err := dhcpv6.NewSolicit(claddr)
  250. require.NoError(t, err)
  251. resp, err := dhcpv6.NewAdvertiseFromSolicit(req)
  252. require.NoError(t, err)
  253. assert.Equal(t, 0, len(resp.GetOption(dhcpv6.OptionIANA)))
  254. // add lease for the MAC in the lease map
  255. clIPAddr := net.ParseIP("2001:db8::10:1")
  256. StaticRecords = map[string]net.IP{
  257. mac: clIPAddr,
  258. }
  259. // if we handle this DHCP request, there should be a specific IANA option
  260. // set in the resulting response
  261. result, stop := Handler6(req, resp)
  262. assert.False(t, stop)
  263. if assert.Equal(t, 1, len(result.GetOption(dhcpv6.OptionIANA))) {
  264. opt := result.GetOneOption(dhcpv6.OptionIANA)
  265. assert.Contains(t, opt.String(), "IP=2001:db8::10:1")
  266. }
  267. // cleanup
  268. StaticRecords = make(map[string]net.IP)
  269. })
  270. }
  271. func TestSetupFile(t *testing.T) {
  272. // too few arguments
  273. _, _, err := setupFile(false)
  274. assert.Error(t, err)
  275. // empty file name
  276. _, _, err = setupFile(false, "")
  277. assert.Error(t, err)
  278. // trigger error in LoadDHCPv*Records
  279. _, _, err = setupFile(false, "/foo/bar")
  280. assert.Error(t, err)
  281. _, _, err = setupFile(true, "/foo/bar")
  282. assert.Error(t, err)
  283. // setup temp leases file
  284. tmp, err := ioutil.TempFile("", "test_plugin_file")
  285. require.NoError(t, err)
  286. defer func() {
  287. tmp.Close()
  288. os.Remove(tmp.Name())
  289. }()
  290. t.Run("typical case", func(t *testing.T) {
  291. _, err = tmp.WriteString("00:11:22:33:44:55 2001:db8::10:1\n")
  292. require.NoError(t, err)
  293. _, err = tmp.WriteString("11:22:33:44:55:66 2001:db8::10:2\n")
  294. require.NoError(t, err)
  295. assert.Equal(t, 0, len(StaticRecords))
  296. // leases should show up in StaticRecords
  297. _, _, err = setupFile(true, tmp.Name())
  298. if assert.NoError(t, err) {
  299. assert.Equal(t, 2, len(StaticRecords))
  300. }
  301. })
  302. t.Run("autorefresh enabled", func(t *testing.T) {
  303. _, _, err = setupFile(true, tmp.Name(), autoRefreshArg)
  304. if assert.NoError(t, err) {
  305. assert.Equal(t, 2, len(StaticRecords))
  306. }
  307. // we add more leases to the file
  308. // this should trigger an event to refresh the leases database
  309. // without calling setupFile again
  310. _, err = tmp.WriteString("22:33:44:55:66:77 2001:db8::10:3\n")
  311. require.NoError(t, err)
  312. // since the event is processed asynchronously, give it a little time
  313. time.Sleep(time.Millisecond * 100)
  314. // an additional record should show up in the database
  315. // but we should respect the locking first
  316. recLock.RLock()
  317. defer recLock.RUnlock()
  318. assert.Equal(t, 3, len(StaticRecords))
  319. })
  320. }