sftp.go (view raw)
1// SFTP server for users with Flounder accounts
2// A lot of this is copied from SFTPGo, but simplified for our use case.
3package main
4
5import (
6 "crypto/rand"
7 "crypto/rsa"
8 "crypto/x509"
9 "encoding/pem"
10 "fmt"
11 "io"
12 "io/ioutil"
13 "log"
14 "net"
15 "os"
16 "path"
17 "path/filepath"
18 "runtime/debug"
19 "time"
20
21 "github.com/pkg/sftp"
22 "golang.org/x/crypto/ssh"
23)
24
25type Connection struct {
26 User string
27}
28
29func (con *Connection) Fileread(request *sftp.Request) (io.ReaderAt, error) {
30 // check user perms -- cant read others hidden files
31 userDir := getUserDirectory(con.User) // NOTE -- not cross platform
32 fullpath := path.Join(userDir, filepath.Clean(request.Filepath))
33 f, err := os.OpenFile(fullpath, os.O_RDONLY, 0)
34 if err != nil {
35 return nil, err
36 }
37 return f, nil
38}
39
40func (conn *Connection) Filewrite(request *sftp.Request) (io.WriterAt, error) {
41 // check user perms -- cant write others files
42 userDir := getUserDirectory(conn.User) // NOTE -- not cross platform
43 fullpath := path.Join(userDir, filepath.Clean(request.Filepath))
44 err := checkIfValidFile(conn.User, fullpath, []byte{})
45 if err != nil {
46 return nil, err
47 }
48 f, err := os.OpenFile(fullpath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0666)
49 if err != nil {
50 return nil, err
51 }
52 return f, nil
53}
54
55func (conn *Connection) Filelist(request *sftp.Request) (sftp.ListerAt, error) {
56 userDir := getUserDirectory(conn.User) // NOTE -- not cross platform
57 fullpath := path.Join(userDir, filepath.Clean(request.Filepath))
58 switch request.Method {
59 case "List":
60 f, err := os.Open(fullpath)
61 if err != nil {
62 return nil, err
63 }
64 fileInfo, err := f.Readdir(-1)
65 if err != nil {
66 return nil, err
67 }
68 return listerat(fileInfo), nil
69 case "Stat":
70 stat, err := os.Stat(fullpath)
71 if err != nil {
72 return nil, err
73 }
74 return listerat([]os.FileInfo{stat}), nil
75 }
76 return nil, fmt.Errorf("Invalid command")
77}
78
79func (conn *Connection) Filecmd(request *sftp.Request) error {
80 // remove, rename, setstat? find out
81 userDir := getUserDirectory(conn.User) // NOTE -- not cross platform
82 fullpath := path.Join(userDir, filepath.Clean(request.Filepath))
83 targetPath := path.Join(userDir, filepath.Clean(request.Target))
84 var err error
85 switch request.Method {
86 case "Remove":
87 err = os.Remove(fullpath)
88 case "Mkdir":
89 err = os.Mkdir(fullpath, 0755)
90 case "Rename":
91 err := checkIfValidFile(conn.User, targetPath, []byte{})
92 if err != nil {
93 return err
94 }
95 err = os.Rename(fullpath, targetPath)
96 }
97 if err != nil {
98 return err
99 }
100 // Rename, Mkdir
101 return nil
102}
103
104// TODO hide hidden folders
105// Users have write persm on their files, read perms on all
106
107func buildHandlers(connection *Connection) sftp.Handlers {
108 return sftp.Handlers{
109 connection,
110 connection,
111 connection,
112 connection,
113 }
114}
115
116// Based on example server code from golang.org/x/crypto/ssh and server_standalone
117func runSFTPServer() {
118 if !c.EnableSFTP {
119 return
120 }
121 // An SSH server is represented by a ServerConfig, which holds
122 // certificate details and handles authentication of ServerConns.
123 config := &ssh.ServerConfig{
124 PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
125 // Should use constant-time compare (or better, salt+hash) in
126 // a production setting.
127 if isOkUsername(c.User()) != nil { // extra check, probably unnecessary
128 return nil, fmt.Errorf("Invalid username")
129 }
130 _, _, err := checkLogin(c.User(), string(pass))
131 // TODO maybe give admin extra permissions?
132 if err != nil {
133 return nil, fmt.Errorf("password rejected for %q", c.User())
134 } else {
135 log.Printf("Login: %s\n", c.User())
136 return nil, nil
137 }
138 },
139 }
140
141 // TODO generate key automatically
142 if _, err := os.Stat(c.HostKeyPath); os.IsNotExist(err) {
143 // path/to/whatever does not exist
144 log.Println("Host key not found, generating host key")
145 err := GenerateRSAKeys()
146 if err != nil {
147 log.Fatal(err)
148 }
149 }
150
151 privateBytes, err := ioutil.ReadFile(c.HostKeyPath)
152 if err != nil {
153 log.Fatal("Failed to load private key", err)
154 }
155
156 private, err := ssh.ParsePrivateKey(privateBytes)
157 if err != nil {
158 log.Fatal("Failed to parse private key", err)
159 }
160
161 config.AddHostKey(private)
162
163 listener, err := net.Listen("tcp", "0.0.0.0:2024")
164 if err != nil {
165 log.Fatal("failed to listen for connection", err)
166 }
167
168 log.Printf("SFTP server listening on %v\n", listener.Addr())
169
170 for {
171 conn, err := listener.Accept()
172 if err != nil {
173 log.Fatal(err)
174 }
175 go acceptInboundConnection(conn, config)
176 }
177}
178
179func acceptInboundConnection(conn net.Conn, config *ssh.ServerConfig) {
180 defer func() {
181 if r := recover(); r != nil {
182 log.Printf("panic in AcceptInboundConnection: %#v stack strace: %v", r, string(debug.Stack()))
183 }
184 }()
185 ipAddr := GetIPFromRemoteAddress(conn.RemoteAddr().String())
186 log.Println("Request from IP " + ipAddr)
187 limiter := getVisitor(ipAddr)
188 if limiter.Allow() == false {
189 conn.Close()
190 return
191 }
192 // Before beginning a handshake must be performed on the incoming net.Conn
193 // we'll set a Deadline for handshake to complete, the default is 2 minutes as OpenSSH
194 conn.SetDeadline(time.Now().Add(2 * time.Minute))
195
196 // Before use, a handshake must be performed on the incoming net.Conn.
197 sconn, chans, reqs, err := ssh.NewServerConn(conn, config)
198 if err != nil {
199 log.Printf("failed to accept an incoming connection: %v", err)
200 return
201 }
202 log.Println("login detected:", sconn.User())
203 fmt.Fprintf(os.Stderr, "SSH server established\n")
204 // handshake completed so remove the deadline, we'll use IdleTimeout configuration from now on
205 conn.SetDeadline(time.Time{})
206
207 defer conn.Close()
208
209 // The incoming Request channel must be serviced.
210 go ssh.DiscardRequests(reqs)
211
212 // Service the incoming Channel channel.
213 channelCounter := int64(0)
214 for newChannel := range chans {
215 // Channels have a type, depending on the application level
216 // protocol intended. In the case of an SFTP session, this is "subsystem"
217 // with a payload string of "<length=4>sftp"
218 fmt.Fprintf(os.Stderr, "Incoming channel: %s\n", newChannel.ChannelType())
219 if newChannel.ChannelType() != "session" {
220 newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
221 fmt.Fprintf(os.Stderr, "Unknown channel type: %s\n", newChannel.ChannelType())
222 continue
223 }
224 channel, requests, err := newChannel.Accept()
225 if err != nil {
226 log.Println("could not accept channel.", err)
227 continue
228 }
229
230 channelCounter++
231 fmt.Fprintf(os.Stderr, "Channel accepted\n")
232
233 // Sessions have out-of-band requests such as "shell",
234 // "pty-req" and "env". Here we handle only the
235 // "subsystem" request.
236 go func(in <-chan *ssh.Request) {
237 for req := range in {
238 fmt.Fprintf(os.Stderr, "Request: %v\n", req.Type)
239 ok := false
240 switch req.Type {
241 case "subsystem":
242 fmt.Fprintf(os.Stderr, "Subsystem: %s\n", req.Payload[4:])
243 if string(req.Payload[4:]) == "sftp" {
244 ok = true
245 }
246 }
247 fmt.Fprintf(os.Stderr, " - accepted: %v\n", ok)
248 req.Reply(ok, nil)
249 }
250 }(requests)
251 connection := Connection{sconn.User()}
252 root := buildHandlers(&connection)
253 server := sftp.NewRequestServer(channel, root)
254 if err := server.Serve(); err == io.EOF {
255 server.Close()
256 log.Println("sftp client exited session.")
257 } else if err != nil {
258 log.Println("sftp server completed with error:", err)
259 return
260 }
261 }
262}
263
264// GenerateRSAKeys generate rsa private and public keys and write the
265// private key to specified file and the public key to the specified
266// file adding the .pub suffix
267func GenerateRSAKeys() error {
268 key, err := rsa.GenerateKey(rand.Reader, 4096)
269 if err != nil {
270 return err
271 }
272
273 o, err := os.OpenFile(c.HostKeyPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
274 if err != nil {
275 return err
276 }
277 defer o.Close()
278
279 priv := &pem.Block{
280 Type: "RSA PRIVATE KEY",
281 Bytes: x509.MarshalPKCS1PrivateKey(key),
282 }
283
284 if err := pem.Encode(o, priv); err != nil {
285 return err
286 }
287
288 pub, err := ssh.NewPublicKey(&key.PublicKey)
289 if err != nil {
290 return err
291 }
292 return ioutil.WriteFile(c.HostKeyPath+".pub", ssh.MarshalAuthorizedKey(pub), 0600)
293}
294
295type listerat []os.FileInfo
296
297// Modeled after strings.Reader's ReadAt() implementation
298func (f listerat) ListAt(ls []os.FileInfo, offset int64) (int, error) {
299 var n int
300 if offset >= int64(len(f)) {
301 return 0, io.EOF
302 }
303 n = copy(ls, f[offset:])
304 if n < len(ls) {
305 return n, io.EOF
306 }
307 return n, nil
308}