sftp.go (view raw)
1// An example SFTP server implementation using the golang SSH package.
2// Serves the whole filesystem visible to the user, and has a hard-coded username and password,
3// so not for real use!
4package main
5
6import (
7 "fmt"
8 "io"
9 "io/ioutil"
10 "log"
11 "net"
12 "os"
13 "path"
14 "path/filepath"
15 "strings"
16
17 "github.com/pkg/sftp"
18 "golang.org/x/crypto/ssh"
19)
20
21type Connection struct {
22 User string
23}
24
25func (con *Connection) Fileread(request *sftp.Request) (io.ReaderAt, error) {
26 // check user perms -- cant read others hidden files
27 fullpath := path.Join(c.FilesDirectory, filepath.Clean(request.Filepath))
28 f, err := os.OpenFile(fullpath, os.O_RDONLY, 0)
29 if err != nil {
30 return nil, err
31 }
32 return f, nil
33}
34
35func (con *Connection) Filewrite(request *sftp.Request) (io.WriterAt, error) {
36 // check user perms -- cant write others files
37 fullpath := path.Join(c.FilesDirectory, filepath.Clean(request.Filepath))
38 userDir := getUserDirectory(con.User) // NOTE -- not cross platform
39 if strings.HasPrefix(fullpath, userDir) {
40 f, err := os.OpenFile(fullpath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0666)
41 if err != nil {
42 return nil, err
43 }
44 return f, nil
45 } else {
46 return nil, fmt.Errorf("Invalid permissions")
47 }
48}
49
50func (conn *Connection) Filelist(request *sftp.Request) (sftp.ListerAt, error) {
51 fullpath := path.Join(c.FilesDirectory, filepath.Clean(request.Filepath))
52 switch request.Method {
53 case "List":
54 f, err := os.Open(fullpath)
55 if err != nil {
56 return nil, err
57 }
58 fileInfo, err := f.Readdir(-1)
59 if err != nil {
60 return nil, err
61 }
62 return listerat(fileInfo), nil
63 case "Stat":
64 stat, err := os.Stat(fullpath)
65 if err != nil {
66 return nil, err
67 }
68 return listerat([]os.FileInfo{stat}), nil
69 }
70 return nil, fmt.Errorf("Invalid command")
71}
72
73func (c *Connection) Filecmd(request *sftp.Request) error {
74 // remove, rename, setstat? find out
75 return nil
76}
77
78// TODO hide hidden folders
79// Users have write persm on their files, read perms on all
80
81func buildHandlers(connection *Connection) sftp.Handlers {
82 return sftp.Handlers{
83 connection,
84 connection,
85 connection,
86 connection,
87 }
88}
89
90// Based on example server code from golang.org/x/crypto/ssh and server_standalone
91func runSFTPServer() {
92
93 // An SSH server is represented by a ServerConfig, which holds
94 // certificate details and handles authentication of ServerConns.
95 config := &ssh.ServerConfig{
96 PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
97 // Should use constant-time compare (or better, salt+hash) in
98 // a production setting.
99 if isOkUsername(c.User()) != nil { // extra check, probably unnecessary
100 return nil, fmt.Errorf("Invalid username")
101 }
102 _, _, err := checkLogin(c.User(), string(pass))
103 // TODO maybe give admin extra permissions?
104 fmt.Fprintf(os.Stderr, "Login: %s\n", c.User())
105 if err != nil {
106 return nil, fmt.Errorf("password rejected for %q", c.User())
107 } else {
108 return nil, nil
109 }
110 },
111 }
112
113 privateBytes, err := ioutil.ReadFile("id_rsa")
114 if err != nil {
115 log.Fatal("Failed to load private key", err)
116 }
117
118 private, err := ssh.ParsePrivateKey(privateBytes)
119 if err != nil {
120 log.Fatal("Failed to parse private key", err)
121 }
122
123 config.AddHostKey(private)
124
125 // Once a ServerConfig has been configured, connections can be
126 // accepted.
127 listener, err := net.Listen("tcp", "0.0.0.0:2024")
128 if err != nil {
129 log.Fatal("failed to listen for connection", err)
130 }
131 fmt.Printf("Listening on %v\n", listener.Addr())
132
133 nConn, err := listener.Accept()
134 if err != nil {
135 log.Fatal("failed to accept incoming connection", err)
136 }
137
138 // Before use, a handshake must be performed on the incoming net.Conn.
139 sconn, chans, reqs, err := ssh.NewServerConn(nConn, config)
140 if err != nil {
141 log.Fatal("failed to handshake", err)
142 }
143 log.Println("login detected:", sconn.User())
144 fmt.Fprintf(os.Stderr, "SSH server established\n")
145
146 // The incoming Request channel must be serviced.
147 go ssh.DiscardRequests(reqs)
148
149 // Service the incoming Channel channel.
150 for newChannel := range chans {
151 // Channels have a type, depending on the application level
152 // protocol intended. In the case of an SFTP session, this is "subsystem"
153 // with a payload string of "<length=4>sftp"
154 fmt.Fprintf(os.Stderr, "Incoming channel: %s\n", newChannel.ChannelType())
155 if newChannel.ChannelType() != "session" {
156 newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
157 fmt.Fprintf(os.Stderr, "Unknown channel type: %s\n", newChannel.ChannelType())
158 continue
159 }
160 channel, requests, err := newChannel.Accept()
161 if err != nil {
162 log.Fatal("could not accept channel.", err)
163 }
164 fmt.Fprintf(os.Stderr, "Channel accepted\n")
165
166 // Sessions have out-of-band requests such as "shell",
167 // "pty-req" and "env". Here we handle only the
168 // "subsystem" request.
169 go func(in <-chan *ssh.Request) {
170 for req := range in {
171 fmt.Fprintf(os.Stderr, "Request: %v\n", req.Type)
172 ok := false
173 switch req.Type {
174 case "subsystem":
175 fmt.Fprintf(os.Stderr, "Subsystem: %s\n", req.Payload[4:])
176 if string(req.Payload[4:]) == "sftp" {
177 ok = true
178 }
179 }
180 fmt.Fprintf(os.Stderr, " - accepted: %v\n", ok)
181 req.Reply(ok, nil)
182 }
183 }(requests)
184 connection := Connection{"alex"}
185 root := buildHandlers(&connection)
186 server := sftp.NewRequestServer(channel, root)
187 if err := server.Serve(); err == io.EOF {
188 server.Close()
189 log.Print("sftp client exited session.")
190 } else if err != nil {
191 log.Fatal("sftp server completed with error:", err)
192 }
193 }
194}
195
196type listerat []os.FileInfo
197
198// Modeled after strings.Reader's ReadAt() implementation
199func (f listerat) ListAt(ls []os.FileInfo, offset int64) (int, error) {
200 var n int
201 if offset >= int64(len(f)) {
202 return 0, io.EOF
203 }
204 n = copy(ls, f[offset:])
205 if n < len(ls) {
206 return n, io.EOF
207 }
208 return n, nil
209}