145 lines
3.8 KiB
Go
145 lines
3.8 KiB
Go
package app
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"io"
|
|
"mime"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"git.warky.dev/wdevs/amcs/internal/tools"
|
|
)
|
|
|
|
const (
|
|
maxUploadBytes = 100 << 20
|
|
multipartFormMemory = 32 << 20
|
|
)
|
|
|
|
func fileHandler(files *tools.FilesTool) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
id := r.PathValue("id")
|
|
if id != "" {
|
|
fileDownloadHandler(files, id, w, r)
|
|
return
|
|
}
|
|
|
|
if r.Method != http.MethodPost {
|
|
w.Header().Set("Allow", http.MethodPost)
|
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
r.Body = http.MaxBytesReader(w, r.Body, maxUploadBytes)
|
|
|
|
in, err := parseUploadRequest(r)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
out, err := files.SaveDecoded(r.Context(), nil, in)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusCreated)
|
|
_ = json.NewEncoder(w).Encode(out)
|
|
})
|
|
}
|
|
|
|
func fileDownloadHandler(files *tools.FilesTool, id string, w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodGet && r.Method != http.MethodHead {
|
|
w.Header().Set("Allow", "GET, HEAD")
|
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
file, err := files.GetRaw(r.Context(), id)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusNotFound)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", file.MediaType)
|
|
w.Header().Set("Content-Disposition", "attachment; filename="+file.Name)
|
|
w.Header().Set("X-File-Kind", file.Kind)
|
|
w.Header().Set("X-File-SHA256", file.SHA256)
|
|
w.WriteHeader(http.StatusOK)
|
|
if r.Method != http.MethodHead {
|
|
_, _ = w.Write(file.Content)
|
|
}
|
|
}
|
|
|
|
func parseUploadRequest(r *http.Request) (tools.SaveFileDecodedInput, error) {
|
|
contentType := r.Header.Get("Content-Type")
|
|
mediaType, _, _ := mime.ParseMediaType(contentType)
|
|
|
|
if strings.HasPrefix(mediaType, "multipart/form-data") {
|
|
return parseMultipartUpload(r)
|
|
}
|
|
|
|
return parseRawUpload(r)
|
|
}
|
|
|
|
func parseMultipartUpload(r *http.Request) (tools.SaveFileDecodedInput, error) {
|
|
if err := r.ParseMultipartForm(multipartFormMemory); err != nil {
|
|
return tools.SaveFileDecodedInput{}, err
|
|
}
|
|
|
|
file, header, err := r.FormFile("file")
|
|
if err != nil {
|
|
return tools.SaveFileDecodedInput{}, errors.New("multipart upload requires a file field named \"file\"")
|
|
}
|
|
defer file.Close()
|
|
|
|
content, err := io.ReadAll(file)
|
|
if err != nil {
|
|
return tools.SaveFileDecodedInput{}, err
|
|
}
|
|
|
|
return tools.SaveFileDecodedInput{
|
|
Name: firstNonEmpty(r.FormValue("name"), header.Filename),
|
|
Content: content,
|
|
MediaType: firstNonEmpty(r.FormValue("media_type"), header.Header.Get("Content-Type")),
|
|
Kind: r.FormValue("kind"),
|
|
ThoughtID: r.FormValue("thought_id"),
|
|
Project: r.FormValue("project"),
|
|
}, nil
|
|
}
|
|
|
|
func parseRawUpload(r *http.Request) (tools.SaveFileDecodedInput, error) {
|
|
content, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
return tools.SaveFileDecodedInput{}, err
|
|
}
|
|
|
|
name := firstNonEmpty(
|
|
r.URL.Query().Get("name"),
|
|
r.Header.Get("X-File-Name"),
|
|
)
|
|
if strings.TrimSpace(name) == "" {
|
|
return tools.SaveFileDecodedInput{}, errors.New("raw upload requires a file name via query param \"name\" or X-File-Name header")
|
|
}
|
|
|
|
return tools.SaveFileDecodedInput{
|
|
Name: name,
|
|
Content: content,
|
|
MediaType: r.Header.Get("Content-Type"),
|
|
Kind: firstNonEmpty(r.URL.Query().Get("kind"), r.Header.Get("X-File-Kind")),
|
|
ThoughtID: firstNonEmpty(r.URL.Query().Get("thought_id"), r.Header.Get("X-Thought-Id")),
|
|
Project: firstNonEmpty(r.URL.Query().Get("project"), r.Header.Get("X-Project")),
|
|
}, nil
|
|
}
|
|
|
|
func firstNonEmpty(values ...string) string {
|
|
for _, value := range values {
|
|
if strings.TrimSpace(value) != "" {
|
|
return strings.TrimSpace(value)
|
|
}
|
|
}
|
|
return ""
|
|
}
|