462 lines
18 KiB
Swift

//
// Client.swift
// Pachyderm
//
// Created by Shadowfacts on 9/8/18.
// Copyright © 2018 Shadowfacts. All rights reserved.
//
import Foundation
/**
The base Mastodon API client.
*/
public class Client {
public typealias Callback<Result: Decodable> = (Response<Result>) -> Void
let baseURL: URL
let session: URLSession
public var accessToken: String?
public var appID: String?
public var clientID: String?
public var clientSecret: String?
public var timeoutInterval: TimeInterval = 60
static let decoder: JSONDecoder = {
let decoder = JSONDecoder()
let formatter = DateFormatter()
formatter.dateFormat = "yyyy-MM-dd'T'HH:mm:ss.SZ"
formatter.timeZone = TimeZone(abbreviation: "UTC")
formatter.locale = Locale(identifier: "en_US_POSIX")
let iso8601 = ISO8601DateFormatter()
decoder.dateDecodingStrategy = .custom({ (decoder) in
let container = try decoder.singleValueContainer()
let str = try container.decode(String.self)
// for the next time mastodon accidentally changes date formats >.>
if let date = formatter.date(from: str) {
return date
} else if let date = iso8601.date(from: str) {
return date
} else {
throw DecodingError.typeMismatch(Date.self, .init(codingPath: container.codingPath, debugDescription: "unexpected date format"))
}
})
return decoder
}()
static let encoder: JSONEncoder = {
let encoder = JSONEncoder()
let formatter = DateFormatter()
formatter.dateFormat = "yyyy-MM-dd'T'HH:mm:ss.SZ"
formatter.timeZone = TimeZone(abbreviation: "UTC")
formatter.locale = Locale(identifier: "en_US_POSIX")
encoder.dateEncodingStrategy = .formatted(formatter)
return encoder
}()
public init(baseURL: URL, accessToken: String? = nil, session: URLSession = .shared) {
self.baseURL = baseURL
self.accessToken = accessToken
self.session = session
}
@discardableResult
public func run<Result>(_ request: Request<Result>, completion: @escaping Callback<Result>) -> URLSessionTask? {
guard let urlRequest = createURLRequest(request: request) else {
completion(.failure(Error(request: request, type: .invalidRequest)))
return nil
}
let task = session.dataTask(with: urlRequest) { data, response, error in
if let error = error {
completion(.failure(Error(request: request, type: .networkError(error))))
return
}
guard let data = data,
let response = response as? HTTPURLResponse else {
completion(.failure(Error(request: request, type: .invalidResponse)))
return
}
guard response.statusCode == 200 else {
let mastodonError = try? Client.decoder.decode(MastodonError.self, from: data)
let type: ErrorType = mastodonError.flatMap { .mastodonError($0.description) } ?? .unexpectedStatus(response.statusCode)
completion(.failure(Error(request: request, type: type)))
return
}
let result: Result
do {
result = try Client.decoder.decode(Result.self, from: data)
} catch {
completion(.failure(Error(request: request, type: .invalidModel(error))))
return
}
let pagination = response.allHeaderFields["Link"].flatMap { $0 as? String }.flatMap(Pagination.init)
completion(.success(result, pagination))
}
task.resume()
return task
}
func createURLRequest<Result>(request: Request<Result>) -> URLRequest? {
guard var components = URLComponents(url: baseURL, resolvingAgainstBaseURL: true) else { return nil }
components.path = request.endpoint.path
components.queryItems = request.queryParameters.isEmpty ? nil : request.queryParameters.queryItems
guard let url = components.url else { return nil }
var urlRequest = URLRequest(url: url, timeoutInterval: timeoutInterval)
urlRequest.httpMethod = request.method.name
urlRequest.httpBody = request.body.data
if let mimeType = request.body.mimeType {
urlRequest.setValue(mimeType, forHTTPHeaderField: "Content-Type")
}
if let accessToken = accessToken {
urlRequest.setValue("Bearer \(accessToken)", forHTTPHeaderField: "Authorization")
}
return urlRequest
}
// MARK: - Authorization
public func registerApp(name: String, redirectURI: String, scopes: [Scope], website: URL? = nil, completion: @escaping Callback<RegisteredApplication>) {
let request = Request<RegisteredApplication>(method: .post, path: "/api/v1/apps", body: ParametersBody([
"client_name" => name,
"redirect_uris" => redirectURI,
"scopes" => scopes.scopeString,
"website" => website?.absoluteString
]))
run(request) { result in
defer { completion(result) }
guard case let .success(application, _) = result else { return }
self.appID = application.id
self.clientID = application.clientID
self.clientSecret = application.clientSecret
}
}
public func getAccessToken(authorizationCode: String, redirectURI: String, completion: @escaping Callback<LoginSettings>) {
let request = Request<LoginSettings>(method: .post, path: "/oauth/token", body: ParametersBody([
"client_id" => clientID,
"client_secret" => clientSecret,
"grant_type" => "authorization_code",
"code" => authorizationCode,
"redirect_uri" => redirectURI
]))
run(request) { result in
defer { completion(result) }
guard case let .success(loginSettings, _) = result else { return }
self.accessToken = loginSettings.accessToken
}
}
public func nodeInfo(completion: @escaping Callback<NodeInfo>) {
let wellKnown = Request<WellKnown>(method: .get, path: "/.well-known/nodeinfo")
run(wellKnown) { result in
switch result {
case let .failure(error):
completion(.failure(error))
case let .success(wellKnown, _):
if let url = wellKnown.links.first(where: { $0.rel == "http://nodeinfo.diaspora.software/ns/schema/2.0" }),
let components = URLComponents(string: url.href),
components.host == self.baseURL.host {
let nodeInfo = Request<NodeInfo>(method: .get, path: Endpoint(stringLiteral: components.path))
self.run(nodeInfo, completion: completion)
}
}
}
}
// MARK: - Self
public static func getSelfAccount() -> Request<Account> {
return Request<Account>(method: .get, path: "/api/v1/accounts/verify_credentials")
}
public static func getFavourites() -> Request<[Status]> {
return Request<[Status]>(method: .get, path: "/api/v1/favourites")
}
public static func getRelationships(accounts: [String]? = nil) -> Request<[Relationship]> {
return Request<[Relationship]>(method: .get, path: "/api/v1/accounts/relationships", queryParameters: "id" => accounts)
}
public static func getInstance() -> Request<Instance> {
return Request<Instance>(method: .get, path: "/api/v1/instance")
}
public static func getCustomEmoji() -> Request<[Emoji]> {
return Request<[Emoji]>(method: .get, path: "/api/v1/custom_emojis")
}
// MARK: - Accounts
public static func getAccount(id: String) -> Request<Account> {
return Request<Account>(method: .get, path: "/api/v1/accounts/\(id)")
}
public static func searchForAccount(query: String, limit: Int? = nil, following: Bool? = nil) -> Request<[Account]> {
return Request<[Account]>(method: .get, path: "/api/v1/accounts/search", queryParameters: [
"q" => query,
"limit" => limit,
"following" => following
])
}
// MARK: - Blocks
public static func getBlocks() -> Request<[Account]> {
return Request<[Account]>(method: .get, path: "/api/v1/blocks")
}
public static func getDomainBlocks() -> Request<[String]> {
return Request<[String]>(method: .get, path: "api/v1/domain_blocks")
}
public static func block(domain: String) -> Request<Empty> {
return Request<Empty>(method: .post, path: "/api/v1/domain_blocks", body: ParametersBody([
"domain" => domain
]))
}
public static func unblock(domain: String) -> Request<Empty> {
return Request<Empty>(method: .delete, path: "/api/v1/domain_blocks", body: ParametersBody([
"domain" => domain
]))
}
// MARK: - Filters
public static func getFilters() -> Request<[Filter]> {
return Request<[Filter]>(method: .get, path: "/api/v1/filters")
}
public static func createFilter(phrase: String, context: [Filter.Context], irreversible: Bool? = nil, wholeWord: Bool? = nil, expiresAt: Date? = nil) -> Request<Filter> {
return Request<Filter>(method: .post, path: "/api/v1/filters", body: ParametersBody([
"phrase" => phrase,
"irreversible" => irreversible,
"whole_word" => wholeWord,
"expires_at" => expiresAt
] + "context" => context.contextStrings))
}
public static func getFilter(id: String) -> Request<Filter> {
return Request<Filter>(method: .get, path: "/api/v1/filters/\(id)")
}
// MARK: - Follows
public static func getFollowRequests(range: RequestRange = .default) -> Request<[Account]> {
var request = Request<[Account]>(method: .get, path: "/api/v1/follow_requests")
request.range = range
return request
}
public static func getFollowSuggestions() -> Request<[Account]> {
return Request<[Account]>(method: .get, path: "/api/v1/suggestions")
}
public static func followRemote(acct: String) -> Request<Account> {
return Request<Account>(method: .post, path: "/api/v1/follows", body: ParametersBody(["uri" => acct]))
}
// MARK: - Lists
public static func getLists() -> Request<[List]> {
return Request<[List]>(method: .get, path: "/api/v1/lists")
}
public static func getList(id: String) -> Request<List> {
return Request<List>(method: .get, path: "/api/v1/lists/\(id)")
}
public static func createList(title: String) -> Request<List> {
return Request<List>(method: .post, path: "/api/v1/lists", body: ParametersBody(["title" => title]))
}
// MARK: - Media
public static func upload(attachment: FormAttachment, description: String? = nil, focus: (Float, Float)? = nil) -> Request<Attachment> {
return Request<Attachment>(method: .post, path: "/api/v1/media", body: FormDataBody([
"description" => description,
"focus" => focus
], attachment))
}
// MARK: - Mutes
public static func getMutes(range: RequestRange) -> Request<[Account]> {
var request = Request<[Account]>(method: .get, path: "/api/v1/mutes")
request.range = range
return request
}
// MARK: - Notifications
public static func getNotifications(excludeTypes: [Notification.Kind], range: RequestRange = .default) -> Request<[Notification]> {
var request = Request<[Notification]>(method: .get, path: "/api/v1/notifications", queryParameters:
"exclude_types" => excludeTypes.map { $0.rawValue }
)
request.range = range
return request
}
public static func clearNotifications() -> Request<Empty> {
return Request<Empty>(method: .post, path: "/api/v1/notifications/clear")
}
// MARK: - Reports
public static func getReports() -> Request<[Report]> {
return Request<[Report]>(method: .get, path: "/api/v1/reports")
}
public static func report(account: Account, statuses: [Status], comment: String) -> Request<Report> {
return Request<Report>(method: .post, path: "/api/v1/reports", body: ParametersBody([
"account_id" => account.id,
"comment" => comment
] + "status_ids" => statuses.map { $0.id }))
}
// MARK: - Search
public static func search(query: String, types: [SearchResultType]? = nil, resolve: Bool? = nil, limit: Int? = nil) -> Request<SearchResults> {
return Request<SearchResults>(method: .get, path: "/api/v2/search", queryParameters: [
"q" => query,
"resolve" => resolve,
"limit" => limit,
] + "types" => types?.map { $0.rawValue })
}
// MARK: - Statuses
public static func getStatus(id: String) -> Request<Status> {
return Request<Status>(method: .get, path: "/api/v1/statuses/\(id)")
}
public static func createStatus(text: String,
contentType: StatusContentType = .plain,
inReplyTo: String? = nil,
media: [Attachment]? = nil,
sensitive: Bool? = nil,
spoilerText: String? = nil,
visibility: Status.Visibility? = nil,
language: String? = nil,
pollOptions: [String]? = nil,
pollExpiresIn: Int? = nil,
pollMultiple: Bool? = nil,
localOnly: Bool? = nil /* hometown only, not glitch */) -> Request<Status> {
return Request<Status>(method: .post, path: "/api/v1/statuses", body: ParametersBody([
"status" => text,
"content_type" => contentType.mimeType,
"in_reply_to_id" => inReplyTo,
"sensitive" => sensitive,
"spoiler_text" => spoilerText,
"visibility" => visibility?.rawValue,
"language" => language,
"poll[expires_in]" => pollExpiresIn,
"poll[multiple]" => pollMultiple,
"local_only" => localOnly,
] + "media_ids" => media?.map { $0.id } + "poll[options]" => pollOptions))
}
// MARK: - Timelines
public static func getStatuses(timeline: Timeline, range: RequestRange = .default) -> Request<[Status]> {
return timeline.request(range: range)
}
// MARK: - Bookmarks
public static func getBookmarks(range: RequestRange = .default) -> Request<[Status]> {
var request = Request<[Status]>(method: .get, path: "/api/v1/bookmarks")
request.range = range
return request
}
// MARK: - Instance
public static func getTrendingHashtags(limit: Int? = nil) -> Request<[Hashtag]> {
let parameters: [Parameter]
if let limit = limit {
parameters = ["limit" => limit]
} else {
parameters = []
}
return Request<[Hashtag]>(method: .get, path: "/api/v1/trends", queryParameters: parameters)
}
public static func getTrendingStatuses(limit: Int? = nil) -> Request<[Status]> {
let parameters: [Parameter]
if let limit = limit {
parameters = ["limit" => limit]
} else {
parameters = []
}
return Request(method: .get, path: "/api/v1/trends/statuses", queryParameters: parameters)
}
public static func getTrendingLinks(limit: Int? = nil) -> Request<[Card]> {
let parameters: [Parameter]
if let limit = limit {
parameters = ["limit" => limit]
} else {
parameters = []
}
return Request(method: .get, path: "/api/v1/trends/links", queryParameters: parameters)
}
public static func getFeaturedProfiles(local: Bool, order: DirectoryOrder, offset: Int? = nil, limit: Int? = nil) -> Request<[Account]> {
var parameters = [
"order" => order.rawValue,
"local" => local,
]
if let offset = offset {
parameters.append("offset" => offset)
}
if let limit = limit {
parameters.append("limit" => limit)
}
return Request<[Account]>(method: .get, path: "/api/v1/directory", queryParameters: parameters)
}
}
extension Client {
public struct Error: LocalizedError {
public let requestMethod: Method
public let requestEndpoint: Endpoint
public let type: ErrorType
#if DEBUG
public static let debug = Error(request: Client.getStatuses(timeline: .home), type: .invalidResponse)
#endif
init<ResultType: Decodable>(request: Request<ResultType>, type: ErrorType) {
self.requestMethod = request.method
self.requestEndpoint = request.endpoint
self.type = type
}
public var localizedDescription: String {
switch type {
case .networkError(let error):
return "Network Error: \(error.localizedDescription)"
// todo: support more status codes
case .unexpectedStatus(413):
return "HTTP 413: Payload Too Large"
case .unexpectedStatus(let code):
return "HTTP Code \(code)"
case .invalidRequest:
return "Invalid Request"
case .invalidResponse:
return "Invalid Response"
case .invalidModel(_):
return "Invalid Model"
case .mastodonError(let error):
return "Server Error: \(error)"
}
}
}
public enum ErrorType: LocalizedError {
case networkError(Swift.Error)
case unexpectedStatus(Int)
case invalidRequest
case invalidResponse
case invalidModel(Swift.Error)
case mastodonError(String)
}
}