all repos — flounder @ f8d68c8141b6f949e42d900e1def34d1d8198e10

A small site builder for the Gemini protocol

sftp.go (view raw)

  1// SFTP server for users with Flounder accounts
  2// 	A lot of this is copied from SFTPGo, but simplified for our use case.
  3package main
  4
  5import (
  6	"crypto/rand"
  7	"crypto/rsa"
  8	"crypto/x509"
  9	"encoding/pem"
 10	"fmt"
 11	"io"
 12	"io/ioutil"
 13	"log"
 14	"net"
 15	"os"
 16	"path"
 17	"path/filepath"
 18	"runtime/debug"
 19	"time"
 20
 21	"github.com/pkg/sftp"
 22	"golang.org/x/crypto/ssh"
 23)
 24
 25type Connection struct {
 26	User string
 27}
 28
 29func (con *Connection) Fileread(request *sftp.Request) (io.ReaderAt, error) {
 30	// check user perms -- cant read others hidden files
 31	fullpath := path.Join(c.FilesDirectory, filepath.Clean(request.Filepath))
 32	f, err := os.OpenFile(fullpath, os.O_RDONLY, 0)
 33	if err != nil {
 34		return nil, err
 35	}
 36	return f, nil
 37}
 38
 39func (con *Connection) Filewrite(request *sftp.Request) (io.WriterAt, error) {
 40	// check user perms -- cant write others files
 41	userDir := getUserDirectory(con.User) // NOTE -- not cross platform
 42	fullpath := path.Join(userDir, filepath.Clean(request.Filepath))
 43	f, err := os.OpenFile(fullpath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0666)
 44	if err != nil {
 45		return nil, err
 46	}
 47	return f, nil
 48}
 49
 50func (conn *Connection) Filelist(request *sftp.Request) (sftp.ListerAt, error) {
 51	userDir := getUserDirectory(conn.User) // NOTE -- not cross platform
 52	fullpath := path.Join(userDir, filepath.Clean(request.Filepath))
 53	switch request.Method {
 54	case "List":
 55		f, err := os.Open(fullpath)
 56		if err != nil {
 57			return nil, err
 58		}
 59		fileInfo, err := f.Readdir(-1)
 60		if err != nil {
 61			return nil, err
 62		}
 63		return listerat(fileInfo), nil
 64	case "Stat":
 65		stat, err := os.Stat(fullpath)
 66		if err != nil {
 67			return nil, err
 68		}
 69		return listerat([]os.FileInfo{stat}), nil
 70	}
 71	return nil, fmt.Errorf("Invalid command")
 72}
 73
 74func (conn *Connection) Filecmd(request *sftp.Request) error {
 75	// remove, rename, setstat? find out
 76	userDir := getUserDirectory(conn.User) // NOTE -- not cross platform
 77	fullpath := path.Join(userDir, filepath.Clean(request.Filepath))
 78	var err error
 79	switch request.Method {
 80	case "Remove":
 81		err = os.Remove(fullpath)
 82	case "Mkdir":
 83		err = os.Mkdir(fullpath, 0755)
 84	}
 85	if err != nil {
 86		return err
 87	}
 88	// Rename, Mkdir
 89	return nil
 90}
 91
 92// TODO hide hidden folders
 93// Users have write persm on their files, read perms on all
 94
 95func buildHandlers(connection *Connection) sftp.Handlers {
 96	return sftp.Handlers{
 97		connection,
 98		connection,
 99		connection,
100		connection,
101	}
102}
103
104// Based on example server code from golang.org/x/crypto/ssh and server_standalone
105func runSFTPServer() {
106	if !c.EnableSFTP {
107		return
108	}
109	// An SSH server is represented by a ServerConfig, which holds
110	// certificate details and handles authentication of ServerConns.
111	config := &ssh.ServerConfig{
112		PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
113			// Should use constant-time compare (or better, salt+hash) in
114			// a production setting.
115			if isOkUsername(c.User()) != nil { // extra check, probably unnecessary
116				return nil, fmt.Errorf("Invalid username")
117			}
118			_, _, err := checkLogin(c.User(), string(pass))
119			// TODO maybe give admin extra permissions?
120			if err != nil {
121				return nil, fmt.Errorf("password rejected for %q", c.User())
122			} else {
123				log.Printf("Login: %s\n", c.User())
124				return nil, nil
125			}
126		},
127	}
128
129	// TODO generate key automatically
130	if _, err := os.Stat(c.HostKeyPath); os.IsNotExist(err) {
131		// path/to/whatever does not exist
132		log.Println("Host key not found, generating host key")
133		err := GenerateRSAKeys()
134		if err != nil {
135			log.Fatal(err)
136		}
137	}
138
139	privateBytes, err := ioutil.ReadFile(c.HostKeyPath)
140	if err != nil {
141		log.Fatal("Failed to load private key", err)
142	}
143
144	private, err := ssh.ParsePrivateKey(privateBytes)
145	if err != nil {
146		log.Fatal("Failed to parse private key", err)
147	}
148
149	config.AddHostKey(private)
150
151	listener, err := net.Listen("tcp", "0.0.0.0:2024")
152	if err != nil {
153		log.Fatal("failed to listen for connection", err)
154	}
155
156	log.Printf("SFTP server listening on %v\n", listener.Addr())
157
158	for {
159		conn, err := listener.Accept()
160		if err != nil {
161			log.Fatal(err)
162		}
163		go acceptInboundConnection(conn, config)
164	}
165}
166
167func acceptInboundConnection(conn net.Conn, config *ssh.ServerConfig) {
168	defer func() {
169		if r := recover(); r != nil {
170			log.Printf("panic in AcceptInboundConnection: %#v stack strace: %v", r, string(debug.Stack()))
171		}
172	}()
173	ipAddr := GetIPFromRemoteAddress(conn.RemoteAddr().String())
174	log.Println("Request from IP " + ipAddr)
175	limiter := getVisitor(ipAddr)
176	if limiter.Allow() == false {
177		conn.Close()
178		return
179	}
180	// Before beginning a handshake must be performed on the incoming net.Conn
181	// we'll set a Deadline for handshake to complete, the default is 2 minutes as OpenSSH
182	conn.SetDeadline(time.Now().Add(2 * time.Minute))
183
184	// Before use, a handshake must be performed on the incoming net.Conn.
185	sconn, chans, reqs, err := ssh.NewServerConn(conn, config)
186	if err != nil {
187		log.Printf("failed to accept an incoming connection: %v", err)
188		return
189	}
190	log.Println("login detected:", sconn.User())
191	fmt.Fprintf(os.Stderr, "SSH server established\n")
192	// handshake completed so remove the deadline, we'll use IdleTimeout configuration from now on
193	conn.SetDeadline(time.Time{})
194
195	defer conn.Close()
196
197	// The incoming Request channel must be serviced.
198	go ssh.DiscardRequests(reqs)
199
200	// Service the incoming Channel channel.
201	channelCounter := int64(0)
202	for newChannel := range chans {
203		// Channels have a type, depending on the application level
204		// protocol intended. In the case of an SFTP session, this is "subsystem"
205		// with a payload string of "<length=4>sftp"
206		fmt.Fprintf(os.Stderr, "Incoming channel: %s\n", newChannel.ChannelType())
207		if newChannel.ChannelType() != "session" {
208			newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
209			fmt.Fprintf(os.Stderr, "Unknown channel type: %s\n", newChannel.ChannelType())
210			continue
211		}
212		channel, requests, err := newChannel.Accept()
213		if err != nil {
214			log.Println("could not accept channel.", err)
215			continue
216		}
217
218		channelCounter++
219		fmt.Fprintf(os.Stderr, "Channel accepted\n")
220
221		// Sessions have out-of-band requests such as "shell",
222		// "pty-req" and "env".  Here we handle only the
223		// "subsystem" request.
224		go func(in <-chan *ssh.Request) {
225			for req := range in {
226				fmt.Fprintf(os.Stderr, "Request: %v\n", req.Type)
227				ok := false
228				switch req.Type {
229				case "subsystem":
230					fmt.Fprintf(os.Stderr, "Subsystem: %s\n", req.Payload[4:])
231					if string(req.Payload[4:]) == "sftp" {
232						ok = true
233					}
234				}
235				fmt.Fprintf(os.Stderr, " - accepted: %v\n", ok)
236				req.Reply(ok, nil)
237			}
238		}(requests)
239		connection := Connection{sconn.User()}
240		root := buildHandlers(&connection)
241		server := sftp.NewRequestServer(channel, root)
242		if err := server.Serve(); err == io.EOF {
243			server.Close()
244			log.Println("sftp client exited session.")
245		} else if err != nil {
246			log.Println("sftp server completed with error:", err)
247			return
248		}
249	}
250}
251
252// GenerateRSAKeys generate rsa private and public keys and write the
253// private key to specified file and the public key to the specified
254// file adding the .pub suffix
255func GenerateRSAKeys() error {
256	key, err := rsa.GenerateKey(rand.Reader, 4096)
257	if err != nil {
258		return err
259	}
260
261	o, err := os.OpenFile(c.HostKeyPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
262	if err != nil {
263		return err
264	}
265	defer o.Close()
266
267	priv := &pem.Block{
268		Type:  "RSA PRIVATE KEY",
269		Bytes: x509.MarshalPKCS1PrivateKey(key),
270	}
271
272	if err := pem.Encode(o, priv); err != nil {
273		return err
274	}
275
276	pub, err := ssh.NewPublicKey(&key.PublicKey)
277	if err != nil {
278		return err
279	}
280	return ioutil.WriteFile(c.HostKeyPath+".pub", ssh.MarshalAuthorizedKey(pub), 0600)
281}
282
283type listerat []os.FileInfo
284
285// Modeled after strings.Reader's ReadAt() implementation
286func (f listerat) ListAt(ls []os.FileInfo, offset int64) (int, error) {
287	var n int
288	if offset >= int64(len(f)) {
289		return 0, io.EOF
290	}
291	n = copy(ls, f[offset:])
292	if n < len(ls) {
293		return n, io.EOF
294	}
295	return n, nil
296}