123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411 |
- // License: GPLv3 Copyright: 2023, Kovid Goyal, <kovid at kovidgoyal.net>
- package ssh
- import (
- "archive/tar"
- "encoding/json"
- "errors"
- "fmt"
- "io/fs"
- "os"
- "path"
- "path/filepath"
- "strings"
- "time"
- "kitty/tools/config"
- "kitty/tools/utils"
- "kitty/tools/utils/paths"
- "kitty/tools/utils/shlex"
- "github.com/bmatcuk/doublestar/v4"
- "golang.org/x/sys/unix"
- )
- var _ = fmt.Print
- type EnvInstruction struct {
- key, val string
- delete_on_remote, copy_from_local, literal_quote bool
- }
- func quote_for_sh(val string, literal_quote bool) string {
- if literal_quote {
- return utils.QuoteStringForSH(val)
- }
- // See https://www.gnu.org/software/bash/manual/html_node/Double-Quotes.html
- b := strings.Builder{}
- b.Grow(len(val) + 16)
- b.WriteRune('"')
- runes := []rune(val)
- for i, ch := range runes {
- if ch == '\\' || ch == '`' || ch == '"' || (ch == '$' && i+1 < len(runes) && runes[i+1] == '(') {
- // special chars are escaped
- // $( is escaped to prevent execution
- b.WriteRune('\\')
- }
- b.WriteRune(ch)
- }
- b.WriteRune('"')
- return b.String()
- }
- func (self *EnvInstruction) Serialize(for_python bool, get_local_env func(string) (string, bool)) string {
- var unset func() string
- var export func(string) string
- if for_python {
- dumps := func(x ...any) string {
- ans, _ := json.Marshal(x)
- return utils.UnsafeBytesToString(ans)
- }
- export = func(val string) string {
- if val == "" {
- return fmt.Sprintf("export %s", dumps(self.key))
- }
- return fmt.Sprintf("export %s", dumps(self.key, val, self.literal_quote))
- }
- unset = func() string {
- return fmt.Sprintf("unset %s", dumps(self.key))
- }
- } else {
- kq := utils.QuoteStringForSH(self.key)
- unset = func() string {
- return fmt.Sprintf("unset %s", kq)
- }
- export = func(val string) string {
- return fmt.Sprintf("export %s=%s", kq, quote_for_sh(val, self.literal_quote))
- }
- }
- if self.delete_on_remote {
- return unset()
- }
- if self.copy_from_local {
- val, found := get_local_env(self.key)
- if !found {
- return ""
- }
- return export(val)
- }
- return export(self.val)
- }
- func final_env_instructions(for_python bool, get_local_env func(string) (string, bool), env ...*EnvInstruction) string {
- seen := make(map[string]int, len(env))
- ans := make([]string, 0, len(env))
- for _, ei := range env {
- q := ei.Serialize(for_python, get_local_env)
- if q != "" {
- if pos, found := seen[ei.key]; found {
- ans[pos] = q
- } else {
- seen[ei.key] = len(ans)
- ans = append(ans, q)
- }
- }
- }
- return strings.Join(ans, "\n")
- }
- type CopyInstruction struct {
- local_path, arcname string
- exclude_patterns []string
- }
- func ParseEnvInstruction(spec string) (ans []*EnvInstruction, err error) {
- const COPY_FROM_LOCAL string = "_kitty_copy_env_var_"
- ei := &EnvInstruction{}
- found := false
- ei.key, ei.val, found = strings.Cut(spec, "=")
- ei.key = strings.TrimSpace(ei.key)
- if found {
- ei.val = strings.TrimSpace(ei.val)
- if ei.val == COPY_FROM_LOCAL {
- ei.val = ""
- ei.copy_from_local = true
- }
- } else {
- ei.delete_on_remote = true
- }
- if ei.key == "" {
- err = fmt.Errorf("The env directive must not be empty")
- }
- ans = []*EnvInstruction{ei}
- return
- }
- var paths_ctx *paths.Ctx
- func resolve_file_spec(spec string, is_glob bool) ([]string, error) {
- if paths_ctx == nil {
- paths_ctx = &paths.Ctx{}
- }
- ans := os.ExpandEnv(paths_ctx.ExpandHome(spec))
- if !filepath.IsAbs(ans) {
- ans = paths_ctx.AbspathFromHome(ans)
- }
- if is_glob {
- files, err := doublestar.FilepathGlob(ans)
- if err != nil {
- return nil, fmt.Errorf("%s is not a valid glob pattern with error: %w", spec, err)
- }
- if len(files) == 0 {
- return nil, fmt.Errorf("%s matches no files", spec)
- }
- return files, nil
- }
- err := unix.Access(ans, unix.R_OK)
- if err != nil {
- if errors.Is(err, os.ErrNotExist) {
- return nil, fmt.Errorf("%s does not exist", spec)
- }
- return nil, fmt.Errorf("Cannot read from: %s with error: %w", spec, err)
- }
- return []string{ans}, nil
- }
- func get_arcname(loc, dest, home string) (arcname string) {
- if dest != "" {
- arcname = dest
- } else {
- arcname = filepath.Clean(loc)
- if strings.HasPrefix(arcname, home) {
- ra, err := filepath.Rel(home, arcname)
- if err == nil {
- arcname = ra
- }
- }
- }
- prefix := "home/"
- if strings.HasPrefix(arcname, "/") {
- prefix = "root"
- }
- return prefix + arcname
- }
- func ParseCopyInstruction(spec string) (ans []*CopyInstruction, err error) {
- args, err := shlex.Split("copy " + spec)
- if err != nil {
- return nil, err
- }
- opts, args, err := parse_copy_args(args)
- if err != nil {
- return nil, err
- }
- locations := make([]string, 0, len(args))
- for _, arg := range args {
- locs, err := resolve_file_spec(arg, opts.Glob)
- if err != nil {
- return nil, err
- }
- locations = append(locations, locs...)
- }
- if len(locations) == 0 {
- return nil, fmt.Errorf("No files to copy specified")
- }
- if len(locations) > 1 && opts.Dest != "" {
- return nil, fmt.Errorf("Specifying a remote location with more than one file is not supported")
- }
- home := paths_ctx.HomePath()
- ans = make([]*CopyInstruction, 0, len(locations))
- for _, loc := range locations {
- ci := CopyInstruction{local_path: loc, exclude_patterns: opts.Exclude}
- if opts.SymlinkStrategy != "preserve" {
- ci.local_path, err = filepath.EvalSymlinks(loc)
- if err != nil {
- return nil, fmt.Errorf("Failed to resolve symlinks in %#v with error: %w", loc, err)
- }
- }
- if opts.SymlinkStrategy == "resolve" {
- ci.arcname = get_arcname(ci.local_path, opts.Dest, home)
- } else {
- ci.arcname = get_arcname(loc, opts.Dest, home)
- }
- ans = append(ans, &ci)
- }
- return
- }
- type file_unique_id struct {
- dev, inode uint64
- }
- func excluded(pattern, path string) bool {
- if !strings.ContainsRune(pattern, '/') {
- path = filepath.Base(path)
- }
- if matched, err := doublestar.PathMatch(pattern, path); matched && err == nil {
- return true
- }
- return false
- }
- func get_file_data(callback func(h *tar.Header, data []byte) error, seen map[file_unique_id]string, local_path, arcname string, exclude_patterns []string) error {
- s, err := os.Lstat(local_path)
- if err != nil {
- return err
- }
- u, ok := s.Sys().(unix.Stat_t)
- cb := func(h *tar.Header, data []byte, arcname string) error {
- h.Name = arcname
- if h.Typeflag == tar.TypeDir {
- h.Name = strings.TrimRight(h.Name, "/") + "/"
- }
- h.Size = int64(len(data))
- h.Mode = int64(s.Mode().Perm())
- h.ModTime = s.ModTime()
- h.Format = tar.FormatPAX
- if ok {
- h.AccessTime = time.Unix(0, u.Atim.Nano())
- h.ChangeTime = time.Unix(0, u.Ctim.Nano())
- }
- return callback(h, data)
- }
- // we only copy regular files, directories and symlinks
- switch s.Mode().Type() {
- case fs.ModeSymlink:
- target, err := os.Readlink(local_path)
- if err != nil {
- return err
- }
- err = cb(&tar.Header{
- Typeflag: tar.TypeSymlink,
- Linkname: target,
- }, nil, arcname)
- if err != nil {
- return err
- }
- case fs.ModeDir:
- local_path = filepath.Clean(local_path)
- type entry struct {
- path, arcname string
- }
- stack := []entry{{local_path, arcname}}
- for len(stack) > 0 {
- x := stack[0]
- stack = stack[1:]
- entries, err := os.ReadDir(x.path)
- if err != nil {
- if x.path == local_path {
- return err
- }
- continue
- }
- err = cb(&tar.Header{Typeflag: tar.TypeDir}, nil, x.arcname)
- if err != nil {
- return err
- }
- for _, e := range entries {
- entry_path := filepath.Join(x.path, e.Name())
- aname := path.Join(x.arcname, e.Name())
- ok := true
- for _, pat := range exclude_patterns {
- if excluded(pat, entry_path) {
- ok = false
- break
- }
- }
- if !ok {
- continue
- }
- if e.IsDir() {
- stack = append(stack, entry{entry_path, aname})
- } else {
- err = get_file_data(callback, seen, entry_path, aname, exclude_patterns)
- if err != nil {
- return err
- }
- }
- }
- }
- case 0: // Regular file
- fid := file_unique_id{dev: uint64(u.Dev), inode: uint64(u.Ino)}
- if prev, ok := seen[fid]; ok { // Hard link
- err = cb(&tar.Header{Typeflag: tar.TypeLink, Linkname: prev}, nil, arcname)
- if err != nil {
- return err
- }
- }
- seen[fid] = arcname
- data, err := os.ReadFile(local_path)
- if err != nil {
- return err
- }
- err = cb(&tar.Header{Typeflag: tar.TypeReg}, data, arcname)
- if err != nil {
- return err
- }
- }
- return nil
- }
- func (ci *CopyInstruction) get_file_data(callback func(h *tar.Header, data []byte) error, seen map[file_unique_id]string) (err error) {
- ep := ci.exclude_patterns
- for _, folder_name := range []string{"__pycache__", ".DS_Store"} {
- ep = append(ep, "**/"+folder_name, "**/"+folder_name+"/**")
- }
- return get_file_data(callback, seen, ci.local_path, ci.arcname, ep)
- }
- type ConfigSet struct {
- all_configs []*Config
- }
- func config_for_hostname(hostname_to_match, username_to_match string, cs *ConfigSet) *Config {
- matcher := func(q *Config) bool {
- for _, pat := range strings.Split(q.Hostname, " ") {
- upat := "*"
- if strings.Contains(pat, "@") {
- upat, pat, _ = strings.Cut(pat, "@")
- }
- var host_matched, user_matched bool
- if matched, err := filepath.Match(pat, hostname_to_match); matched && err == nil {
- host_matched = true
- }
- if matched, err := filepath.Match(upat, username_to_match); matched && err == nil {
- user_matched = true
- }
- if host_matched && user_matched {
- return true
- }
- }
- return false
- }
- for _, c := range utils.Reversed(cs.all_configs) {
- if matcher(c) {
- return c
- }
- }
- return cs.all_configs[0]
- }
- func (self *ConfigSet) line_handler(key, val string) error {
- c := self.all_configs[len(self.all_configs)-1]
- if key == "hostname" {
- c = NewConfig()
- self.all_configs = append(self.all_configs, c)
- }
- return c.Parse(key, val)
- }
- func load_config(hostname_to_match string, username_to_match string, overrides []string, paths ...string) (*Config, []config.ConfigLine, error) {
- ans := &ConfigSet{all_configs: []*Config{NewConfig()}}
- p := config.ConfigParser{LineHandler: ans.line_handler}
- err := p.LoadConfig("ssh.conf", paths, nil)
- if err != nil {
- return nil, nil, err
- }
- final_conf := config_for_hostname(hostname_to_match, username_to_match, ans)
- bad_lines := p.BadLines()
- if len(overrides) > 0 {
- h := final_conf.Hostname
- override_parser := config.ConfigParser{LineHandler: final_conf.Parse}
- if err = override_parser.ParseOverrides(overrides...); err != nil {
- return nil, nil, err
- }
- bad_lines = append(bad_lines, override_parser.BadLines()...)
- final_conf.Hostname = h
- }
- return final_conf, bad_lines, nil
- }
|