Commit 4e771771 authored by Ilham Maulana's avatar Ilham Maulana 💻

fix: token authentication and user registration

parent 230799c6
from django.contrib.auth.models import User
from rest_framework.permissions import IsAuthenticated from rest_framework.permissions import IsAuthenticated
from rest_framework.authtoken.models import Token
class IsStaffUser(IsAuthenticated): class IsStaffUser(IsAuthenticated):
def has_permission(self, request, view): def has_permission(self, request, view):
refresh_token = request.session.get("refresh_token") header = request.headers.get("Authorization")
token = header.split(" ")[1]
verified_token = None
try:
verified_token = Token.objects.get(key=token)
except Token.DoesNotExist:
return False
return bool( return bool(
refresh_token is not None header is not None
and request.user and verified_token.exists()
and request.user.is_authenticated and verified_token.user.is_staff
and request.user.is_staff
) )
class IsNotStaffUser(IsAuthenticated): class IsNotStaffUser(IsAuthenticated):
def has_permission(self, request, view): def has_permission(self, request, view):
refresh_token = request.session.get("refresh_token") header = request.headers.get("Authorization")
token = header.split(" ")[1]
verified_token = None
try:
verified_token = Token.objects.get(key=token)
except Token.DoesNotExist:
return False
return bool( return bool(
refresh_token is not None header is not None
and request.user and verified_token is not None
and request.user.is_authenticated and not verified_token.user.is_staff
and not request.user.is_staff
) )
...@@ -11,11 +11,20 @@ class UserSerializer(serializers.ModelSerializer): ...@@ -11,11 +11,20 @@ class UserSerializer(serializers.ModelSerializer):
"id", "id",
"username", "username",
"email", "email",
"password",
"first_name", "first_name",
"last_name", "last_name",
"is_staff", "is_staff",
] ]
extra_kwargs = {"password": {"write_only": True}}
def create(self, validated_data):
password = validated_data.get("password")
print(validated_data)
user = User.objects.create(**validated_data)
user.set_password(password)
user.save()
return user
def update(self, instance, validated_data): def update(self, instance, validated_data):
partial = validated_data.get("is_partial", False) partial = validated_data.get("is_partial", False)
......
import json import json
from django.contrib.auth import authenticate from django.contrib.auth import authenticate, login, logout
from django.contrib.auth.tokens import default_token_generator from django.contrib.auth.tokens import default_token_generator
from django.core.mail import send_mail from django.core.mail import send_mail
...@@ -7,8 +7,7 @@ from django.core.mail import send_mail ...@@ -7,8 +7,7 @@ from django.core.mail import send_mail
from rest_framework import views, viewsets, status from rest_framework import views, viewsets, status
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.filters import SearchFilter from rest_framework.filters import SearchFilter
from rest_framework_simplejwt.views import TokenObtainPairView from rest_framework.authtoken.models import Token
from rest_framework_simplejwt.tokens import AccessToken, TokenError
from .serializers import ( from .serializers import (
User, User,
...@@ -18,7 +17,6 @@ from .serializers import ( ...@@ -18,7 +17,6 @@ from .serializers import (
LoginHistorySerializer, LoginHistorySerializer,
Member, Member,
MemberSerializer, MemberSerializer,
TokenSerializer,
User, User,
UserSerializer, UserSerializer,
) )
...@@ -85,16 +83,15 @@ class UserDetailView(views.APIView): ...@@ -85,16 +83,15 @@ class UserDetailView(views.APIView):
{"message": "Unauthorized"}, status=status.HTTP_401_UNAUTHORIZED {"message": "Unauthorized"}, status=status.HTTP_401_UNAUTHORIZED
) )
try: token = header.split(" ")[1]
token = header.split(" ")[1] verified_token = Token.objects.filter(key=token)
verified_token = AccessToken(token=token) if not verified_token.exists():
except TokenError:
return Response( return Response(
{"message": "Token is invalid or expired"}, {"message": "Token is invalid or expired"},
status=status.HTTP_401_UNAUTHORIZED, status=status.HTTP_401_UNAUTHORIZED,
) )
user_id = verified_token.payload.get("user_id") user_id = verified_token[0].user.id
user = User.objects.get(pk=user_id) user = User.objects.get(pk=user_id)
data = { data = {
"id": user.pk, "id": user.pk,
...@@ -108,25 +105,18 @@ class UserDetailView(views.APIView): ...@@ -108,25 +105,18 @@ class UserDetailView(views.APIView):
return Response(data, status=status.HTTP_200_OK) return Response(data, status=status.HTTP_200_OK)
class LoginBaseView(TokenObtainPairView): class LoginBaseView(views.APIView):
serializer_class = TokenSerializer
user = None
def post(self, request, *args, **kwargs): def post(self, request):
response = super().post(request, *args, **kwargs) user = authenticate(
username = request.data.get("username") username=request.data["username"], password=request.data["password"]
password = request.data.get("password") )
user = authenticate(username=username, password=password) if user:
login(request=request, user=user)
if user is None: token, created = Token.objects.get_or_create(user=user)
return Response( return Response({"token": token.key})
{"message": "Invalid username or password"}, else:
status=status.HTTP_403_FORBIDDEN, return Response({"error": "Invalid credentials"}, status=401)
)
self.user = user
request.session["refresh_token"] = response.data.get("refresh")
return response
class LibrarianLoginView(LoginBaseView): class LibrarianLoginView(LoginBaseView):
...@@ -135,7 +125,7 @@ class LibrarianLoginView(LoginBaseView): ...@@ -135,7 +125,7 @@ class LibrarianLoginView(LoginBaseView):
response = super().post(request, *args, **kwargs) response = super().post(request, *args, **kwargs)
if response.status_code == 200: if response.status_code == 200:
if not self.user.is_staff: if not request.user.is_staff:
return Response( return Response(
{"message": "Account does not have access"}, {"message": "Account does not have access"},
status=status.HTTP_403_FORBIDDEN, status=status.HTTP_403_FORBIDDEN,
...@@ -151,7 +141,6 @@ class RegisterBaseView(views.APIView): ...@@ -151,7 +141,6 @@ class RegisterBaseView(views.APIView):
def post(self, request): def post(self, request):
data = request.data data = request.data
data["message"] = "Register as librarian success"
serializer = self.serializer_class(data=data) serializer = self.serializer_class(data=data)
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
serializer.save() serializer.save()
...@@ -160,11 +149,19 @@ class RegisterBaseView(views.APIView): ...@@ -160,11 +149,19 @@ class RegisterBaseView(views.APIView):
class LibrarianRegisterView(RegisterBaseView): class LibrarianRegisterView(RegisterBaseView):
serializer_class = LibrarianSerializer serializer_class = UserSerializer
class MemberRegisterView(RegisterBaseView): class MemberRegisterView(RegisterBaseView):
serializer_class = MemberSerializer serializer_class = UserSerializer
def post(self, request):
response = super().post(request)
user_id = response.data.get("id")
user = User.objects.get(pk=user_id)
Member.objects.create(user=user)
return response
class MemberLoginView(LoginBaseView): class MemberLoginView(LoginBaseView):
...@@ -173,7 +170,7 @@ class MemberLoginView(LoginBaseView): ...@@ -173,7 +170,7 @@ class MemberLoginView(LoginBaseView):
response = super().post(request, *args, **kwargs) response = super().post(request, *args, **kwargs)
if response.status_code == 200: if response.status_code == 200:
if self.user.is_staff: if request.user.is_staff:
return Response( return Response(
{"message": "Account does not have access"}, {"message": "Account does not have access"},
status=status.HTTP_403_FORBIDDEN, status=status.HTTP_403_FORBIDDEN,
...@@ -187,14 +184,22 @@ class MemberLoginView(LoginBaseView): ...@@ -187,14 +184,22 @@ class MemberLoginView(LoginBaseView):
class LogoutView(views.APIView): class LogoutView(views.APIView):
def get(self, request): def get(self, request):
refresh = request.session.get("refresh_token") header = request.headers.get("Authorization")
if refresh is None: if header is None:
return Response( return Response(
{"detail": "You do not have permission to perform this action."}, {"message": "Unauthorized"}, status=status.HTTP_401_UNAUTHORIZED
status=status.HTTP_403_FORBIDDEN, )
token = header.split(" ")[1]
verified_token = Token.objects.filter(key=token)
if not verified_token.exists():
return Response(
{"message": "Token is invalid or expired"},
status=status.HTTP_401_UNAUTHORIZED,
) )
del request.session["refresh_token"] verified_token.delete()
logout(request=request)
return Response({"message": "Logout success"}, status=status.HTTP_200_OK) return Response({"message": "Logout success"}, status=status.HTTP_200_OK)
......
...@@ -45,6 +45,7 @@ INSTALLED_APPS = [ ...@@ -45,6 +45,7 @@ INSTALLED_APPS = [
"dashboard.apps.DashboardConfig", "dashboard.apps.DashboardConfig",
# 3rd party # 3rd party
"rest_framework", "rest_framework",
"rest_framework.authtoken",
"rest_framework_simplejwt", "rest_framework_simplejwt",
"django_filters", "django_filters",
] ]
...@@ -55,7 +56,7 @@ INSTALLED_APPS = [ ...@@ -55,7 +56,7 @@ INSTALLED_APPS = [
REST_FRAMEWORK = { REST_FRAMEWORK = {
"DEFAULT_AUTHENTICATION_CLASSES": [ "DEFAULT_AUTHENTICATION_CLASSES": [
"rest_framework_simplejwt.authentication.JWTAuthentication", "rest_framework.authentication.TokenAuthentication",
], ],
"DEFAULT_FILTER_BACKENDS": ["django_filters.rest_framework.DjangoFilterBackend"], "DEFAULT_FILTER_BACKENDS": ["django_filters.rest_framework.DjangoFilterBackend"],
"DEFAULT_PAGINATION_CLASS": "rest_framework.pagination.PageNumberPagination", "DEFAULT_PAGINATION_CLASS": "rest_framework.pagination.PageNumberPagination",
......
...@@ -23,9 +23,6 @@ from django.contrib.auth.views import ( ...@@ -23,9 +23,6 @@ from django.contrib.auth.views import (
PasswordResetConfirmView, PasswordResetConfirmView,
PasswordResetCompleteView, PasswordResetCompleteView,
) )
from rest_framework_simplejwt.views import (
TokenRefreshView,
)
from dashboard.views import UpcomingLoanView, OverduedLoanView from dashboard.views import UpcomingLoanView, OverduedLoanView
...@@ -54,7 +51,6 @@ urlpatterns = [ ...@@ -54,7 +51,6 @@ urlpatterns = [
name="password_reset_complete", name="password_reset_complete",
), ),
# api # api
path("api/v1/token/refresh/", TokenRefreshView.as_view(), name="token_refresh"),
path("api/v1/", include("api.urls"), name="API_V1"), path("api/v1/", include("api.urls"), name="API_V1"),
# 3rd party # 3rd party
path("api-auth/", include("rest_framework.urls"), name="api_auth"), path("api-auth/", include("rest_framework.urls"), name="api_auth"),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment