authenticator.go 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. package socks
  2. import (
  3. "fmt"
  4. "io"
  5. )
  6. // Authenticator is the connection passed in as a reader/writer to support different authentication types
  7. type Authenticator interface {
  8. Handle(io.Reader, io.Writer) error
  9. }
  10. // NoAuthAuthenticator is used to handle the No Authentication mode
  11. type NoAuthAuthenticator struct{}
  12. // NewNoAuthAuthenticator creates a authless Authenticator
  13. func NewNoAuthAuthenticator() Authenticator {
  14. return &NoAuthAuthenticator{}
  15. }
  16. // Handle writes back the version and NoAuth
  17. func (a *NoAuthAuthenticator) Handle(reader io.Reader, writer io.Writer) error {
  18. _, err := writer.Write([]byte{socks5Version, NoAuth})
  19. return err
  20. }
  21. // UserPassAuthAuthenticator is used to handle the user/password mode
  22. type UserPassAuthAuthenticator struct {
  23. IsValid func(string, string) bool
  24. }
  25. // NewUserPassAuthAuthenticator creates a new username/password validator Authenticator
  26. func NewUserPassAuthAuthenticator(isValid func(string, string) bool) Authenticator {
  27. return &UserPassAuthAuthenticator{
  28. IsValid: isValid,
  29. }
  30. }
  31. // Handle writes back the version and NoAuth
  32. func (a *UserPassAuthAuthenticator) Handle(reader io.Reader, writer io.Writer) error {
  33. if _, err := writer.Write([]byte{socks5Version, UserPassAuth}); err != nil {
  34. return err
  35. }
  36. // Get the version and username length
  37. header := []byte{0, 0}
  38. if _, err := io.ReadAtLeast(reader, header, 2); err != nil {
  39. return err
  40. }
  41. // Ensure compatibility. Someone call E-harmony
  42. if header[0] != userAuthVersion {
  43. return fmt.Errorf("Unsupported auth version: %v", header[0])
  44. }
  45. // Get the user name
  46. userLen := int(header[1])
  47. user := make([]byte, userLen)
  48. if _, err := io.ReadAtLeast(reader, user, userLen); err != nil {
  49. return err
  50. }
  51. // Get the password length
  52. if _, err := reader.Read(header[:1]); err != nil {
  53. return err
  54. }
  55. // Get the password
  56. passLen := int(header[0])
  57. pass := make([]byte, passLen)
  58. if _, err := io.ReadAtLeast(reader, pass, passLen); err != nil {
  59. return err
  60. }
  61. // Verify the password
  62. if a.IsValid(string(user), string(pass)) {
  63. _, err := writer.Write([]byte{userAuthVersion, authSuccess})
  64. return err
  65. }
  66. // password failed. Write back failure
  67. if _, err := writer.Write([]byte{userAuthVersion, authFailure}); err != nil {
  68. return err
  69. }
  70. return fmt.Errorf("User authentication failed")
  71. }