sftp.go (view raw)
1// An example SFTP server implementation using the golang SSH package.
2// Serves the whole filesystem visible to the user, and has a hard-coded username and password,
3// so not for real use!
4package main
5
6import (
7 "flag"
8 "fmt"
9 "io"
10 "io/ioutil"
11 "log"
12 "net"
13 "os"
14 "path"
15 "path/filepath"
16 "strings"
17
18 "github.com/pkg/sftp"
19 "golang.org/x/crypto/ssh"
20)
21
22type Connection struct {
23 User string
24}
25
26func (con *Connection) Fileread(request *sftp.Request) (io.ReaderAt, error) {
27 // check user perms -- cant read others hidden files
28 fullpath := path.Join(c.FilesDirectory, filepath.Clean(request.Filepath))
29 f, err := os.OpenFile(fullpath, os.O_RDONLY, 0)
30 if err != nil {
31 return nil, err
32 }
33 return f, nil
34}
35
36func (con *Connection) Filewrite(request *sftp.Request) (io.WriterAt, error) {
37 // check user perms -- cant write others files
38 // check if file is inside your directory -- strings prefix?
39 fullpath := path.Join(c.FilesDirectory, filepath.Clean(request.Filepath))
40 userDir := getUserDirectory(con.User) // NOTE -- not cross platform
41 if strings.HasPrefix(fullpath, userDir) {
42 f, err := os.OpenFile(fullpath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0666)
43 if err != nil {
44 return nil, err
45 }
46 return f, nil
47 } else {
48 return nil, fmt.Errorf("Invalid permissions")
49 }
50}
51
52func (conn *Connection) Filelist(request *sftp.Request) (sftp.ListerAt, error) {
53 fullpath := path.Join(c.FilesDirectory, filepath.Clean(request.Filepath))
54 switch request.Method {
55 case "List":
56 f, err := os.Open(fullpath)
57 if err != nil {
58 return nil, err
59 }
60 fileInfo, err := f.Readdir(-1)
61 if err != nil {
62 return nil, err
63 }
64 return listerat(fileInfo), nil
65 case "Stat":
66 stat, err := os.Stat(fullpath)
67 if err != nil {
68 return nil, err
69 }
70 return listerat([]os.FileInfo{stat}), nil
71 }
72 return nil, fmt.Errorf("Invalid command")
73}
74
75func (c *Connection) Filecmd(request *sftp.Request) error {
76 // remove, rename, setstat? find out
77 return nil
78}
79
80// TODO hide hidden folders
81// Users have write persm on their files, read perms on all
82
83func buildHandlers(connection *Connection) sftp.Handlers {
84 return sftp.Handlers{
85 connection,
86 connection,
87 connection,
88 connection,
89 }
90}
91
92// Based on example server code from golang.org/x/crypto/ssh and server_standalone
93func runSFTPServer() {
94
95 var (
96 readOnly bool
97 debugStderr bool
98 )
99
100 flag.BoolVar(&readOnly, "R", false, "read-only server")
101 flag.BoolVar(&debugStderr, "e", false, "debug to stderr")
102 flag.Parse()
103
104 debugStream := ioutil.Discard
105 if debugStderr {
106 debugStream = os.Stderr
107 }
108
109 // An SSH server is represented by a ServerConfig, which holds
110 // certificate details and handles authentication of ServerConns.
111 config := &ssh.ServerConfig{
112 // PublicKeyCallback
113 PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
114 // Should use constant-time compare (or better, salt+hash) in
115 // a production setting.
116 fmt.Fprintf(debugStream, "Login: %s\n", c.User())
117 if c.User() == "alex" && string(pass) == "alex" {
118 return nil, nil
119 }
120 return nil, fmt.Errorf("password rejected for %q", c.User())
121 },
122 }
123
124 privateBytes, err := ioutil.ReadFile("id_rsa")
125 if err != nil {
126 log.Fatal("Failed to load private key", err)
127 }
128
129 private, err := ssh.ParsePrivateKey(privateBytes)
130 if err != nil {
131 log.Fatal("Failed to parse private key", err)
132 }
133
134 config.AddHostKey(private)
135
136 // Once a ServerConfig has been configured, connections can be
137 // accepted.
138 listener, err := net.Listen("tcp", "0.0.0.0:2024")
139 if err != nil {
140 log.Fatal("failed to listen for connection", err)
141 }
142 fmt.Printf("Listening on %v\n", listener.Addr())
143
144 nConn, err := listener.Accept()
145 if err != nil {
146 log.Fatal("failed to accept incoming connection", err)
147 }
148
149 // Before use, a handshake must be performed on the incoming net.Conn.
150 sconn, chans, reqs, err := ssh.NewServerConn(nConn, config)
151 if err != nil {
152 log.Fatal("failed to handshake", err)
153 }
154 log.Println("login detected:", sconn.User())
155 fmt.Fprintf(debugStream, "SSH server established\n")
156
157 // The incoming Request channel must be serviced.
158 go ssh.DiscardRequests(reqs)
159
160 // Service the incoming Channel channel.
161 for newChannel := range chans {
162 // Channels have a type, depending on the application level
163 // protocol intended. In the case of an SFTP session, this is "subsystem"
164 // with a payload string of "<length=4>sftp"
165 fmt.Fprintf(debugStream, "Incoming channel: %s\n", newChannel.ChannelType())
166 if newChannel.ChannelType() != "session" {
167 newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
168 fmt.Fprintf(debugStream, "Unknown channel type: %s\n", newChannel.ChannelType())
169 continue
170 }
171 channel, requests, err := newChannel.Accept()
172 if err != nil {
173 log.Fatal("could not accept channel.", err)
174 }
175 fmt.Fprintf(debugStream, "Channel accepted\n")
176
177 // Sessions have out-of-band requests such as "shell",
178 // "pty-req" and "env". Here we handle only the
179 // "subsystem" request.
180 go func(in <-chan *ssh.Request) {
181 for req := range in {
182 fmt.Fprintf(debugStream, "Request: %v\n", req.Type)
183 ok := false
184 switch req.Type {
185 case "subsystem":
186 fmt.Fprintf(debugStream, "Subsystem: %s\n", req.Payload[4:])
187 if string(req.Payload[4:]) == "sftp" {
188 ok = true
189 }
190 }
191 fmt.Fprintf(debugStream, " - accepted: %v\n", ok)
192 req.Reply(ok, nil)
193 }
194 }(requests)
195 connection := Connection{"alex"}
196 root := buildHandlers(&connection)
197 server := sftp.NewRequestServer(channel, root)
198 if err := server.Serve(); err == io.EOF {
199 server.Close()
200 log.Print("sftp client exited session.")
201 } else if err != nil {
202 log.Fatal("sftp server completed with error:", err)
203 }
204 }
205}
206
207type listerat []os.FileInfo
208
209// Modeled after strings.Reader's ReadAt() implementation
210func (f listerat) ListAt(ls []os.FileInfo, offset int64) (int, error) {
211 var n int
212 if offset >= int64(len(f)) {
213 return 0, io.EOF
214 }
215 n = copy(ls, f[offset:])
216 if n < len(ls) {
217 return n, io.EOF
218 }
219 return n, nil
220}