Add missing Authorization header on MAL refresh token request

* Add missing Authorization header on MAL refresh token request.

* Make sure to also close the response when it have failed.

Co-Authored-By: arkon <4098258+arkon@users.noreply.github.com>
This commit is contained in:
Jays2Kings 2022-08-21 01:51:51 -04:00
parent d64932a504
commit 41c085210c
3 changed files with 38 additions and 27 deletions

View file

@ -23,6 +23,7 @@ import kotlinx.serialization.json.jsonObject
import kotlinx.serialization.json.jsonPrimitive
import kotlinx.serialization.json.long
import okhttp3.FormBody
import okhttp3.Headers
import okhttp3.OkHttpClient
import okhttp3.Request
import okhttp3.RequestBody
@ -280,13 +281,21 @@ class MyAnimeListApi(private val client: OkHttpClient, interceptor: MyAnimeListI
.appendPath("my_list_status")
.build()
fun refreshTokenRequest(refreshToken: String): Request {
fun refreshTokenRequest(oauth: OAuth): Request {
val formBody: RequestBody = FormBody.Builder()
.add("client_id", clientId)
.add("refresh_token", refreshToken)
.add("refresh_token", oauth.refresh_token)
.add("grant_type", "refresh_token")
.build()
return POST("$baseOAuthUrl/token", body = formBody)
// Add the Authorization header manually as this particular
// request is called by the interceptor itself so it doesn't reach
// the part where the token is added automatically.
val headers = Headers.Builder()
.add("Authorization", "Bearer ${oauth.access_token}")
.build()
return POST("$baseOAuthUrl/token", body = formBody, headers = headers)
}
private fun getPkceChallengeCode(): String {

View file

@ -1,49 +1,50 @@
package eu.kanade.tachiyomi.data.track.myanimelist
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.serialization.decodeFromString
import kotlinx.serialization.json.Json
import eu.kanade.tachiyomi.network.parseAs
import okhttp3.Interceptor
import okhttp3.Response
import uy.kohesive.injekt.injectLazy
import okhttp3.internal.closeQuietly
import java.io.IOException
class MyAnimeListInterceptor(private val myanimelist: MyAnimeList, private var token: String?) : Interceptor {
val scope = CoroutineScope(Job() + Dispatchers.Main)
private val json: Json by injectLazy()
private var oauth: OAuth? = null
set(value) {
field = value?.copy(expires_in = System.currentTimeMillis() + (value.expires_in * 1000))
}
override fun intercept(chain: Interceptor.Chain): Response {
val originalRequest = chain.request()
if (token.isNullOrEmpty()) {
throw Exception("Not authenticated with MyAnimeList")
throw IOException("Not authenticated with MyAnimeList")
}
if (oauth == null) {
oauth = myanimelist.loadOAuth()
}
// Refresh access token if null or expired.
if (oauth!!.isExpired()) {
chain.proceed(MyAnimeListApi.refreshTokenRequest(oauth!!.refresh_token)).use {
if (it.isSuccessful) {
setAuth(json.decodeFromString(it.body!!.string()))
// Refresh access token if expired or created_at is freshly set
if (oauth != null &&
(oauth!!.isExpired() || oauth!!.created_at == System.currentTimeMillis())
) {
val newOauth = runCatching {
val oauthResponse = chain.proceed(MyAnimeListApi.refreshTokenRequest(oauth!!))
if (oauthResponse.isSuccessful) {
oauthResponse.parseAs<OAuth>()
} else {
oauthResponse.closeQuietly()
null
}
}
}
// Throw on null auth.
if (newOauth.getOrNull() == null) {
throw IOException("Failed to refresh the access token")
}
setAuth(newOauth.getOrNull())
}
if (oauth == null) {
throw Exception("No authentication token")
throw IOException("No authentication token")
}
// Add the authorization header to the original request.
// Add the authorization header to the original request
val authRequest = originalRequest.newBuilder()
.addHeader("Authorization", "Bearer ${oauth!!.access_token}")
.build()

View file

@ -7,8 +7,9 @@ data class OAuth(
val refresh_token: String,
val access_token: String,
val token_type: String,
val created_at: Long = System.currentTimeMillis(),
val expires_in: Long,
) {
fun isExpired() = System.currentTimeMillis() > expires_in
fun isExpired() = System.currentTimeMillis() > created_at + (expires_in * 1000)
}