package app import ( "encoding/json" "errors" "io" "mime" "net/http" "strings" "git.warky.dev/wdevs/amcs/internal/tools" ) const ( maxUploadBytes = 100 << 20 multipartFormMemory = 32 << 20 ) func fileUploadHandler(files *tools.FilesTool) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { w.Header().Set("Allow", http.MethodPost) http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } r.Body = http.MaxBytesReader(w, r.Body, maxUploadBytes) in, err := parseUploadRequest(r) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } out, err := files.SaveDecoded(r.Context(), nil, in) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusCreated) _ = json.NewEncoder(w).Encode(out) }) } func parseUploadRequest(r *http.Request) (tools.SaveFileDecodedInput, error) { contentType := r.Header.Get("Content-Type") mediaType, _, _ := mime.ParseMediaType(contentType) if strings.HasPrefix(mediaType, "multipart/form-data") { return parseMultipartUpload(r) } return parseRawUpload(r) } func parseMultipartUpload(r *http.Request) (tools.SaveFileDecodedInput, error) { if err := r.ParseMultipartForm(multipartFormMemory); err != nil { return tools.SaveFileDecodedInput{}, err } file, header, err := r.FormFile("file") if err != nil { return tools.SaveFileDecodedInput{}, errors.New("multipart upload requires a file field named \"file\"") } defer file.Close() content, err := io.ReadAll(file) if err != nil { return tools.SaveFileDecodedInput{}, err } return tools.SaveFileDecodedInput{ Name: firstNonEmpty(r.FormValue("name"), header.Filename), Content: content, MediaType: firstNonEmpty(r.FormValue("media_type"), header.Header.Get("Content-Type")), Kind: r.FormValue("kind"), ThoughtID: r.FormValue("thought_id"), Project: r.FormValue("project"), }, nil } func parseRawUpload(r *http.Request) (tools.SaveFileDecodedInput, error) { content, err := io.ReadAll(r.Body) if err != nil { return tools.SaveFileDecodedInput{}, err } name := firstNonEmpty( r.URL.Query().Get("name"), r.Header.Get("X-File-Name"), ) if strings.TrimSpace(name) == "" { return tools.SaveFileDecodedInput{}, errors.New("raw upload requires a file name via query param \"name\" or X-File-Name header") } return tools.SaveFileDecodedInput{ Name: name, Content: content, MediaType: r.Header.Get("Content-Type"), Kind: firstNonEmpty(r.URL.Query().Get("kind"), r.Header.Get("X-File-Kind")), ThoughtID: firstNonEmpty(r.URL.Query().Get("thought_id"), r.Header.Get("X-Thought-Id")), Project: firstNonEmpty(r.URL.Query().Get("project"), r.Header.Get("X-Project")), }, nil } func firstNonEmpty(values ...string) string { for _, value := range values { if strings.TrimSpace(value) != "" { return strings.TrimSpace(value) } } return "" }