all repos — flounder @ 2892b866cd0a7127e8aa8ef037e1a9dc405e2d79

A small site builder for the Gemini protocol

sftp.go (view raw)

  1// SFTP server for users with Flounder accounts
  2// 	A lot of this is copied from SFTPGo, but simplified for our use case.
  3package main
  4
  5import (
  6	"crypto/rand"
  7	"crypto/rsa"
  8	"crypto/x509"
  9	"encoding/pem"
 10	"fmt"
 11	"io"
 12	"io/ioutil"
 13	"log"
 14	"net"
 15	"os"
 16	"path"
 17	"path/filepath"
 18	"runtime/debug"
 19	"time"
 20
 21	"github.com/pkg/sftp"
 22	"golang.org/x/crypto/ssh"
 23)
 24
 25type Connection struct {
 26	User string
 27}
 28
 29func (con *Connection) Fileread(request *sftp.Request) (io.ReaderAt, error) {
 30	// check user perms -- cant read others hidden files
 31	userDir := getUserDirectory(con.User) // NOTE -- not cross platform
 32	fullpath := path.Join(userDir, filepath.Clean(request.Filepath))
 33	f, err := os.OpenFile(fullpath, os.O_RDONLY, 0)
 34	if err != nil {
 35		return nil, err
 36	}
 37	return f, nil
 38}
 39
 40func (conn *Connection) Filewrite(request *sftp.Request) (io.WriterAt, error) {
 41	// check user perms -- cant write others files
 42	userDir := getUserDirectory(conn.User) // NOTE -- not cross platform
 43	fullpath := path.Join(userDir, filepath.Clean(request.Filepath))
 44	err := checkIfValidFile(conn.User, fullpath, []byte{})
 45	if err != nil {
 46		return nil, err
 47	}
 48	f, err := os.OpenFile(fullpath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0666)
 49	if err != nil {
 50		return nil, err
 51	}
 52	return f, nil
 53}
 54
 55func (conn *Connection) Filelist(request *sftp.Request) (sftp.ListerAt, error) {
 56	userDir := getUserDirectory(conn.User) // NOTE -- not cross platform
 57	fullpath := path.Join(userDir, filepath.Clean(request.Filepath))
 58	switch request.Method {
 59	case "List":
 60		f, err := os.Open(fullpath)
 61		if err != nil {
 62			return nil, err
 63		}
 64		fileInfo, err := f.Readdir(-1)
 65		if err != nil {
 66			return nil, err
 67		}
 68		return listerat(fileInfo), nil
 69	case "Stat":
 70		stat, err := os.Stat(fullpath)
 71		if err != nil {
 72			return nil, err
 73		}
 74		return listerat([]os.FileInfo{stat}), nil
 75	}
 76	return nil, fmt.Errorf("Invalid command")
 77}
 78
 79func (conn *Connection) Filecmd(request *sftp.Request) error {
 80	// remove, rename, setstat? find out
 81	userDir := getUserDirectory(conn.User) // NOTE -- not cross platform
 82	fullpath := path.Join(userDir, filepath.Clean(request.Filepath))
 83	targetPath := path.Join(userDir, filepath.Clean(request.Target))
 84	var err error
 85	switch request.Method {
 86	case "Remove":
 87		err = os.Remove(fullpath)
 88	case "Mkdir":
 89		err = os.Mkdir(fullpath, 0755)
 90	case "Rename":
 91		err := checkIfValidFile(conn.User, targetPath, []byte{})
 92		if err != nil {
 93			return err
 94		}
 95		err = os.Rename(fullpath, targetPath)
 96	}
 97	if err != nil {
 98		return err
 99	}
100	// Rename, Mkdir
101	return nil
102}
103
104// TODO hide hidden folders
105// Users have write persm on their files, read perms on all
106
107func buildHandlers(connection *Connection) sftp.Handlers {
108	return sftp.Handlers{
109		connection,
110		connection,
111		connection,
112		connection,
113	}
114}
115
116// Based on example server code from golang.org/x/crypto/ssh and server_standalone
117func runSFTPServer() {
118	if !c.EnableSFTP {
119		return
120	}
121	// An SSH server is represented by a ServerConfig, which holds
122	// certificate details and handles authentication of ServerConns.
123	config := &ssh.ServerConfig{
124		PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
125			// Should use constant-time compare (or better, salt+hash) in
126			// a production setting.
127			if isOkUsername(c.User()) != nil { // extra check, probably unnecessary
128				return nil, fmt.Errorf("Invalid username")
129			}
130			_, _, err := checkLogin(c.User(), string(pass))
131			// TODO maybe give admin extra permissions?
132			if err != nil {
133				return nil, fmt.Errorf("password rejected for %q", c.User())
134			} else {
135				log.Printf("Login: %s\n", c.User())
136				return nil, nil
137			}
138		},
139	}
140
141	// TODO generate key automatically
142	if _, err := os.Stat(c.HostKeyPath); os.IsNotExist(err) {
143		// path/to/whatever does not exist
144		log.Println("Host key not found, generating host key")
145		err := GenerateRSAKeys()
146		if err != nil {
147			log.Fatal(err)
148		}
149	}
150
151	privateBytes, err := ioutil.ReadFile(c.HostKeyPath)
152	if err != nil {
153		log.Fatal("Failed to load private key", err)
154	}
155
156	private, err := ssh.ParsePrivateKey(privateBytes)
157	if err != nil {
158		log.Fatal("Failed to parse private key", err)
159	}
160
161	config.AddHostKey(private)
162
163	listener, err := net.Listen("tcp", "0.0.0.0:2024")
164	if err != nil {
165		log.Fatal("failed to listen for connection", err)
166	}
167
168	log.Printf("SFTP server listening on %v\n", listener.Addr())
169
170	for {
171		conn, err := listener.Accept()
172		if err != nil {
173			log.Fatal(err)
174		}
175		go acceptInboundConnection(conn, config)
176	}
177}
178
179func acceptInboundConnection(conn net.Conn, config *ssh.ServerConfig) {
180	defer func() {
181		if r := recover(); r != nil {
182			log.Printf("panic in AcceptInboundConnection: %#v stack strace: %v", r, string(debug.Stack()))
183		}
184	}()
185	ipAddr := GetIPFromRemoteAddress(conn.RemoteAddr().String())
186	log.Println("Request from IP " + ipAddr)
187	limiter := getVisitor(ipAddr)
188	if limiter.Allow() == false {
189		conn.Close()
190		return
191	}
192	// Before beginning a handshake must be performed on the incoming net.Conn
193	// we'll set a Deadline for handshake to complete, the default is 2 minutes as OpenSSH
194	conn.SetDeadline(time.Now().Add(2 * time.Minute))
195
196	// Before use, a handshake must be performed on the incoming net.Conn.
197	sconn, chans, reqs, err := ssh.NewServerConn(conn, config)
198	if err != nil {
199		log.Printf("failed to accept an incoming connection: %v", err)
200		return
201	}
202	log.Println("login detected:", sconn.User())
203	fmt.Fprintf(os.Stderr, "SSH server established\n")
204	// handshake completed so remove the deadline, we'll use IdleTimeout configuration from now on
205	conn.SetDeadline(time.Time{})
206
207	defer conn.Close()
208
209	// The incoming Request channel must be serviced.
210	go ssh.DiscardRequests(reqs)
211
212	// Service the incoming Channel channel.
213	channelCounter := int64(0)
214	for newChannel := range chans {
215		// Channels have a type, depending on the application level
216		// protocol intended. In the case of an SFTP session, this is "subsystem"
217		// with a payload string of "<length=4>sftp"
218		fmt.Fprintf(os.Stderr, "Incoming channel: %s\n", newChannel.ChannelType())
219		if newChannel.ChannelType() != "session" {
220			newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
221			fmt.Fprintf(os.Stderr, "Unknown channel type: %s\n", newChannel.ChannelType())
222			continue
223		}
224		channel, requests, err := newChannel.Accept()
225		if err != nil {
226			log.Println("could not accept channel.", err)
227			continue
228		}
229
230		channelCounter++
231		fmt.Fprintf(os.Stderr, "Channel accepted\n")
232
233		// Sessions have out-of-band requests such as "shell",
234		// "pty-req" and "env".  Here we handle only the
235		// "subsystem" request.
236		go func(in <-chan *ssh.Request) {
237			for req := range in {
238				fmt.Fprintf(os.Stderr, "Request: %v\n", req.Type)
239				ok := false
240				switch req.Type {
241				case "subsystem":
242					fmt.Fprintf(os.Stderr, "Subsystem: %s\n", req.Payload[4:])
243					if string(req.Payload[4:]) == "sftp" {
244						ok = true
245					}
246				}
247				fmt.Fprintf(os.Stderr, " - accepted: %v\n", ok)
248				req.Reply(ok, nil)
249			}
250		}(requests)
251		connection := Connection{sconn.User()}
252		root := buildHandlers(&connection)
253		server := sftp.NewRequestServer(channel, root)
254		if err := server.Serve(); err == io.EOF {
255			server.Close()
256			log.Println("sftp client exited session.")
257		} else if err != nil {
258			log.Println("sftp server completed with error:", err)
259			return
260		}
261	}
262}
263
264// GenerateRSAKeys generate rsa private and public keys and write the
265// private key to specified file and the public key to the specified
266// file adding the .pub suffix
267func GenerateRSAKeys() error {
268	key, err := rsa.GenerateKey(rand.Reader, 4096)
269	if err != nil {
270		return err
271	}
272
273	o, err := os.OpenFile(c.HostKeyPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
274	if err != nil {
275		return err
276	}
277	defer o.Close()
278
279	priv := &pem.Block{
280		Type:  "RSA PRIVATE KEY",
281		Bytes: x509.MarshalPKCS1PrivateKey(key),
282	}
283
284	if err := pem.Encode(o, priv); err != nil {
285		return err
286	}
287
288	pub, err := ssh.NewPublicKey(&key.PublicKey)
289	if err != nil {
290		return err
291	}
292	return ioutil.WriteFile(c.HostKeyPath+".pub", ssh.MarshalAuthorizedKey(pub), 0600)
293}
294
295type listerat []os.FileInfo
296
297// Modeled after strings.Reader's ReadAt() implementation
298func (f listerat) ListAt(ls []os.FileInfo, offset int64) (int, error) {
299	var n int
300	if offset >= int64(len(f)) {
301		return 0, io.EOF
302	}
303	n = copy(ls, f[offset:])
304	if n < len(ls) {
305		return n, io.EOF
306	}
307	return n, nil
308}