src/app/functions.go (view raw)
1package app
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log"
8 "net/http"
9 "os"
10 "regexp"
11 "strings"
12 "time"
13
14 "github.com/birabittoh/auth-boilerplate/src/email"
15)
16
17type HabitDisplay struct {
18 Class string
19 Name string
20 LastAck string
21 Disabled bool
22}
23
24const (
25 minUsernameLength = 3
26 maxUsernameLength = 10
27
28 classGood = "good"
29 classWarn = "warn"
30 classBad = "bad"
31)
32
33var (
34 validUsername = regexp.MustCompile(`(?i)^[a-z0-9._-]+$`)
35 validEmail = regexp.MustCompile(`^[a-z0-9._%+-]+@[a-z0-9.-]+\.[a-z]{2,}$`)
36 validHabitName = regexp.MustCompile(`(?i)^[a-z0-9._,\s)(-]+$`)
37)
38
39func getUserByName(username string, excluding uint) (user User, err error) {
40 err = db.Model(&User{}).Where("upper(username) == upper(?) AND id != ?", username, excluding).First(&user).Error
41 return
42}
43
44func sanitizeUsername(username string) (string, error) {
45 if !validUsername.MatchString(username) || len(username) < minUsernameLength || len(username) > maxUsernameLength {
46 return "", errors.New("invalid username")
47 }
48
49 return username, nil
50}
51
52func sanitizeEmail(email string) (string, error) {
53 email = strings.ToLower(email)
54
55 if !validEmail.MatchString(email) {
56 return "", fmt.Errorf("invalid email")
57 }
58
59 return email, nil
60}
61
62func checkHabitName(name string) bool {
63 return len(name) < 50 && validHabitName.MatchString(name)
64}
65
66func login(w http.ResponseWriter, userID uint, remember bool) {
67 var duration time.Duration
68 if remember {
69 duration = durationWeek
70 } else {
71 duration = durationDay
72 }
73
74 cookie, err := g.GenerateCookie(duration)
75 if err != nil {
76 http.Error(w, "Could not generate session cookie.", http.StatusInternalServerError)
77 }
78
79 ks.Set("session:"+cookie.Value, userID, duration)
80 http.SetCookie(w, cookie)
81}
82
83func loadEmailConfig() *email.Client {
84 address := os.Getenv("APP_SMTP_EMAIL")
85 password := os.Getenv("APP_SMTP_PASSWORD")
86 host := os.Getenv("APP_SMTP_HOST")
87 port := os.Getenv("APP_SMTP_PORT")
88
89 if address == "" || password == "" || host == "" {
90 log.Println("Missing email configuration.")
91 return nil
92 }
93
94 if port == "" {
95 port = "587"
96 }
97
98 return email.NewClient(address, password, host, port)
99}
100
101func sendEmail(mail email.Email) error {
102 if m == nil {
103 return errors.New("email client is not initialized")
104 }
105 return m.Send(mail)
106}
107
108func sendResetEmail(address, token string) {
109 resetURL := fmt.Sprintf("%s/reset-password-confirm?token=%s", baseUrl, token)
110 err := sendEmail(email.Email{
111 To: []string{address},
112 Subject: "Reset password",
113 Body: fmt.Sprintf("Use the following link to reset your password:\n%s", resetURL),
114 })
115 if err != nil {
116 log.Printf("Could not send reset email for %s. Link: %s", address, resetURL)
117 }
118}
119
120func readSessionCookie(r *http.Request) (userID *uint, err error) {
121 cookie, err := r.Cookie("session_token")
122 if err != nil {
123 return
124 }
125 return ks.Get("session:" + cookie.Value)
126}
127
128// Middleware to check if the user is logged in
129func loginRequired(next http.HandlerFunc) http.HandlerFunc {
130 return func(w http.ResponseWriter, r *http.Request) {
131 userID, err := readSessionCookie(r)
132 if err != nil {
133 http.Redirect(w, r, "/login", http.StatusFound)
134 return
135 }
136
137 ctx := context.WithValue(r.Context(), userContextKey, *userID)
138 next(w, r.WithContext(ctx))
139 }
140}
141
142func getLoggedUser(r *http.Request) (user User, ok bool) {
143 userID, ok := r.Context().Value(userContextKey).(uint)
144 db.Find(&user, userID)
145 return user, ok
146}
147
148func formatDuration(d time.Duration) string {
149 // TODO: 48h1m13s --> 2.01 days
150 return d.String()
151}
152
153func toHabitDisplay(habit Habit) HabitDisplay {
154 var lastAck string
155 if habit.LastAck == nil {
156 lastAck = "-"
157 } else {
158 lastAck = formatDuration(time.Since(*habit.LastAck))
159 }
160 return HabitDisplay{
161 Name: habit.Name,
162 LastAck: lastAck,
163 Disabled: habit.Disabled,
164 Class: classGood,
165 }
166}
167
168func getAllHabits(userID uint) (positives []HabitDisplay, negatives []HabitDisplay, err error) {
169 var habits []Habit
170 err = db.Model(&Habit{}).Where(&Habit{UserID: userID}).Find(&habits).Error
171 if err != nil {
172 return
173 }
174
175 for _, habit := range habits {
176 habitDisplay := toHabitDisplay(habit)
177 if habit.Negative {
178 negatives = append(negatives, habitDisplay)
179 } else {
180 positives = append(positives, habitDisplay)
181 }
182 }
183 return
184}