src/app/functions.go (view raw)
1package app
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log"
8 "net/http"
9 "os"
10 "regexp"
11 "strconv"
12 "strings"
13 "time"
14
15 "github.com/birabittoh/auth-boilerplate/src/email"
16)
17
18type HabitDisplay struct {
19 ID uint
20 Class string
21 Name string
22 LastAck string
23 Disabled bool
24}
25
26const (
27 minUsernameLength = 3
28 maxUsernameLength = 10
29
30 classGood = "good"
31 classWarn = "warn"
32 classBad = "bad"
33)
34
35var (
36 validUsername = regexp.MustCompile(`(?i)^[a-z0-9._-]+$`)
37 validEmail = regexp.MustCompile(`^[a-z0-9._%+-]+@[a-z0-9.-]+\.[a-z]{2,}$`)
38 validHabitName = regexp.MustCompile(`(?i)^[a-z0-9._,\s)(-]+$`)
39)
40
41func getUserByName(username string, excluding uint) (user User, err error) {
42 err = db.Model(&User{}).Where("upper(username) == upper(?) AND id != ?", username, excluding).First(&user).Error
43 return
44}
45
46func sanitizeUsername(username string) (string, error) {
47 if !validUsername.MatchString(username) || len(username) < minUsernameLength || len(username) > maxUsernameLength {
48 return "", errors.New("invalid username")
49 }
50
51 return username, nil
52}
53
54func sanitizeEmail(email string) (string, error) {
55 email = strings.ToLower(email)
56
57 if !validEmail.MatchString(email) {
58 return "", fmt.Errorf("invalid email")
59 }
60
61 return email, nil
62}
63
64func checkHabitName(name string) bool {
65 return len(name) < 50 && validHabitName.MatchString(name)
66}
67
68func login(w http.ResponseWriter, userID uint, remember bool) {
69 var duration time.Duration
70 if remember {
71 duration = durationWeek
72 } else {
73 duration = durationDay
74 }
75
76 cookie, err := g.GenerateCookie(duration)
77 if err != nil {
78 http.Error(w, "Could not generate session cookie.", http.StatusInternalServerError)
79 }
80
81 ks.Set("session:"+cookie.Value, userID, duration)
82 http.SetCookie(w, cookie)
83}
84
85func loadEmailConfig() *email.Client {
86 address := os.Getenv("APP_SMTP_EMAIL")
87 password := os.Getenv("APP_SMTP_PASSWORD")
88 host := os.Getenv("APP_SMTP_HOST")
89 port := os.Getenv("APP_SMTP_PORT")
90
91 if address == "" || password == "" || host == "" {
92 log.Println("Missing email configuration.")
93 return nil
94 }
95
96 if port == "" {
97 port = "587"
98 }
99
100 return email.NewClient(address, password, host, port)
101}
102
103func sendEmail(mail email.Email) error {
104 if m == nil {
105 return errors.New("email client is not initialized")
106 }
107 return m.Send(mail)
108}
109
110func sendResetEmail(address, token string) {
111 resetURL := fmt.Sprintf("%s/reset-password-confirm?token=%s", baseUrl, token)
112 err := sendEmail(email.Email{
113 To: []string{address},
114 Subject: "Reset password",
115 Body: fmt.Sprintf("Use the following link to reset your password:\n%s", resetURL),
116 })
117 if err != nil {
118 log.Printf("Could not send reset email for %s. Link: %s", address, resetURL)
119 }
120}
121
122func readSessionCookie(r *http.Request) (userID *uint, err error) {
123 cookie, err := r.Cookie("session_token")
124 if err != nil {
125 return
126 }
127 return ks.Get("session:" + cookie.Value)
128}
129
130// Middleware to check if the user is logged in
131func loginRequired(next http.HandlerFunc) http.HandlerFunc {
132 return func(w http.ResponseWriter, r *http.Request) {
133 userID, err := readSessionCookie(r)
134 if err != nil {
135 http.Redirect(w, r, "/login", http.StatusFound)
136 return
137 }
138
139 ctx := context.WithValue(r.Context(), userContextKey, *userID)
140 next(w, r.WithContext(ctx))
141 }
142}
143
144func getLoggedUser(r *http.Request) (user User, ok bool) {
145 userID, ok := r.Context().Value(userContextKey).(uint)
146 if !ok {
147 return
148 }
149
150 if db.Find(&user, userID).Error != nil {
151 ok = true
152 }
153
154 return user, ok
155}
156
157func formatDuration(t time.Time) (days int, s string) {
158 days = int(time.Since(t).Hours()) / 24
159
160 switch {
161 case days == 0:
162 s = "Today"
163 case days == 1:
164 s = "Yesterday"
165 case days <= 7:
166 s = fmt.Sprintf("%d day(s) ago", days)
167 case days <= 30:
168 weeks := days / 7
169 remainingDays := days % 7
170 if remainingDays == 0 {
171 s = fmt.Sprintf("%d week(s) ago", weeks)
172 } else {
173 s = fmt.Sprintf("%d week(s), %d day(s) ago", weeks, remainingDays)
174 }
175 default:
176 months := days / 30
177 remainingDays := days % 30
178 if remainingDays == 0 {
179 s = fmt.Sprintf("%d month(s) ago", months)
180 } else {
181 s = fmt.Sprintf("%d month(s), %d day(s) ago", months, remainingDays)
182 }
183 }
184 return
185}
186
187func toHabitDisplay(habit Habit) HabitDisplay {
188 var (
189 days int
190 lastAck string
191 class string
192 )
193
194 if habit.LastAck != nil {
195 days, lastAck = formatDuration(*habit.LastAck)
196 class = getClassForAck(habit, days)
197 } else {
198 lastAck = "-"
199 if habit.Negative {
200 class = classGood
201 }
202
203 if !habit.Disabled {
204 class = classBad
205 }
206 }
207
208 return HabitDisplay{
209 ID: habit.ID,
210 Name: habit.Name,
211 LastAck: lastAck,
212 Disabled: habit.Disabled,
213 Class: class,
214 }
215}
216
217func getClassForAck(habit Habit, days int) string {
218 if habit.Negative {
219 switch {
220 case days <= 1:
221 return classBad
222 case days <= 7:
223 return classWarn
224 default:
225 return classGood
226 }
227 }
228
229 if habit.Disabled {
230 return ""
231 }
232
233 switch {
234 case days <= 1:
235 return classGood
236 case days <= int(habit.Days):
237 return classWarn
238 default:
239 return classBad
240 }
241}
242
243func getHabit(id uint) (habit Habit, err error) {
244 err = db.Model(&Habit{}).Find(&habit, id).Error
245 return
246}
247
248func getHabitHelper(w http.ResponseWriter, r *http.Request) (habit Habit, err error) {
249 id := getID(r)
250 if id == 0 {
251 err = errors.New("no id")
252 http.Error(w, "bad request", http.StatusBadRequest)
253 return
254 }
255
256 user, ok := getLoggedUser(r)
257 if !ok {
258 err = errors.New("no logged user")
259 http.Error(w, "unauthorized", http.StatusUnauthorized)
260 return
261 }
262
263 habit, err = getHabit(id)
264 if err != nil {
265 http.Error(w, "not found", http.StatusNotFound)
266 return
267 }
268
269 if habit.UserID != user.ID {
270 err = errors.New("forbidden")
271 http.Error(w, "forbidden", http.StatusForbidden)
272 }
273 return
274}
275
276func getAllHabits(userID uint) (positives []HabitDisplay, negatives []HabitDisplay, err error) {
277 var habits []Habit
278 err = db.Model(&Habit{}).Where(&Habit{UserID: userID}).Find(&habits).Error
279 if err != nil {
280 return
281 }
282
283 for _, habit := range habits {
284 habitDisplay := toHabitDisplay(habit)
285 if habit.Negative {
286 negatives = append(negatives, habitDisplay)
287 } else {
288 positives = append(positives, habitDisplay)
289 }
290 }
291 return
292}
293
294func getID(r *http.Request) uint {
295 res, err := strconv.ParseUint(r.PathValue("id"), 10, 64)
296 if err != nil {
297 return 0
298 }
299 return uint(res)
300}