all repos — flounder @ deb0c5ae2a24efdbaa79345b0fbddfd3ad74e972

A small site builder for the Gemini protocol

sftp.go (view raw)

  1// SFTP server for users with Flounder accounts
  2// 	A lot of this is copied from SFTPGo, but simplified for our use case.
  3package main
  4
  5import (
  6	"fmt"
  7	"io"
  8	"io/ioutil"
  9	"log"
 10	"net"
 11	"os"
 12	"path"
 13	"path/filepath"
 14	"runtime/debug"
 15	"strings"
 16	"time"
 17
 18	"github.com/pkg/sftp"
 19	"golang.org/x/crypto/ssh"
 20)
 21
 22type Connection struct {
 23	User string
 24}
 25
 26func (con *Connection) Fileread(request *sftp.Request) (io.ReaderAt, error) {
 27	// check user perms -- cant read others hidden files
 28	fullpath := path.Join(c.FilesDirectory, filepath.Clean(request.Filepath))
 29	f, err := os.OpenFile(fullpath, os.O_RDONLY, 0)
 30	if err != nil {
 31		return nil, err
 32	}
 33	return f, nil
 34}
 35
 36func (con *Connection) Filewrite(request *sftp.Request) (io.WriterAt, error) {
 37	// check user perms -- cant write others files
 38	fullpath := path.Join(c.FilesDirectory, filepath.Clean(request.Filepath))
 39	userDir := getUserDirectory(con.User) // NOTE -- not cross platform
 40	if strings.HasPrefix(fullpath, userDir) {
 41		f, err := os.OpenFile(fullpath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0666)
 42		if err != nil {
 43			return nil, err
 44		}
 45		return f, nil
 46	} else {
 47		return nil, fmt.Errorf("Invalid permissions")
 48	}
 49}
 50
 51func (conn *Connection) Filelist(request *sftp.Request) (sftp.ListerAt, error) {
 52	fullpath := path.Join(c.FilesDirectory, filepath.Clean(request.Filepath))
 53	if strings.Contains(request.Filepath, ".hidden") {
 54		return nil, fmt.Errorf("Invalid permissions") // TODO fix better
 55	}
 56	switch request.Method {
 57	case "List":
 58		f, err := os.Open(fullpath)
 59		if err != nil {
 60			return nil, err
 61		}
 62		fileInfo, err := f.Readdir(-1)
 63		if err != nil {
 64			return nil, err
 65		}
 66		return listerat(fileInfo), nil
 67	case "Stat":
 68		stat, err := os.Stat(fullpath)
 69		if err != nil {
 70			return nil, err
 71		}
 72		return listerat([]os.FileInfo{stat}), nil
 73	}
 74	return nil, fmt.Errorf("Invalid command")
 75}
 76
 77func (conn *Connection) Filecmd(request *sftp.Request) error {
 78	// remove, rename, setstat? find out
 79	fullpath := path.Join(c.FilesDirectory, filepath.Clean(request.Filepath))
 80	userDir := getUserDirectory(conn.User) // NOTE -- not cross platform
 81	writePerms := strings.HasPrefix(fullpath, userDir)
 82	var err error
 83	if writePerms {
 84		switch request.Method {
 85		case "Remove":
 86			err = os.Remove(fullpath)
 87		case "Mkdir":
 88			err = os.Mkdir(fullpath, 0755)
 89		}
 90		if err != nil {
 91			return err
 92		}
 93	} else {
 94		return fmt.Errorf("Unauthorized")
 95	}
 96	// Rename, Mkdir
 97	return nil
 98}
 99
100// TODO hide hidden folders
101// Users have write persm on their files, read perms on all
102
103func buildHandlers(connection *Connection) sftp.Handlers {
104	return sftp.Handlers{
105		connection,
106		connection,
107		connection,
108		connection,
109	}
110}
111
112// Based on example server code from golang.org/x/crypto/ssh and server_standalone
113func runSFTPServer() {
114	if !c.EnableSFTP {
115		return
116	}
117	// An SSH server is represented by a ServerConfig, which holds
118	// certificate details and handles authentication of ServerConns.
119	config := &ssh.ServerConfig{
120		PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
121			// Should use constant-time compare (or better, salt+hash) in
122			// a production setting.
123			if isOkUsername(c.User()) != nil { // extra check, probably unnecessary
124				return nil, fmt.Errorf("Invalid username")
125			}
126			_, _, err := checkLogin(c.User(), string(pass))
127			// TODO maybe give admin extra permissions?
128			if err != nil {
129				return nil, fmt.Errorf("password rejected for %q", c.User())
130			} else {
131				log.Printf("Login: %s\n", c.User())
132				return nil, nil
133			}
134		},
135	}
136
137	// TODO generate key automatically
138	privateBytes, err := ioutil.ReadFile("id_rsa")
139	if err != nil {
140		log.Fatal("Failed to load private key", err)
141	}
142
143	private, err := ssh.ParsePrivateKey(privateBytes)
144	if err != nil {
145		log.Fatal("Failed to parse private key", err)
146	}
147
148	config.AddHostKey(private)
149
150	listener, err := net.Listen("tcp", "0.0.0.0:2024")
151	if err != nil {
152		log.Fatal("failed to listen for connection", err)
153	}
154
155	log.Printf("SFTP server listening on %v\n", listener.Addr())
156
157	for {
158		conn, err := listener.Accept()
159		if err != nil {
160			log.Fatal(err)
161		}
162		go acceptInboundConnection(conn, config)
163	}
164}
165
166func acceptInboundConnection(conn net.Conn, config *ssh.ServerConfig) {
167	defer func() {
168		if r := recover(); r != nil {
169			log.Println("panic in AcceptInboundConnection: %#v stack strace: %v", r, string(debug.Stack()))
170		}
171	}()
172	ipAddr := GetIPFromRemoteAddress(conn.RemoteAddr().String())
173	log.Println("Request from IP " + ipAddr)
174	limiter := getVisitor(ipAddr)
175	if limiter.Allow() == false {
176		conn.Close()
177		return
178	}
179	// Before beginning a handshake must be performed on the incoming net.Conn
180	// we'll set a Deadline for handshake to complete, the default is 2 minutes as OpenSSH
181	conn.SetDeadline(time.Now().Add(2 * time.Minute))
182
183	// Before use, a handshake must be performed on the incoming net.Conn.
184	sconn, chans, reqs, err := ssh.NewServerConn(conn, config)
185	if err != nil {
186		log.Printf("failed to accept an incoming connection: %v", err)
187		return
188	}
189	log.Println("login detected:", sconn.User())
190	fmt.Fprintf(os.Stderr, "SSH server established\n")
191	// handshake completed so remove the deadline, we'll use IdleTimeout configuration from now on
192	conn.SetDeadline(time.Time{})
193
194	defer conn.Close()
195
196	// The incoming Request channel must be serviced.
197	go ssh.DiscardRequests(reqs)
198
199	// Service the incoming Channel channel.
200	channelCounter := int64(0)
201	for newChannel := range chans {
202		// Channels have a type, depending on the application level
203		// protocol intended. In the case of an SFTP session, this is "subsystem"
204		// with a payload string of "<length=4>sftp"
205		fmt.Fprintf(os.Stderr, "Incoming channel: %s\n", newChannel.ChannelType())
206		if newChannel.ChannelType() != "session" {
207			newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
208			fmt.Fprintf(os.Stderr, "Unknown channel type: %s\n", newChannel.ChannelType())
209			continue
210		}
211		channel, requests, err := newChannel.Accept()
212		if err != nil {
213			log.Println("could not accept channel.", err)
214			continue
215		}
216
217		channelCounter++
218		fmt.Fprintf(os.Stderr, "Channel accepted\n")
219
220		// Sessions have out-of-band requests such as "shell",
221		// "pty-req" and "env".  Here we handle only the
222		// "subsystem" request.
223		go func(in <-chan *ssh.Request) {
224			for req := range in {
225				fmt.Fprintf(os.Stderr, "Request: %v\n", req.Type)
226				ok := false
227				switch req.Type {
228				case "subsystem":
229					fmt.Fprintf(os.Stderr, "Subsystem: %s\n", req.Payload[4:])
230					if string(req.Payload[4:]) == "sftp" {
231						ok = true
232					}
233				}
234				fmt.Fprintf(os.Stderr, " - accepted: %v\n", ok)
235				req.Reply(ok, nil)
236			}
237		}(requests)
238		connection := Connection{sconn.User()}
239		root := buildHandlers(&connection)
240		server := sftp.NewRequestServer(channel, root)
241		if err := server.Serve(); err == io.EOF {
242			server.Close()
243			log.Println("sftp client exited session.")
244		} else if err != nil {
245			log.Println("sftp server completed with error:", err)
246			return
247		}
248	}
249}
250
251type listerat []os.FileInfo
252
253// Modeled after strings.Reader's ReadAt() implementation
254func (f listerat) ListAt(ls []os.FileInfo, offset int64) (int, error) {
255	var n int
256	if offset >= int64(len(f)) {
257		return 0, io.EOF
258	}
259	n = copy(ls, f[offset:])
260	if n < len(ls) {
261		return n, io.EOF
262	}
263	return n, nil
264}